diff --git a/src/lwt/dune b/src/lwt/dune new file mode 100644 index 00000000..76e9b6ae --- /dev/null +++ b/src/lwt/dune @@ -0,0 +1,6 @@ + +(library + (name tiny_httpd_lwt) + (public_name tiny_httpd_lwt) + (enabled_if (>= %{ocaml_version} 5.0)) + (libraries tiny_httpd lwt lwt.unix)) diff --git a/src/lwt/task.ml b/src/lwt/task.ml new file mode 100644 index 00000000..d1e615f8 --- /dev/null +++ b/src/lwt/task.ml @@ -0,0 +1,74 @@ +module ED = Effect.Deep + +type _ Effect.t += Await : 'a Lwt.t -> 'a Effect.t + +(** Queue of microtasks that are ready *) +let tasks : (unit -> unit) Queue.t = Queue.create () + +let[@inline] push_task f : unit = Queue.push f tasks + +let on_uncaught_exn : (exn -> Printexc.raw_backtrace -> unit) ref = + ref (fun exn bt -> + Printf.eprintf "lwt_task: uncaught task exception:\n%s\n%s\n%!" + (Printexc.to_string exn) + (Printexc.raw_backtrace_to_string bt)) + +let run_all_tasks () : unit = + (* use local queue to prevent the hook from running forever in case + tasks keep scheduling new tasks. *) + let local = Queue.create () in + Queue.transfer tasks local; + while not (Queue.is_empty local) do + let t = Queue.pop local in + try t () + with exn -> + let bt = Printexc.get_raw_backtrace () in + !on_uncaught_exn exn bt + done; + (* make sure we don't sleep forever if there's no lwt promise + ready but [tasks] contains ready tasks *) + if not (Queue.is_empty tasks) then ignore (Lwt.pause () : unit Lwt.t) + +let () = + let _hook1 = Lwt_main.Enter_iter_hooks.add_first run_all_tasks in + let _hook2 = Lwt_main.Leave_iter_hooks.add_first run_all_tasks in + () + +let await (fut : 'a Lwt.t) : 'a = + match Lwt.state fut with + | Lwt.Return x -> x + | Lwt.Fail exn -> raise exn + | Lwt.Sleep -> Effect.perform (Await fut) + +(** the main effect handler *) +let handler : _ ED.effect_handler = + let effc : type b. b Effect.t -> ((b, unit) ED.continuation -> 'a) option = + function + | Await fut -> + Some + (fun k -> + Lwt.on_any fut + (fun res -> push_task (fun () -> ED.continue k res)) + (fun exn -> push_task (fun () -> ED.discontinue k exn))) + | _ -> None + in + + { effc } + +let run_inside_effect_handler_ (type a) (promise : a Lwt.u) f () : unit = + let res = ref (Error (Failure "not resolved")) in + let run_f_and_set_res () = + (try + let r = f () in + res := Ok r + with exn -> res := Error exn); + Lwt.wakeup_later_result promise !res + in + ED.try_with run_f_and_set_res () handler + +let run f : _ Lwt.t = + let lwt, resolve = Lwt.wait () in + push_task (run_inside_effect_handler_ resolve f); + lwt + +let run_async f : unit = ignore (run f : unit Lwt.t) diff --git a/src/lwt/task.mli b/src/lwt/task.mli new file mode 100644 index 00000000..7b326dc0 --- /dev/null +++ b/src/lwt/task.mli @@ -0,0 +1,9 @@ +(** Direct style tasks for Lwt *) + +val run : (unit -> 'a) -> 'a Lwt.t +(** Run a microtask *) + +val run_async : (unit -> unit) -> unit + +val await : 'a Lwt.t -> 'a +(** Can only be used inside {!run} *) diff --git a/src/lwt/tiny_httpd_lwt.ml b/src/lwt/tiny_httpd_lwt.ml new file mode 100644 index 00000000..909c177c --- /dev/null +++ b/src/lwt/tiny_httpd_lwt.ml @@ -0,0 +1,175 @@ +module IO = Tiny_httpd.IO +module H = Tiny_httpd.Server +module Pool = Tiny_httpd.Pool +module Slice = IO.Slice +module Log = Tiny_httpd.Log + +let spf = Printf.sprintf +let ( let@ ) = ( @@ ) + +type 'a with_args = + ?addr:string -> + ?port:int -> + ?unix_sock:string -> + ?max_connections:int -> + ?max_buf_pool_size:int -> + ?buf_size:int -> + 'a + +let get_max_connection_ ?(max_connections = 64) () : int = + let max_connections = max 4 max_connections in + max_connections + +let buf_size = 16 * 1024 + +let show_sockaddr = function + | Unix.ADDR_UNIX s -> s + | Unix.ADDR_INET (addr, port) -> + spf "%s:%d" (Unix.string_of_inet_addr addr) port + +let ic_of_channel (ic : Lwt_io.input_channel) : IO.Input.t = + object + inherit Iostream.In_buf.t_from_refill () + + method private refill (sl : Slice.t) = + assert (sl.len = 0); + let n = + Lwt_io.read_into ic sl.bytes 0 (Bytes.length sl.bytes) |> Task.await + in + sl.len <- n + + method close () = Lwt_io.close ic |> Task.await + end + +let oc_of_channel (oc : Lwt_io.output_channel) : IO.Output.t = + object + method flush () : unit = Lwt_io.flush oc |> Task.await + + method output buf i len = + Lwt_io.write_from_exactly oc buf i len |> Task.await + + method output_char c = Lwt_io.write_char oc c |> Task.await + method close () = Lwt_io.close oc |> Task.await + end + +let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size + ?(buf_size = buf_size) () : (module H.IO_BACKEND) = + let buf_pool = + Pool.create ?max_size:max_buf_pool_size + ~mk_item:(fun () -> Lwt_bytes.create buf_size) + () + in + + let addr, port, (sockaddr : Unix.sockaddr) = + match addr, port, unix_sock with + | _, _, Some s -> Printf.sprintf "unix:%s" s, 0, Unix.ADDR_UNIX s + | addr, port, None -> + let addr = Option.value ~default:"127.0.0.1" addr in + let sockaddr, port = + match Lwt_unix.getaddrinfo addr "" [] |> Task.await, port with + | { Unix.ai_addr = ADDR_INET (h, _); _ } :: _, None -> + let p = 8080 in + Unix.ADDR_INET (h, p), p + | { Unix.ai_addr = ADDR_INET (h, _); _ } :: _, Some p -> + Unix.ADDR_INET (h, p), p + | _ -> + failwith @@ Printf.sprintf "Could not parse TCP address from %S" addr + in + addr, port, sockaddr + in + + let module M = struct + let init_addr () = addr + let init_port () = port + let get_time_s () = Unix.gettimeofday () + let max_connections = get_max_connection_ ?max_connections () + + let pool_size = + match max_buf_pool_size with + | Some n -> n + | None -> min 4096 (max_connections * 2) + + let tcp_server () : IO.TCP_server.builder = + { + IO.TCP_server.serve = + (fun ~after_init ~handle () : unit -> + let running = Atomic.make true in + let active_conns = Atomic.make 0 in + + (* Eio.Switch.on_release sw (fun () -> Atomic.set running false); *) + let port = ref port in + + let server_loop : unit Lwt.t = + let@ () = Task.run in + let backlog = max_connections in + let sock = + Lwt_unix.socket ~cloexec:true + (Unix.domain_of_sockaddr sockaddr) + Unix.SOCK_STREAM 0 + in + Lwt_unix.bind sock sockaddr |> Task.await; + Lwt_unix.listen sock backlog; + + (* recover real port, if any *) + (match Unix.getsockname (Lwt_unix.unix_file_descr sock) with + | Unix.ADDR_INET (_, p) -> port := p + | _ -> ()); + + let handle_client client_addr fd : unit = + Atomic.incr active_conns; + let@ () = Task.run_async in + let@ () = + Fun.protect ~finally:(fun () -> + Log.debug (fun k -> + k "Tiny_httpd_lwt: client handler returned"); + Atomic.decr active_conns) + in + + let@ buf_ic = Pool.with_resource buf_pool in + let@ buf_oc = Pool.with_resource buf_pool in + let ic = + ic_of_channel @@ Lwt_io.of_fd ~mode:Input ~buffer:buf_ic fd + in + let oc = + oc_of_channel @@ Lwt_io.of_fd ~mode:Output ~buffer:buf_ic fd + in + try handle.handle ~client_addr ic oc + with exn -> + let bt = Printexc.get_raw_backtrace () in + Log.error (fun k -> + k "Client handler for %s failed with %s\n%s" + (show_sockaddr client_addr) + (Printexc.to_string exn) + (Printexc.raw_backtrace_to_string bt)) + in + + while Atomic.get running do + let fd, addr = Lwt_unix.accept sock |> Task.await in + handle_client addr fd + done + in + + let tcp_server : IO.TCP_server.t = + { + running = (fun () -> Atomic.get running); + stop = + (fun () -> + Atomic.set running false; + Task.await server_loop); + endpoint = (fun () -> addr, !port); + active_connections = (fun () -> Atomic.get active_conns); + } + in + + after_init tcp_server); + } + end in + (module M) + +let create ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size ?buf_size + ?middlewares () : H.t = + let backend = + io_backend ?addr ?port ?unix_sock ?max_buf_pool_size ?max_connections + ?buf_size () + in + H.create_from ?buf_size ?middlewares ~backend () diff --git a/src/lwt/tiny_httpd_lwt.mli b/src/lwt/tiny_httpd_lwt.mli new file mode 100644 index 00000000..56bd6868 --- /dev/null +++ b/src/lwt/tiny_httpd_lwt.mli @@ -0,0 +1,26 @@ +(** Lwt backend for Tiny_httpd. + + This only works on OCaml 5 because it uses effect handlers to use Lwt in + direct style. + + {b NOTE}: this is very experimental and will absolutely change over time, + @since NEXT_RELEASE *) + +type 'a with_args = + ?addr:string -> + ?port:int -> + ?unix_sock:string -> + ?max_connections:int -> + ?max_buf_pool_size:int -> + ?buf_size:int -> + 'a + +val io_backend : (unit -> (module Tiny_httpd.Server.IO_BACKEND)) with_args +(** Create a server *) + +val create : + (?middlewares:([ `Encoding | `Stage of int ] * Tiny_httpd.Middleware.t) list -> + unit -> + Tiny_httpd.Server.t) + with_args +(** Create a server *)