From 5d7637becc89c402d3d5bc0324d5c5b366e76ea2 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 3 Jun 2023 20:54:28 -0400 Subject: [PATCH] server: add `IO_BACKEND` abstraction; implement a unix version of it this doesn't change the `create`+`run` version, but makes it possible to create a server that doesn't use unix IOs. --- src/Tiny_httpd_buf.ml | 10 ++ src/Tiny_httpd_buf.mli | 8 + src/Tiny_httpd_io.ml | 26 ++- src/Tiny_httpd_server.ml | 368 ++++++++++++++++++++++++++------------ src/Tiny_httpd_server.mli | 41 ++++- 5 files changed, 332 insertions(+), 121 deletions(-) diff --git a/src/Tiny_httpd_buf.ml b/src/Tiny_httpd_buf.ml index 2c706180..e3e2faa2 100644 --- a/src/Tiny_httpd_buf.ml +++ b/src/Tiny_httpd_buf.ml @@ -20,6 +20,16 @@ let add_bytes (self : t) s i len : unit = Bytes.blit s i self.bytes self.i len; self.i <- self.i + len +let[@inline] add_string self str : unit = + add_bytes self (Bytes.unsafe_of_string str) 0 (String.length str) + +let add_buffer (self : t) (buf : Buffer.t) : unit = + let len = Buffer.length buf in + if self.i + len >= Bytes.length self.bytes then + resize self (self.i + (self.i / 2) + len + 10); + Buffer.blit buf 0 self.bytes self.i len; + self.i <- self.i + len + let contents (self : t) : string = Bytes.sub_string self.bytes 0 self.i let contents_and_clear (self : t) : string = diff --git a/src/Tiny_httpd_buf.mli b/src/Tiny_httpd_buf.mli index b500ccaf..2bcfe58b 100644 --- a/src/Tiny_httpd_buf.mli +++ b/src/Tiny_httpd_buf.mli @@ -24,3 +24,11 @@ val contents_and_clear : t -> string val add_bytes : t -> bytes -> int -> int -> unit (** Append given bytes slice to the buffer. @since 0.5 *) + +val add_string : t -> string -> unit +(** Add string. + @since NEXT_RELEASE *) + +val add_buffer : t -> Buffer.t -> unit +(** Append bytes from buffer. + @since NEXT_RELEASE *) diff --git a/src/Tiny_httpd_io.ml b/src/Tiny_httpd_io.ml index 8d36a2a4..749f53d9 100644 --- a/src/Tiny_httpd_io.ml +++ b/src/Tiny_httpd_io.ml @@ -8,6 +8,8 @@ @since NEXT_RELEASE *) +module Buf = Tiny_httpd_buf + module In_channel = struct type t = { input: bytes -> int -> int -> int; @@ -67,6 +69,11 @@ module Out_channel = struct self.output (Bytes.unsafe_of_string str) 0 (String.length str) let[@inline] close self : unit = self.close () + let[@inline] flush self : unit = self.flush () + + let output_buf (self : t) (buf : Buf.t) : unit = + let b = Buf.bytes_slice buf in + output self b 0 (Buf.size buf) end (** A TCP server abstraction *) @@ -77,9 +84,20 @@ module TCP_server = struct } type t = { - listen: handle:conn_handler -> unit -> unit; - (** Blocking call to start listening for incoming connections. - Uses the connection handler to handle individual client connections. *) - endpoint: unit -> Unix.inet_addr * int; (** Endpoint we listen on *) + endpoint: unit -> string * int; + (** Endpoint we listen on. This can only be called from within [serve]. *) + active_connections: unit -> int; + (** Number of connections currently active *) + running: unit -> bool; (** Is the server currently running? *) + stop: unit -> unit; + (** Ask the server to stop. This might not take effect immediately. *) } + (** Running server. *) + + type builder = { + serve: after_init:(t -> unit) -> handle:conn_handler -> unit -> unit; + (** Blocking call to listen for incoming connections and handle them. + Uses the connection handler to handle individual client connections. *) + } + (** A TCP server implementation. *) end diff --git a/src/Tiny_httpd_server.ml b/src/Tiny_httpd_server.ml index 9378faae..24086a08 100644 --- a/src/Tiny_httpd_server.ml +++ b/src/Tiny_httpd_server.ml @@ -18,6 +18,7 @@ let _debug k = module Buf = Tiny_httpd_buf module Byte_stream = Tiny_httpd_stream +module IO = Tiny_httpd_io exception Bad_req of int * string @@ -423,9 +424,19 @@ module Response = struct Format.fprintf out "{@[code=%d;@ headers=[@[%a@]];@ body=%a@]}" self.code Headers.pp self.headers pp_body self.body - let output_ (oc : out_channel) (self : t) : unit = - Printf.fprintf oc "HTTP/1.1 %d %s\r\n" self.code + let output_ ?(buf = Buf.create ~size:256 ()) (oc : IO.Out_channel.t) + (self : t) : unit = + (* double indirection: + - print into [buffer] using [bprintf] + - transfer to [buf_] so we can output from there *) + let tmp_buffer = Buffer.create 32 in + + (* write start of reply *) + Printf.bprintf tmp_buffer "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code); + Buf.add_buffer buf tmp_buffer; + Buffer.clear tmp_buffer; + let body, is_chunked = match self.body with | `String s when String.length s > 1024 * 500 -> @@ -447,19 +458,29 @@ module Response = struct _debug (fun k -> k "output response: %s" (Format.asprintf "%a" pp { self with body = `String "<…>" })); - List.iter (fun (k, v) -> Printf.fprintf oc "%s: %s\r\n" k v) headers; - output_string oc "\r\n"; + + (* write headers *) + List.iter + (fun (k, v) -> + Printf.bprintf tmp_buffer "%s: %s\r\n" k v; + Buf.add_buffer buf tmp_buffer; + Buffer.clear tmp_buffer) + headers; + + IO.Out_channel.output_buf oc buf; + IO.Out_channel.output_string oc "\r\n"; + (match body with | `String "" | `Void -> () - | `String s -> output_string oc s + | `String s -> IO.Out_channel.output_string oc s | `Stream str -> (try - Byte_stream.output_chunked oc str; + Byte_stream.output_chunked' oc str; Byte_stream.close str with e -> Byte_stream.close str; raise e)); - flush oc + IO.Out_channel.flush oc end (* semaphore, for limiting concurrency. *) @@ -593,7 +614,7 @@ module Middleware = struct end (* a request handler. handles a single request. *) -type cb_path_handler = out_channel -> Middleware.handler +type cb_path_handler = IO.Out_channel.t -> Middleware.handler module type SERVER_SENT_GENERATOR = sig val set_headers : Headers.t -> unit @@ -606,18 +627,26 @@ end type server_sent_generator = (module SERVER_SENT_GENERATOR) +module type IO_BACKEND = sig + val init_addr : unit -> string + val init_port : unit -> int + + val spawn : (unit -> unit) -> unit + (** function used to spawn a new thread to handle a + new client connection. By default it is {!Thread.create} but one + could use a thread pool instead.*) + + val get_time_s : unit -> float + (** obtain the current timestamp in seconds. *) + + val tcp_server : unit -> Tiny_httpd_io.TCP_server.builder + (** Server that can listen on a port and handle clients. *) +end + type t = { - addr: string; (** Address at creation *) - port: int; (** Port at creation *) - mutable sock: Unix.file_descr option; (** Socket *) - timeout: float; - sem_max_connections: Sem_.t; - (* semaphore to restrict the number of active concurrent connections *) - new_thread: (unit -> unit) -> unit; - (* a function to run the given callback in a separate thread (or thread pool) *) - masksigpipe: bool; + backend: (module IO_BACKEND); + mutable tcp_server: IO.TCP_server.t option; buf_size: int; - get_time_s: unit -> float; mutable handler: string Request.t -> Response.t; (** toplevel handler, if any *) mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) @@ -626,9 +655,6 @@ type t = { mutable path_handlers: (unit Request.t -> cb_path_handler resp_result option) list; (** path handlers *) - mutable running: bool; - (** true while the server is running. no need to protect with a mutex, - writes should be atomic enough. *) } let get_addr_ sock = @@ -636,17 +662,24 @@ let get_addr_ sock = | Unix.ADDR_INET (addr, port) -> addr, port | _ -> invalid_arg "httpd: address is not INET" -let addr self = - match self.sock with - | None -> self.addr - | Some s -> Unix.string_of_inet_addr (fst @@ get_addr_ s) +let addr (self : t) = + match self.tcp_server with + | None -> + let (module B) = self.backend in + B.init_addr () + | Some s -> fst @@ s.endpoint () -let port self = - match self.sock with - | None -> self.port - | Some sock -> snd @@ get_addr_ sock +let port (self : t) = + match self.tcp_server with + | None -> + let (module B) = self.backend in + B.init_port () + | Some s -> snd @@ s.endpoint () -let active_connections self = Sem_.num_acquired self.sem_max_connections - 1 +let active_connections (self : t) = + match self.tcp_server with + | None -> 0 + | Some s -> s.active_connections () let add_middleware ~stage self m = let stage = @@ -726,8 +759,10 @@ let[@inline] _opt_iter ~f o = | None -> () | Some x -> f x +exception Exit_SSE + let add_route_server_sent_handler ?accept self route f = - let tr_req oc req ~resp f = + let tr_req (oc : IO.Out_channel.t) req ~resp f = let req = Request.read_body_full ~buf_size:self.buf_size req in let headers = ref Headers.(empty |> set "content-type" "text/event-stream") @@ -746,16 +781,20 @@ let add_route_server_sent_handler ?accept self route f = ) in + let[@inline] writef fmt = + Printf.ksprintf (IO.Out_channel.output_string oc) fmt + in + let send_event ?event ?id ?retry ~data () : unit = send_response_idempotent_ (); - _opt_iter event ~f:(fun e -> Printf.fprintf oc "event: %s\n" e); - _opt_iter id ~f:(fun e -> Printf.fprintf oc "id: %s\n" e); - _opt_iter retry ~f:(fun e -> Printf.fprintf oc "retry: %s\n" e); + _opt_iter event ~f:(fun e -> writef "event: %s\n" e); + _opt_iter id ~f:(fun e -> writef "id: %s\n" e); + _opt_iter retry ~f:(fun e -> writef "retry: %s\n" e); let l = String.split_on_char '\n' data in - List.iter (fun s -> Printf.fprintf oc "data: %s\n" s) l; - output_string oc "\n"; + List.iter (fun s -> writef "data: %s\n" s) l; + IO.Out_channel.output_string oc "\n"; (* finish group *) - flush oc + IO.Out_channel.flush oc in let module SSG = struct let set_headers h = @@ -765,32 +804,26 @@ let add_route_server_sent_handler ?accept self route f = ) let send_event = send_event - let close () = raise Exit + let close () = raise Exit_SSE end in - try f req (module SSG : SERVER_SENT_GENERATOR) with Exit -> close_out oc + try f req (module SSG : SERVER_SENT_GENERATOR) + with Exit_SSE -> IO.Out_channel.close oc in add_route_handler_ self ?accept ~meth:`GET route ~tr_req f -let create ?(masksigpipe = true) ?(max_connections = 32) ?(timeout = 0.0) - ?(buf_size = 16 * 1_024) ?(get_time_s = Unix.gettimeofday) - ?(new_thread = fun f -> ignore (Thread.create f () : Thread.t)) - ?(addr = "127.0.0.1") ?(port = 8080) ?sock ?(middlewares = []) () : t = - let handler _req = Response.fail ~code:404 "no top handler" in +let get_max_connection_ ?(max_connections = 64) () : int = let max_connections = max 4 max_connections in + max_connections + +let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t = + let handler _req = Response.fail ~code:404 "no top handler" in let self = { - new_thread; - addr; - port; - sock; - masksigpipe; + backend; + tcp_server = None; handler; buf_size; - running = true; - sem_max_connections = Sem_.create max_connections; path_handlers = []; - timeout; - get_time_s; middlewares = []; middlewares_sorted = lazy []; } @@ -798,7 +831,149 @@ let create ?(masksigpipe = true) ?(max_connections = 32) ?(timeout = 0.0) List.iter (fun (stage, m) -> add_middleware self ~stage m) middlewares; self -let stop s = s.running <- false +let is_ipv6_str addr : bool = String.contains addr ':' + +module Unix_tcp_server_ = struct + type t = { + addr: string; + port: int; + max_connections: int; + sem_max_connections: Sem_.t; + (** semaphore to restrict the number of active concurrent connections *) + mutable sock: Unix.file_descr option; (** Socket *) + new_thread: (unit -> unit) -> unit; + timeout: float; + masksigpipe: bool; + mutable running: bool; (* TODO: use an atomic? *) + } + + let to_tcp_server (self : t) : IO.TCP_server.builder = + { + IO.TCP_server.serve = + (fun ~after_init ~handle () : unit -> + if self.masksigpipe then + ignore (Unix.sigprocmask Unix.SIG_BLOCK [ Sys.sigpipe ] : _ list); + let sock, should_bind = + match self.sock with + | Some s -> + ( s, + false + (* Because we're getting a socket from the caller (e.g. systemd) *) + ) + | None -> + ( Unix.socket + (if is_ipv6_str self.addr then + Unix.PF_INET6 + else + Unix.PF_INET) + Unix.SOCK_STREAM 0, + true (* Because we're creating the socket ourselves *) ) + in + Unix.clear_nonblock sock; + Unix.setsockopt_optint sock Unix.SO_LINGER None; + if should_bind then ( + let inet_addr = Unix.inet_addr_of_string self.addr in + Unix.setsockopt sock Unix.SO_REUSEADDR true; + Unix.bind sock (Unix.ADDR_INET (inet_addr, self.port)); + let n_listen = 2 * self.max_connections in + Unix.listen sock n_listen + ); + + self.sock <- Some sock; + + let tcp_server = + { + IO.TCP_server.stop = (fun () -> self.running <- false); + running = (fun () -> self.running); + active_connections = + (fun () -> Sem_.num_acquired self.sem_max_connections - 1); + endpoint = + (fun () -> + let addr, port = get_addr_ sock in + Unix.string_of_inet_addr addr, port); + } + in + after_init tcp_server; + + (* how to handle a single client *) + let handle_client_unix_ (client_sock : Unix.file_descr) : unit = + Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout); + Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout); + let oc = + IO.Out_channel.of_out_channel + @@ Unix.out_channel_of_descr client_sock + in + let ic = IO.In_channel.of_unix_fd client_sock in + handle.handle ic oc; + _debug (fun k -> k "done with client, exiting"); + (try Unix.close client_sock + with e -> + _debug (fun k -> + k "error when closing sock: %s" (Printexc.to_string e))); + () + in + + while self.running do + (* limit concurrency *) + Sem_.acquire 1 self.sem_max_connections; + try + let client_sock, _ = Unix.accept sock in + self.new_thread (fun () -> + try + handle_client_unix_ client_sock; + Sem_.release 1 self.sem_max_connections + with e -> + (try Unix.close client_sock with _ -> ()); + Sem_.release 1 self.sem_max_connections; + raise e) + with e -> + Sem_.release 1 self.sem_max_connections; + _debug (fun k -> + k "Unix.accept or Thread.create raised an exception: %s" + (Printexc.to_string e)) + done; + ()); + } +end + +let create ?(masksigpipe = true) ?max_connections ?(timeout = 0.0) ?buf_size + ?(get_time_s = Unix.gettimeofday) + ?(new_thread = fun f -> ignore (Thread.create f () : Thread.t)) + ?(addr = "127.0.0.1") ?(port = 8080) ?sock ?middlewares () : t = + let max_connections = get_max_connection_ ?max_connections () in + let server = + { + Unix_tcp_server_.addr; + new_thread; + running = true; + port; + sock; + max_connections; + sem_max_connections = Sem_.create max_connections; + masksigpipe; + timeout; + } + in + let tcp_server_builder = Unix_tcp_server_.to_tcp_server server in + let module B = struct + let init_addr () = addr + let init_port () = port + let get_time_s = get_time_s + let spawn f = new_thread f + let tcp_server () = tcp_server_builder + end in + let backend = (module B : IO_BACKEND) in + create_from ?buf_size ?middlewares ~backend () + +let stop (self : t) = + match self.tcp_server with + | None -> () + | Some s -> s.stop () + +let running (self : t) = + match self.tcp_server with + | None -> false + | Some s -> s.running () let find_map f l = let rec aux f = function @@ -810,16 +985,15 @@ let find_map f l = in aux f l -let handle_client_ (self : t) (client_sock : Unix.file_descr) : unit = - Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout); - Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout); - let oc = Unix.out_channel_of_descr client_sock in +(* handle client on [ic] and [oc] *) +let client_handle_for (self : t) ic oc : unit = let buf = Buf.create ~size:self.buf_size () in - let is = Byte_stream.of_fd ~buf_size:self.buf_size client_sock in + let is = Byte_stream.of_input ~buf_size:self.buf_size ic in let continue = ref true in - while !continue && self.running do + while !continue && running self do _debug (fun k -> k "read next request"); - match Request.parse_req_start ~get_time_s:self.get_time_s ~buf is with + let (module B) = self.backend in + match Request.parse_req_start ~get_time_s:B.get_time_s ~buf is with | Ok None -> continue := false (* client is done *) | Error (c, s) -> (* connection error, close *) @@ -833,7 +1007,7 @@ let handle_client_ (self : t) (client_sock : Unix.file_descr) : unit = (try (* is there a handler for this path? *) - let handler = + let base_handler = match find_map (fun ph -> ph req) self.path_handlers with | Some f -> unwrap_resp_result f | None -> @@ -857,7 +1031,7 @@ let handle_client_ (self : t) (client_sock : Unix.file_descr) : unit = List.fold_right (fun (_, m) h -> m h) (Lazy.force self.middlewares_sorted) - (handler oc) + (base_handler oc) in (* now actually read request's body into a stream *) @@ -889,62 +1063,24 @@ let handle_client_ (self : t) (client_sock : Unix.file_descr) : unit = continue := false; Response.output_ oc @@ Response.fail ~code:500 "server error: %s" (Printexc.to_string e)) - done; - _debug (fun k -> k "done with client, exiting"); - (try Unix.close client_sock - with e -> - _debug (fun k -> k "error when closing sock: %s" (Printexc.to_string e))); - () + done -let is_ipv6 self = String.contains self.addr ':' +let client_handler (self : t) : IO.TCP_server.conn_handler = + { IO.TCP_server.handle = client_handle_for self } +let is_ipv6 (self : t) = + let (module B) = self.backend in + is_ipv6_str (B.init_addr ()) + +(* TODO: init TCP server *) let run ?(after_init = ignore) (self : t) : (unit, _) result = try - if self.masksigpipe then - ignore (Unix.sigprocmask Unix.SIG_BLOCK [ Sys.sigpipe ] : _ list); - let sock, should_bind = - match self.sock with - | Some s -> - ( s, - false - (* Because we're getting a socket from the caller (e.g. systemd) *) ) - | None -> - ( Unix.socket - (if is_ipv6 self then - Unix.PF_INET6 - else - Unix.PF_INET) - Unix.SOCK_STREAM 0, - true (* Because we're creating the socket ourselves *) ) - in - Unix.clear_nonblock sock; - Unix.setsockopt_optint sock Unix.SO_LINGER None; - if should_bind then ( - let inet_addr = Unix.inet_addr_of_string self.addr in - Unix.setsockopt sock Unix.SO_REUSEADDR true; - Unix.bind sock (Unix.ADDR_INET (inet_addr, self.port)); - Unix.listen sock (2 * self.sem_max_connections.Sem_.n) - ); - self.sock <- Some sock; - after_init (); - while self.running do - (* limit concurrency *) - Sem_.acquire 1 self.sem_max_connections; - try - let client_sock, _ = Unix.accept sock in - self.new_thread (fun () -> - try - handle_client_ self client_sock; - Sem_.release 1 self.sem_max_connections - with e -> - (try Unix.close client_sock with _ -> ()); - Sem_.release 1 self.sem_max_connections; - raise e) - with e -> - Sem_.release 1 self.sem_max_connections; - _debug (fun k -> - k "Unix.accept or Thread.create raised an exception: %s" - (Printexc.to_string e)) - done; + let (module B) = self.backend in + let server = B.tcp_server () in + server.serve + ~after_init:(fun tcp_server -> + self.tcp_server <- Some tcp_server; + after_init ()) + ~handle:(client_handler self) (); Ok () with e -> Error e diff --git a/src/Tiny_httpd_server.mli b/src/Tiny_httpd_server.mli index 3ff0eac7..88011f56 100644 --- a/src/Tiny_httpd_server.mli +++ b/src/Tiny_httpd_server.mli @@ -369,7 +369,7 @@ val create : ?middlewares:([ `Encoding | `Stage of int ] * Middleware.t) list -> unit -> t -(** Create a new webserver. +(** Create a new webserver using UNIX abstractions. The server will not do anything until {!run} is called on it. Before starting the server, one can use {!add_path_handler} and @@ -401,6 +401,41 @@ val create : This parameter exists since 0.11. *) +(** A backend that provides IO operations, network operations, etc. *) +module type IO_BACKEND = sig + val init_addr : unit -> string + val init_port : unit -> int + + val spawn : (unit -> unit) -> unit + (** function used to spawn a new thread to handle a + new client connection. By default it is {!Thread.create} but one + could use a thread pool instead.*) + + val get_time_s : unit -> float + (** obtain the current timestamp in seconds. *) + + val tcp_server : unit -> Tiny_httpd_io.TCP_server.builder + (** Server that can listen on a port and handle clients. *) +end + +val create_from : + ?buf_size:int -> + ?middlewares:([ `Encoding | `Stage of int ] * Middleware.t) list -> + backend:(module IO_BACKEND) -> + unit -> + t +(** Create a new webserver using provided backend. + + The server will not do anything until {!run} is called on it. + Before starting the server, one can use {!add_path_handler} and + {!set_top_handler} to specify how to handle incoming requests. + + @param buf_size size for buffers (since 0.11) + @param middlewares see {!add_middleware} for more details. + + @since NEXT_RELEASE +*) + val addr : t -> string (** Address on which the server listens. *) @@ -556,6 +591,10 @@ val add_route_server_sent_handler : (** {2 Run the server} *) +val running : t -> bool +(** Is the server running? + @since NEXT_RELEASE *) + val stop : t -> unit (** Ask the server to stop. This might not have an immediate effect as {!run} might currently be waiting on IO. *)