mirror of
https://github.com/c-cube/tiny_httpd.git
synced 2025-12-06 11:15:35 -05:00
feat server: new notion of Upgrade handler
this handles `connection: upgrade` endpoints with a generic connection-oriented handler. The main goal is to support websockets.
This commit is contained in:
parent
f416f7272d
commit
d3a4dbc5b0
2 changed files with 197 additions and 38 deletions
|
|
@ -46,7 +46,8 @@ module Response_code = struct
|
|||
let[@inline] is_success n = n >= 200 && n < 400
|
||||
end
|
||||
|
||||
type 'a resp_result = ('a, Response_code.t * string) result
|
||||
type resp_error = Response_code.t * string
|
||||
type 'a resp_result = ('a, resp_error) result
|
||||
|
||||
let unwrap_resp_result = function
|
||||
| Ok x -> x
|
||||
|
|
@ -633,6 +634,26 @@ end
|
|||
|
||||
type server_sent_generator = (module SERVER_SENT_GENERATOR)
|
||||
|
||||
(** Handler that upgrades to another protocol *)
|
||||
module type UPGRADE_HANDLER = sig
|
||||
type handshake_state
|
||||
(** Some specific state returned after handshake *)
|
||||
|
||||
val name : string
|
||||
(** Name in the "upgrade" header *)
|
||||
|
||||
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result
|
||||
(** Perform the handshake and upgrade the connection. The returned
|
||||
code is [101] alongside these headers. *)
|
||||
|
||||
val handle_connection : handshake_state -> IO.Input.t -> IO.Output.t -> unit
|
||||
(** Take control of the connection and take it from there *)
|
||||
end
|
||||
|
||||
type upgrade_handler = (module UPGRADE_HANDLER)
|
||||
|
||||
exception Upgrade of unit Request.t * upgrade_handler
|
||||
|
||||
module type IO_BACKEND = sig
|
||||
val init_addr : unit -> string
|
||||
val init_port : unit -> int
|
||||
|
|
@ -644,6 +665,16 @@ module type IO_BACKEND = sig
|
|||
(** Server that can listen on a port and handle clients. *)
|
||||
end
|
||||
|
||||
type handler_result =
|
||||
| Handle of cb_path_handler
|
||||
| Fail of resp_error
|
||||
| Upgrade of upgrade_handler
|
||||
|
||||
let unwrap_handler_result req = function
|
||||
| Handle x -> x
|
||||
| Fail (c, s) -> raise (Bad_req (c, s))
|
||||
| Upgrade up -> raise (Upgrade (req, up))
|
||||
|
||||
type t = {
|
||||
backend: (module IO_BACKEND);
|
||||
mutable tcp_server: IO.TCP_server.t option;
|
||||
|
|
@ -653,8 +684,7 @@ type t = {
|
|||
mutable middlewares: (int * Middleware.t) list; (** Global middlewares *)
|
||||
mutable middlewares_sorted: (int * Middleware.t) list lazy_t;
|
||||
(** sorted version of {!middlewares} *)
|
||||
mutable path_handlers:
|
||||
(unit Request.t -> cb_path_handler resp_result option) list;
|
||||
mutable path_handlers: (unit Request.t -> handler_result option) list;
|
||||
(** path handlers *)
|
||||
buf_pool: Buf.t Pool.t;
|
||||
}
|
||||
|
|
@ -726,7 +756,7 @@ let set_top_handler self f = self.handler <- f
|
|||
and makes it into a handler. *)
|
||||
let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
|
||||
~tr_req self (route : _ Route.t) f =
|
||||
let ph req : cb_path_handler resp_result option =
|
||||
let ph req : handler_result option =
|
||||
match meth with
|
||||
| Some m when m <> req.Request.meth -> None (* ignore *)
|
||||
| _ ->
|
||||
|
|
@ -736,11 +766,11 @@ let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
|
|||
(match accept req with
|
||||
| Ok () ->
|
||||
Some
|
||||
(Ok
|
||||
(Handle
|
||||
(fun oc ->
|
||||
Middleware.apply_l middlewares @@ fun req ~resp ->
|
||||
tr_req oc req ~resp handler))
|
||||
| Error _ as e -> Some e)
|
||||
| Error err -> Some (Fail err))
|
||||
| None -> None (* path didn't match *))
|
||||
in
|
||||
self.path_handlers <- ph :: self.path_handlers
|
||||
|
|
@ -821,6 +851,22 @@ let add_route_server_sent_handler ?accept self route f =
|
|||
in
|
||||
add_route_handler_ self ?accept ~meth:`GET route ~tr_req f
|
||||
|
||||
let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = [])
|
||||
(self : t) route f : unit =
|
||||
let ph req : handler_result option =
|
||||
if req.Request.meth <> `GET then
|
||||
None
|
||||
else (
|
||||
match accept req with
|
||||
| Ok () ->
|
||||
(match Route.eval req.Request.path_components route f with
|
||||
| Some up -> Some (Upgrade up)
|
||||
| None -> None (* path didn't match *))
|
||||
| Error err -> Some (Fail err)
|
||||
)
|
||||
in
|
||||
self.path_handlers <- ph :: self.path_handlers
|
||||
|
||||
let get_max_connection_ ?(max_connections = 64) () : int =
|
||||
let max_connections = max 4 max_connections in
|
||||
max_connections
|
||||
|
|
@ -929,7 +975,9 @@ module Unix_tcp_server_ = struct
|
|||
Log.info (fun k ->
|
||||
k "done with client on %s, exiting"
|
||||
@@ str_of_sockaddr client_addr);
|
||||
(try Unix.close client_sock
|
||||
(try
|
||||
Unix.shutdown client_sock Unix.SHUTDOWN_ALL;
|
||||
Unix.close client_sock
|
||||
with e ->
|
||||
Log.error (fun k ->
|
||||
k "error when closing sock for client %s: %s"
|
||||
|
|
@ -1030,32 +1078,19 @@ let find_map f l =
|
|||
in
|
||||
aux f l
|
||||
|
||||
let string_as_list_contains_ (s : string) (sub : string) : bool =
|
||||
let fragments = String.split_on_char ',' s in
|
||||
List.exists (fun fragment -> String.trim fragment = sub) fragments
|
||||
|
||||
(* handle client on [ic] and [oc] *)
|
||||
let client_handle_for (self : t) ~client_addr ic oc : unit =
|
||||
Pool.with_resource self.buf_pool @@ fun buf ->
|
||||
Pool.with_resource self.buf_pool @@ fun buf_res ->
|
||||
let is = Byte_stream.of_input ~buf_size:self.buf_size ic in
|
||||
let continue = ref true in
|
||||
while !continue && running self do
|
||||
Log.debug (fun k -> k "read next request");
|
||||
let (module B) = self.backend in
|
||||
match
|
||||
Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is
|
||||
with
|
||||
| Ok None -> continue := false (* client is done *)
|
||||
| Error (c, s) ->
|
||||
(* connection error, close *)
|
||||
let res = Response.make_raw ~code:c s in
|
||||
(try Response.output_ ~buf:buf_res oc res with Sys_error _ -> ());
|
||||
continue := false
|
||||
| Ok (Some req) ->
|
||||
Log.debug (fun k ->
|
||||
k "parsed request: %s" (Format.asprintf "@[%a@]" Request.pp_ req));
|
||||
|
||||
if Request.close_after_req req then continue := false;
|
||||
|
||||
(* how to log the response to this query *)
|
||||
let log_response (resp : Response.t) =
|
||||
let log_response (req : _ Request.t) (resp : Response.t) =
|
||||
if not Log.dummy then (
|
||||
let msgf k =
|
||||
let elapsed = B.get_time_s () -. req.start_time in
|
||||
|
|
@ -1071,11 +1106,93 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
|
|||
)
|
||||
in
|
||||
|
||||
(* handle generic exception *)
|
||||
let handle_exn e =
|
||||
let resp =
|
||||
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
|
||||
in
|
||||
if not Log.dummy then
|
||||
Log.error (fun k ->
|
||||
k "response to %s code=%d" (str_of_sockaddr client_addr) resp.code);
|
||||
Response.output_ ~buf:buf_res oc resp
|
||||
in
|
||||
|
||||
let handle_bad_req req e =
|
||||
let resp =
|
||||
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
|
||||
in
|
||||
log_response req resp;
|
||||
Response.output_ ~buf:buf_res oc resp
|
||||
in
|
||||
|
||||
let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit =
|
||||
Log.debug (fun k -> k "upgrade connection");
|
||||
try
|
||||
(* check headers *)
|
||||
(match Request.get_header req "connection" with
|
||||
| Some str when string_as_list_contains_ str "Upgrade" -> ()
|
||||
| _ -> bad_reqf 426 "connection header must contain 'Upgrade'");
|
||||
(match Request.get_header req "upgrade" with
|
||||
| Some u when u = UP.name -> ()
|
||||
| Some u -> bad_reqf 426 "expected upgrade to be '%s', got '%s'" UP.name u
|
||||
| None -> bad_reqf 426 "expected 'connection: upgrade' header");
|
||||
|
||||
(* ok, this is the upgrade we expected *)
|
||||
match UP.handshake req with
|
||||
| Error msg ->
|
||||
(* fail the upgrade *)
|
||||
Log.error (fun k -> k "upgrade failed: %s" msg);
|
||||
let resp = Response.make_raw ~code:429 "upgrade required" in
|
||||
log_response req resp;
|
||||
Response.output_ ~buf:buf_res oc resp
|
||||
| Ok (headers, handshake_st) ->
|
||||
(* send the upgrade reply *)
|
||||
let headers =
|
||||
[ "connection", "upgrade"; "upgrade", UP.name ] @ headers
|
||||
in
|
||||
let resp = Response.make_string ~code:101 ~headers (Ok "") in
|
||||
log_response req resp;
|
||||
Response.output_ ~buf:buf_res oc resp;
|
||||
|
||||
(* now, give the whole connection over to the upgraded connection.
|
||||
Make sure to give the leftovers from [is] as well, if any.
|
||||
There might not be any because the first message doesn't normally come
|
||||
directly in the same packet as the handshake, but still. *)
|
||||
let ic =
|
||||
if is.len > 0 then (
|
||||
Log.debug (fun k -> k "LEFTOVERS! %d B" is.len);
|
||||
IO.Input.append (IO.Input.of_slice is.bs is.off is.len) ic
|
||||
) else
|
||||
ic
|
||||
in
|
||||
|
||||
UP.handle_connection handshake_st ic oc
|
||||
with e -> handle_bad_req req e
|
||||
in
|
||||
|
||||
let continue = ref true in
|
||||
|
||||
let handle_one_req () =
|
||||
match
|
||||
Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is
|
||||
with
|
||||
| Ok None -> continue := false (* client is done *)
|
||||
| Error (c, s) ->
|
||||
(* connection error, close *)
|
||||
let res = Response.make_raw ~code:c s in
|
||||
(try Response.output_ ~buf:buf_res oc res with Sys_error _ -> ());
|
||||
continue := false
|
||||
| Ok (Some req) ->
|
||||
Log.debug (fun k ->
|
||||
k "parsed request: %s" (Format.asprintf "@[%a@]" Request.pp_ req));
|
||||
|
||||
if Request.close_after_req req then continue := false;
|
||||
|
||||
(try
|
||||
(* is there a handler for this path? *)
|
||||
let base_handler =
|
||||
match find_map (fun ph -> ph req) self.path_handlers with
|
||||
| Some f -> unwrap_resp_result f
|
||||
| Some f -> unwrap_handler_result req f
|
||||
| None -> fun _oc req ~resp -> resp (self.handler req)
|
||||
in
|
||||
|
||||
|
|
@ -1108,7 +1225,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
|
|||
try
|
||||
if Headers.get "connection" r.Response.headers = Some "close" then
|
||||
continue := false;
|
||||
log_response r;
|
||||
log_response req r;
|
||||
Response.output_ ~buf:buf_res oc r
|
||||
with Sys_error _ -> continue := false
|
||||
in
|
||||
|
|
@ -1123,16 +1240,23 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
|
|||
| Bad_req (code, s) ->
|
||||
continue := false;
|
||||
let resp = Response.make_raw ~code s in
|
||||
log_response resp;
|
||||
log_response req resp;
|
||||
Response.output_ ~buf:buf_res oc resp
|
||||
| e ->
|
||||
continue := false;
|
||||
let resp =
|
||||
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
|
||||
| Upgrade _ as e -> raise e
|
||||
| e -> handle_bad_req req e)
|
||||
in
|
||||
log_response resp;
|
||||
Response.output_ ~buf:buf_res oc resp)
|
||||
|
||||
try
|
||||
while !continue && running self do
|
||||
Log.debug (fun k -> k "read next request");
|
||||
handle_one_req ()
|
||||
done
|
||||
with
|
||||
| Upgrade (req, up) ->
|
||||
(* upgrades take over the whole connection, we won't process
|
||||
any further request *)
|
||||
handle_upgrade req up
|
||||
| e -> handle_exn e
|
||||
|
||||
let client_handler (self : t) : IO.TCP_server.conn_handler =
|
||||
{ IO.TCP_server.handle = client_handle_for self }
|
||||
|
|
|
|||
|
|
@ -645,6 +645,41 @@ val add_route_server_sent_handler :
|
|||
|
||||
@since 0.9 *)
|
||||
|
||||
(** {2 Upgrade handlers}
|
||||
|
||||
These handlers upgrade the connection to another protocol.
|
||||
@since NEXT_RELEASE *)
|
||||
|
||||
(** Handler that upgrades to another protocol.
|
||||
@since NEXT_RELEASE *)
|
||||
module type UPGRADE_HANDLER = sig
|
||||
type handshake_state
|
||||
(** Some specific state returned after handshake *)
|
||||
|
||||
val name : string
|
||||
(** Name in the "upgrade" header *)
|
||||
|
||||
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result
|
||||
(** Perform the handshake and upgrade the connection. The returned
|
||||
code is [101] alongside these headers.
|
||||
In case the handshake fails, this only returns [Error log_msg].
|
||||
The connection is closed without further ado. *)
|
||||
|
||||
val handle_connection :
|
||||
handshake_state -> Tiny_httpd_io.Input.t -> Tiny_httpd_io.Output.t -> unit
|
||||
(** Take control of the connection and take it from there *)
|
||||
end
|
||||
|
||||
type upgrade_handler = (module UPGRADE_HANDLER)
|
||||
|
||||
val add_upgrade_handler :
|
||||
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
|
||||
?middlewares:Middleware.t list ->
|
||||
t ->
|
||||
('a, upgrade_handler) Route.t ->
|
||||
'a ->
|
||||
unit
|
||||
|
||||
(** {2 Run the server} *)
|
||||
|
||||
val running : t -> bool
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue