From d3a4dbc5b0a377036bdd0ff59870f26e1bf8b2f4 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 5 Feb 2024 00:27:21 -0500 Subject: [PATCH] feat server: new notion of Upgrade handler this handles `connection: upgrade` endpoints with a generic connection-oriented handler. The main goal is to support websockets. --- src/Tiny_httpd_server.ml | 200 ++++++++++++++++++++++++++++++-------- src/Tiny_httpd_server.mli | 35 +++++++ 2 files changed, 197 insertions(+), 38 deletions(-) diff --git a/src/Tiny_httpd_server.ml b/src/Tiny_httpd_server.ml index ade65cdc..56e9c930 100644 --- a/src/Tiny_httpd_server.ml +++ b/src/Tiny_httpd_server.ml @@ -46,7 +46,8 @@ module Response_code = struct let[@inline] is_success n = n >= 200 && n < 400 end -type 'a resp_result = ('a, Response_code.t * string) result +type resp_error = Response_code.t * string +type 'a resp_result = ('a, resp_error) result let unwrap_resp_result = function | Ok x -> x @@ -633,6 +634,26 @@ end type server_sent_generator = (module SERVER_SENT_GENERATOR) +(** Handler that upgrades to another protocol *) +module type UPGRADE_HANDLER = sig + type handshake_state + (** Some specific state returned after handshake *) + + val name : string + (** Name in the "upgrade" header *) + + val handshake : unit Request.t -> (Headers.t * handshake_state, string) result + (** Perform the handshake and upgrade the connection. The returned + code is [101] alongside these headers. *) + + val handle_connection : handshake_state -> IO.Input.t -> IO.Output.t -> unit + (** Take control of the connection and take it from there *) +end + +type upgrade_handler = (module UPGRADE_HANDLER) + +exception Upgrade of unit Request.t * upgrade_handler + module type IO_BACKEND = sig val init_addr : unit -> string val init_port : unit -> int @@ -644,6 +665,16 @@ module type IO_BACKEND = sig (** Server that can listen on a port and handle clients. *) end +type handler_result = + | Handle of cb_path_handler + | Fail of resp_error + | Upgrade of upgrade_handler + +let unwrap_handler_result req = function + | Handle x -> x + | Fail (c, s) -> raise (Bad_req (c, s)) + | Upgrade up -> raise (Upgrade (req, up)) + type t = { backend: (module IO_BACKEND); mutable tcp_server: IO.TCP_server.t option; @@ -653,8 +684,7 @@ type t = { mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) mutable middlewares_sorted: (int * Middleware.t) list lazy_t; (** sorted version of {!middlewares} *) - mutable path_handlers: - (unit Request.t -> cb_path_handler resp_result option) list; + mutable path_handlers: (unit Request.t -> handler_result option) list; (** path handlers *) buf_pool: Buf.t Pool.t; } @@ -726,7 +756,7 @@ let set_top_handler self f = self.handler <- f and makes it into a handler. *) let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth ~tr_req self (route : _ Route.t) f = - let ph req : cb_path_handler resp_result option = + let ph req : handler_result option = match meth with | Some m when m <> req.Request.meth -> None (* ignore *) | _ -> @@ -736,11 +766,11 @@ let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth (match accept req with | Ok () -> Some - (Ok + (Handle (fun oc -> Middleware.apply_l middlewares @@ fun req ~resp -> tr_req oc req ~resp handler)) - | Error _ as e -> Some e) + | Error err -> Some (Fail err)) | None -> None (* path didn't match *)) in self.path_handlers <- ph :: self.path_handlers @@ -821,6 +851,22 @@ let add_route_server_sent_handler ?accept self route f = in add_route_handler_ self ?accept ~meth:`GET route ~tr_req f +let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = []) + (self : t) route f : unit = + let ph req : handler_result option = + if req.Request.meth <> `GET then + None + else ( + match accept req with + | Ok () -> + (match Route.eval req.Request.path_components route f with + | Some up -> Some (Upgrade up) + | None -> None (* path didn't match *)) + | Error err -> Some (Fail err) + ) + in + self.path_handlers <- ph :: self.path_handlers + let get_max_connection_ ?(max_connections = 64) () : int = let max_connections = max 4 max_connections in max_connections @@ -929,7 +975,9 @@ module Unix_tcp_server_ = struct Log.info (fun k -> k "done with client on %s, exiting" @@ str_of_sockaddr client_addr); - (try Unix.close client_sock + (try + Unix.shutdown client_sock Unix.SHUTDOWN_ALL; + Unix.close client_sock with e -> Log.error (fun k -> k "error when closing sock for client %s: %s" @@ -1030,15 +1078,101 @@ let find_map f l = in aux f l +let string_as_list_contains_ (s : string) (sub : string) : bool = + let fragments = String.split_on_char ',' s in + List.exists (fun fragment -> String.trim fragment = sub) fragments + (* handle client on [ic] and [oc] *) let client_handle_for (self : t) ~client_addr ic oc : unit = Pool.with_resource self.buf_pool @@ fun buf -> Pool.with_resource self.buf_pool @@ fun buf_res -> let is = Byte_stream.of_input ~buf_size:self.buf_size ic in + let (module B) = self.backend in + + (* how to log the response to this query *) + let log_response (req : _ Request.t) (resp : Response.t) = + if not Log.dummy then ( + let msgf k = + let elapsed = B.get_time_s () -. req.start_time in + k + ("response to=%s code=%d time=%.3fs path=%S" : _ format4) + (str_of_sockaddr client_addr) + resp.code elapsed req.path + in + if Response_code.is_success resp.code then + Log.info msgf + else + Log.error msgf + ) + in + + (* handle generic exception *) + let handle_exn e = + let resp = + Response.fail ~code:500 "server error: %s" (Printexc.to_string e) + in + if not Log.dummy then + Log.error (fun k -> + k "response to %s code=%d" (str_of_sockaddr client_addr) resp.code); + Response.output_ ~buf:buf_res oc resp + in + + let handle_bad_req req e = + let resp = + Response.fail ~code:500 "server error: %s" (Printexc.to_string e) + in + log_response req resp; + Response.output_ ~buf:buf_res oc resp + in + + let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit = + Log.debug (fun k -> k "upgrade connection"); + try + (* check headers *) + (match Request.get_header req "connection" with + | Some str when string_as_list_contains_ str "Upgrade" -> () + | _ -> bad_reqf 426 "connection header must contain 'Upgrade'"); + (match Request.get_header req "upgrade" with + | Some u when u = UP.name -> () + | Some u -> bad_reqf 426 "expected upgrade to be '%s', got '%s'" UP.name u + | None -> bad_reqf 426 "expected 'connection: upgrade' header"); + + (* ok, this is the upgrade we expected *) + match UP.handshake req with + | Error msg -> + (* fail the upgrade *) + Log.error (fun k -> k "upgrade failed: %s" msg); + let resp = Response.make_raw ~code:429 "upgrade required" in + log_response req resp; + Response.output_ ~buf:buf_res oc resp + | Ok (headers, handshake_st) -> + (* send the upgrade reply *) + let headers = + [ "connection", "upgrade"; "upgrade", UP.name ] @ headers + in + let resp = Response.make_string ~code:101 ~headers (Ok "") in + log_response req resp; + Response.output_ ~buf:buf_res oc resp; + + (* now, give the whole connection over to the upgraded connection. + Make sure to give the leftovers from [is] as well, if any. + There might not be any because the first message doesn't normally come + directly in the same packet as the handshake, but still. *) + let ic = + if is.len > 0 then ( + Log.debug (fun k -> k "LEFTOVERS! %d B" is.len); + IO.Input.append (IO.Input.of_slice is.bs is.off is.len) ic + ) else + ic + in + + UP.handle_connection handshake_st ic oc + with e -> handle_bad_req req e + in + let continue = ref true in - while !continue && running self do - Log.debug (fun k -> k "read next request"); - let (module B) = self.backend in + + let handle_one_req () = match Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is with @@ -1054,28 +1188,11 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = if Request.close_after_req req then continue := false; - (* how to log the response to this query *) - let log_response (resp : Response.t) = - if not Log.dummy then ( - let msgf k = - let elapsed = B.get_time_s () -. req.start_time in - k - ("response to=%s code=%d time=%.3fs path=%S" : _ format4) - (str_of_sockaddr client_addr) - resp.code elapsed req.path - in - if Response_code.is_success resp.code then - Log.info msgf - else - Log.error msgf - ) - in - (try (* is there a handler for this path? *) let base_handler = match find_map (fun ph -> ph req) self.path_handlers with - | Some f -> unwrap_resp_result f + | Some f -> unwrap_handler_result req f | None -> fun _oc req ~resp -> resp (self.handler req) in @@ -1108,7 +1225,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = try if Headers.get "connection" r.Response.headers = Some "close" then continue := false; - log_response r; + log_response req r; Response.output_ ~buf:buf_res oc r with Sys_error _ -> continue := false in @@ -1123,16 +1240,23 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = | Bad_req (code, s) -> continue := false; let resp = Response.make_raw ~code s in - log_response resp; + log_response req resp; Response.output_ ~buf:buf_res oc resp - | e -> - continue := false; - let resp = - Response.fail ~code:500 "server error: %s" (Printexc.to_string e) - in - log_response resp; - Response.output_ ~buf:buf_res oc resp) - done + | Upgrade _ as e -> raise e + | e -> handle_bad_req req e) + in + + try + while !continue && running self do + Log.debug (fun k -> k "read next request"); + handle_one_req () + done + with + | Upgrade (req, up) -> + (* upgrades take over the whole connection, we won't process + any further request *) + handle_upgrade req up + | e -> handle_exn e let client_handler (self : t) : IO.TCP_server.conn_handler = { IO.TCP_server.handle = client_handle_for self } diff --git a/src/Tiny_httpd_server.mli b/src/Tiny_httpd_server.mli index 67270cdf..3842060f 100644 --- a/src/Tiny_httpd_server.mli +++ b/src/Tiny_httpd_server.mli @@ -645,6 +645,41 @@ val add_route_server_sent_handler : @since 0.9 *) +(** {2 Upgrade handlers} + + These handlers upgrade the connection to another protocol. + @since NEXT_RELEASE *) + +(** Handler that upgrades to another protocol. + @since NEXT_RELEASE *) +module type UPGRADE_HANDLER = sig + type handshake_state + (** Some specific state returned after handshake *) + + val name : string + (** Name in the "upgrade" header *) + + val handshake : unit Request.t -> (Headers.t * handshake_state, string) result + (** Perform the handshake and upgrade the connection. The returned + code is [101] alongside these headers. + In case the handshake fails, this only returns [Error log_msg]. + The connection is closed without further ado. *) + + val handle_connection : + handshake_state -> Tiny_httpd_io.Input.t -> Tiny_httpd_io.Output.t -> unit + (** Take control of the connection and take it from there *) +end + +type upgrade_handler = (module UPGRADE_HANDLER) + +val add_upgrade_handler : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> + t -> + ('a, upgrade_handler) Route.t -> + 'a -> + unit + (** {2 Run the server} *) val running : t -> bool