mirror of
https://github.com/c-cube/tiny_httpd.git
synced 2025-12-06 11:15:35 -05:00
feat server: new notion of Upgrade handler
this handles `connection: upgrade` endpoints with a generic connection-oriented handler. The main goal is to support websockets.
This commit is contained in:
parent
f416f7272d
commit
d3a4dbc5b0
2 changed files with 197 additions and 38 deletions
|
|
@ -46,7 +46,8 @@ module Response_code = struct
|
||||||
let[@inline] is_success n = n >= 200 && n < 400
|
let[@inline] is_success n = n >= 200 && n < 400
|
||||||
end
|
end
|
||||||
|
|
||||||
type 'a resp_result = ('a, Response_code.t * string) result
|
type resp_error = Response_code.t * string
|
||||||
|
type 'a resp_result = ('a, resp_error) result
|
||||||
|
|
||||||
let unwrap_resp_result = function
|
let unwrap_resp_result = function
|
||||||
| Ok x -> x
|
| Ok x -> x
|
||||||
|
|
@ -633,6 +634,26 @@ end
|
||||||
|
|
||||||
type server_sent_generator = (module SERVER_SENT_GENERATOR)
|
type server_sent_generator = (module SERVER_SENT_GENERATOR)
|
||||||
|
|
||||||
|
(** Handler that upgrades to another protocol *)
|
||||||
|
module type UPGRADE_HANDLER = sig
|
||||||
|
type handshake_state
|
||||||
|
(** Some specific state returned after handshake *)
|
||||||
|
|
||||||
|
val name : string
|
||||||
|
(** Name in the "upgrade" header *)
|
||||||
|
|
||||||
|
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result
|
||||||
|
(** Perform the handshake and upgrade the connection. The returned
|
||||||
|
code is [101] alongside these headers. *)
|
||||||
|
|
||||||
|
val handle_connection : handshake_state -> IO.Input.t -> IO.Output.t -> unit
|
||||||
|
(** Take control of the connection and take it from there *)
|
||||||
|
end
|
||||||
|
|
||||||
|
type upgrade_handler = (module UPGRADE_HANDLER)
|
||||||
|
|
||||||
|
exception Upgrade of unit Request.t * upgrade_handler
|
||||||
|
|
||||||
module type IO_BACKEND = sig
|
module type IO_BACKEND = sig
|
||||||
val init_addr : unit -> string
|
val init_addr : unit -> string
|
||||||
val init_port : unit -> int
|
val init_port : unit -> int
|
||||||
|
|
@ -644,6 +665,16 @@ module type IO_BACKEND = sig
|
||||||
(** Server that can listen on a port and handle clients. *)
|
(** Server that can listen on a port and handle clients. *)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
type handler_result =
|
||||||
|
| Handle of cb_path_handler
|
||||||
|
| Fail of resp_error
|
||||||
|
| Upgrade of upgrade_handler
|
||||||
|
|
||||||
|
let unwrap_handler_result req = function
|
||||||
|
| Handle x -> x
|
||||||
|
| Fail (c, s) -> raise (Bad_req (c, s))
|
||||||
|
| Upgrade up -> raise (Upgrade (req, up))
|
||||||
|
|
||||||
type t = {
|
type t = {
|
||||||
backend: (module IO_BACKEND);
|
backend: (module IO_BACKEND);
|
||||||
mutable tcp_server: IO.TCP_server.t option;
|
mutable tcp_server: IO.TCP_server.t option;
|
||||||
|
|
@ -653,8 +684,7 @@ type t = {
|
||||||
mutable middlewares: (int * Middleware.t) list; (** Global middlewares *)
|
mutable middlewares: (int * Middleware.t) list; (** Global middlewares *)
|
||||||
mutable middlewares_sorted: (int * Middleware.t) list lazy_t;
|
mutable middlewares_sorted: (int * Middleware.t) list lazy_t;
|
||||||
(** sorted version of {!middlewares} *)
|
(** sorted version of {!middlewares} *)
|
||||||
mutable path_handlers:
|
mutable path_handlers: (unit Request.t -> handler_result option) list;
|
||||||
(unit Request.t -> cb_path_handler resp_result option) list;
|
|
||||||
(** path handlers *)
|
(** path handlers *)
|
||||||
buf_pool: Buf.t Pool.t;
|
buf_pool: Buf.t Pool.t;
|
||||||
}
|
}
|
||||||
|
|
@ -726,7 +756,7 @@ let set_top_handler self f = self.handler <- f
|
||||||
and makes it into a handler. *)
|
and makes it into a handler. *)
|
||||||
let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
|
let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
|
||||||
~tr_req self (route : _ Route.t) f =
|
~tr_req self (route : _ Route.t) f =
|
||||||
let ph req : cb_path_handler resp_result option =
|
let ph req : handler_result option =
|
||||||
match meth with
|
match meth with
|
||||||
| Some m when m <> req.Request.meth -> None (* ignore *)
|
| Some m when m <> req.Request.meth -> None (* ignore *)
|
||||||
| _ ->
|
| _ ->
|
||||||
|
|
@ -736,11 +766,11 @@ let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
|
||||||
(match accept req with
|
(match accept req with
|
||||||
| Ok () ->
|
| Ok () ->
|
||||||
Some
|
Some
|
||||||
(Ok
|
(Handle
|
||||||
(fun oc ->
|
(fun oc ->
|
||||||
Middleware.apply_l middlewares @@ fun req ~resp ->
|
Middleware.apply_l middlewares @@ fun req ~resp ->
|
||||||
tr_req oc req ~resp handler))
|
tr_req oc req ~resp handler))
|
||||||
| Error _ as e -> Some e)
|
| Error err -> Some (Fail err))
|
||||||
| None -> None (* path didn't match *))
|
| None -> None (* path didn't match *))
|
||||||
in
|
in
|
||||||
self.path_handlers <- ph :: self.path_handlers
|
self.path_handlers <- ph :: self.path_handlers
|
||||||
|
|
@ -821,6 +851,22 @@ let add_route_server_sent_handler ?accept self route f =
|
||||||
in
|
in
|
||||||
add_route_handler_ self ?accept ~meth:`GET route ~tr_req f
|
add_route_handler_ self ?accept ~meth:`GET route ~tr_req f
|
||||||
|
|
||||||
|
let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = [])
|
||||||
|
(self : t) route f : unit =
|
||||||
|
let ph req : handler_result option =
|
||||||
|
if req.Request.meth <> `GET then
|
||||||
|
None
|
||||||
|
else (
|
||||||
|
match accept req with
|
||||||
|
| Ok () ->
|
||||||
|
(match Route.eval req.Request.path_components route f with
|
||||||
|
| Some up -> Some (Upgrade up)
|
||||||
|
| None -> None (* path didn't match *))
|
||||||
|
| Error err -> Some (Fail err)
|
||||||
|
)
|
||||||
|
in
|
||||||
|
self.path_handlers <- ph :: self.path_handlers
|
||||||
|
|
||||||
let get_max_connection_ ?(max_connections = 64) () : int =
|
let get_max_connection_ ?(max_connections = 64) () : int =
|
||||||
let max_connections = max 4 max_connections in
|
let max_connections = max 4 max_connections in
|
||||||
max_connections
|
max_connections
|
||||||
|
|
@ -929,7 +975,9 @@ module Unix_tcp_server_ = struct
|
||||||
Log.info (fun k ->
|
Log.info (fun k ->
|
||||||
k "done with client on %s, exiting"
|
k "done with client on %s, exiting"
|
||||||
@@ str_of_sockaddr client_addr);
|
@@ str_of_sockaddr client_addr);
|
||||||
(try Unix.close client_sock
|
(try
|
||||||
|
Unix.shutdown client_sock Unix.SHUTDOWN_ALL;
|
||||||
|
Unix.close client_sock
|
||||||
with e ->
|
with e ->
|
||||||
Log.error (fun k ->
|
Log.error (fun k ->
|
||||||
k "error when closing sock for client %s: %s"
|
k "error when closing sock for client %s: %s"
|
||||||
|
|
@ -1030,32 +1078,19 @@ let find_map f l =
|
||||||
in
|
in
|
||||||
aux f l
|
aux f l
|
||||||
|
|
||||||
|
let string_as_list_contains_ (s : string) (sub : string) : bool =
|
||||||
|
let fragments = String.split_on_char ',' s in
|
||||||
|
List.exists (fun fragment -> String.trim fragment = sub) fragments
|
||||||
|
|
||||||
(* handle client on [ic] and [oc] *)
|
(* handle client on [ic] and [oc] *)
|
||||||
let client_handle_for (self : t) ~client_addr ic oc : unit =
|
let client_handle_for (self : t) ~client_addr ic oc : unit =
|
||||||
Pool.with_resource self.buf_pool @@ fun buf ->
|
Pool.with_resource self.buf_pool @@ fun buf ->
|
||||||
Pool.with_resource self.buf_pool @@ fun buf_res ->
|
Pool.with_resource self.buf_pool @@ fun buf_res ->
|
||||||
let is = Byte_stream.of_input ~buf_size:self.buf_size ic in
|
let is = Byte_stream.of_input ~buf_size:self.buf_size ic in
|
||||||
let continue = ref true in
|
|
||||||
while !continue && running self do
|
|
||||||
Log.debug (fun k -> k "read next request");
|
|
||||||
let (module B) = self.backend in
|
let (module B) = self.backend in
|
||||||
match
|
|
||||||
Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is
|
|
||||||
with
|
|
||||||
| Ok None -> continue := false (* client is done *)
|
|
||||||
| Error (c, s) ->
|
|
||||||
(* connection error, close *)
|
|
||||||
let res = Response.make_raw ~code:c s in
|
|
||||||
(try Response.output_ ~buf:buf_res oc res with Sys_error _ -> ());
|
|
||||||
continue := false
|
|
||||||
| Ok (Some req) ->
|
|
||||||
Log.debug (fun k ->
|
|
||||||
k "parsed request: %s" (Format.asprintf "@[%a@]" Request.pp_ req));
|
|
||||||
|
|
||||||
if Request.close_after_req req then continue := false;
|
|
||||||
|
|
||||||
(* how to log the response to this query *)
|
(* how to log the response to this query *)
|
||||||
let log_response (resp : Response.t) =
|
let log_response (req : _ Request.t) (resp : Response.t) =
|
||||||
if not Log.dummy then (
|
if not Log.dummy then (
|
||||||
let msgf k =
|
let msgf k =
|
||||||
let elapsed = B.get_time_s () -. req.start_time in
|
let elapsed = B.get_time_s () -. req.start_time in
|
||||||
|
|
@ -1071,11 +1106,93 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
|
||||||
)
|
)
|
||||||
in
|
in
|
||||||
|
|
||||||
|
(* handle generic exception *)
|
||||||
|
let handle_exn e =
|
||||||
|
let resp =
|
||||||
|
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
|
||||||
|
in
|
||||||
|
if not Log.dummy then
|
||||||
|
Log.error (fun k ->
|
||||||
|
k "response to %s code=%d" (str_of_sockaddr client_addr) resp.code);
|
||||||
|
Response.output_ ~buf:buf_res oc resp
|
||||||
|
in
|
||||||
|
|
||||||
|
let handle_bad_req req e =
|
||||||
|
let resp =
|
||||||
|
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
|
||||||
|
in
|
||||||
|
log_response req resp;
|
||||||
|
Response.output_ ~buf:buf_res oc resp
|
||||||
|
in
|
||||||
|
|
||||||
|
let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit =
|
||||||
|
Log.debug (fun k -> k "upgrade connection");
|
||||||
|
try
|
||||||
|
(* check headers *)
|
||||||
|
(match Request.get_header req "connection" with
|
||||||
|
| Some str when string_as_list_contains_ str "Upgrade" -> ()
|
||||||
|
| _ -> bad_reqf 426 "connection header must contain 'Upgrade'");
|
||||||
|
(match Request.get_header req "upgrade" with
|
||||||
|
| Some u when u = UP.name -> ()
|
||||||
|
| Some u -> bad_reqf 426 "expected upgrade to be '%s', got '%s'" UP.name u
|
||||||
|
| None -> bad_reqf 426 "expected 'connection: upgrade' header");
|
||||||
|
|
||||||
|
(* ok, this is the upgrade we expected *)
|
||||||
|
match UP.handshake req with
|
||||||
|
| Error msg ->
|
||||||
|
(* fail the upgrade *)
|
||||||
|
Log.error (fun k -> k "upgrade failed: %s" msg);
|
||||||
|
let resp = Response.make_raw ~code:429 "upgrade required" in
|
||||||
|
log_response req resp;
|
||||||
|
Response.output_ ~buf:buf_res oc resp
|
||||||
|
| Ok (headers, handshake_st) ->
|
||||||
|
(* send the upgrade reply *)
|
||||||
|
let headers =
|
||||||
|
[ "connection", "upgrade"; "upgrade", UP.name ] @ headers
|
||||||
|
in
|
||||||
|
let resp = Response.make_string ~code:101 ~headers (Ok "") in
|
||||||
|
log_response req resp;
|
||||||
|
Response.output_ ~buf:buf_res oc resp;
|
||||||
|
|
||||||
|
(* now, give the whole connection over to the upgraded connection.
|
||||||
|
Make sure to give the leftovers from [is] as well, if any.
|
||||||
|
There might not be any because the first message doesn't normally come
|
||||||
|
directly in the same packet as the handshake, but still. *)
|
||||||
|
let ic =
|
||||||
|
if is.len > 0 then (
|
||||||
|
Log.debug (fun k -> k "LEFTOVERS! %d B" is.len);
|
||||||
|
IO.Input.append (IO.Input.of_slice is.bs is.off is.len) ic
|
||||||
|
) else
|
||||||
|
ic
|
||||||
|
in
|
||||||
|
|
||||||
|
UP.handle_connection handshake_st ic oc
|
||||||
|
with e -> handle_bad_req req e
|
||||||
|
in
|
||||||
|
|
||||||
|
let continue = ref true in
|
||||||
|
|
||||||
|
let handle_one_req () =
|
||||||
|
match
|
||||||
|
Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is
|
||||||
|
with
|
||||||
|
| Ok None -> continue := false (* client is done *)
|
||||||
|
| Error (c, s) ->
|
||||||
|
(* connection error, close *)
|
||||||
|
let res = Response.make_raw ~code:c s in
|
||||||
|
(try Response.output_ ~buf:buf_res oc res with Sys_error _ -> ());
|
||||||
|
continue := false
|
||||||
|
| Ok (Some req) ->
|
||||||
|
Log.debug (fun k ->
|
||||||
|
k "parsed request: %s" (Format.asprintf "@[%a@]" Request.pp_ req));
|
||||||
|
|
||||||
|
if Request.close_after_req req then continue := false;
|
||||||
|
|
||||||
(try
|
(try
|
||||||
(* is there a handler for this path? *)
|
(* is there a handler for this path? *)
|
||||||
let base_handler =
|
let base_handler =
|
||||||
match find_map (fun ph -> ph req) self.path_handlers with
|
match find_map (fun ph -> ph req) self.path_handlers with
|
||||||
| Some f -> unwrap_resp_result f
|
| Some f -> unwrap_handler_result req f
|
||||||
| None -> fun _oc req ~resp -> resp (self.handler req)
|
| None -> fun _oc req ~resp -> resp (self.handler req)
|
||||||
in
|
in
|
||||||
|
|
||||||
|
|
@ -1108,7 +1225,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
|
||||||
try
|
try
|
||||||
if Headers.get "connection" r.Response.headers = Some "close" then
|
if Headers.get "connection" r.Response.headers = Some "close" then
|
||||||
continue := false;
|
continue := false;
|
||||||
log_response r;
|
log_response req r;
|
||||||
Response.output_ ~buf:buf_res oc r
|
Response.output_ ~buf:buf_res oc r
|
||||||
with Sys_error _ -> continue := false
|
with Sys_error _ -> continue := false
|
||||||
in
|
in
|
||||||
|
|
@ -1123,16 +1240,23 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
|
||||||
| Bad_req (code, s) ->
|
| Bad_req (code, s) ->
|
||||||
continue := false;
|
continue := false;
|
||||||
let resp = Response.make_raw ~code s in
|
let resp = Response.make_raw ~code s in
|
||||||
log_response resp;
|
log_response req resp;
|
||||||
Response.output_ ~buf:buf_res oc resp
|
Response.output_ ~buf:buf_res oc resp
|
||||||
| e ->
|
| Upgrade _ as e -> raise e
|
||||||
continue := false;
|
| e -> handle_bad_req req e)
|
||||||
let resp =
|
|
||||||
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
|
|
||||||
in
|
in
|
||||||
log_response resp;
|
|
||||||
Response.output_ ~buf:buf_res oc resp)
|
try
|
||||||
|
while !continue && running self do
|
||||||
|
Log.debug (fun k -> k "read next request");
|
||||||
|
handle_one_req ()
|
||||||
done
|
done
|
||||||
|
with
|
||||||
|
| Upgrade (req, up) ->
|
||||||
|
(* upgrades take over the whole connection, we won't process
|
||||||
|
any further request *)
|
||||||
|
handle_upgrade req up
|
||||||
|
| e -> handle_exn e
|
||||||
|
|
||||||
let client_handler (self : t) : IO.TCP_server.conn_handler =
|
let client_handler (self : t) : IO.TCP_server.conn_handler =
|
||||||
{ IO.TCP_server.handle = client_handle_for self }
|
{ IO.TCP_server.handle = client_handle_for self }
|
||||||
|
|
|
||||||
|
|
@ -645,6 +645,41 @@ val add_route_server_sent_handler :
|
||||||
|
|
||||||
@since 0.9 *)
|
@since 0.9 *)
|
||||||
|
|
||||||
|
(** {2 Upgrade handlers}
|
||||||
|
|
||||||
|
These handlers upgrade the connection to another protocol.
|
||||||
|
@since NEXT_RELEASE *)
|
||||||
|
|
||||||
|
(** Handler that upgrades to another protocol.
|
||||||
|
@since NEXT_RELEASE *)
|
||||||
|
module type UPGRADE_HANDLER = sig
|
||||||
|
type handshake_state
|
||||||
|
(** Some specific state returned after handshake *)
|
||||||
|
|
||||||
|
val name : string
|
||||||
|
(** Name in the "upgrade" header *)
|
||||||
|
|
||||||
|
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result
|
||||||
|
(** Perform the handshake and upgrade the connection. The returned
|
||||||
|
code is [101] alongside these headers.
|
||||||
|
In case the handshake fails, this only returns [Error log_msg].
|
||||||
|
The connection is closed without further ado. *)
|
||||||
|
|
||||||
|
val handle_connection :
|
||||||
|
handshake_state -> Tiny_httpd_io.Input.t -> Tiny_httpd_io.Output.t -> unit
|
||||||
|
(** Take control of the connection and take it from there *)
|
||||||
|
end
|
||||||
|
|
||||||
|
type upgrade_handler = (module UPGRADE_HANDLER)
|
||||||
|
|
||||||
|
val add_upgrade_handler :
|
||||||
|
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
|
||||||
|
?middlewares:Middleware.t list ->
|
||||||
|
t ->
|
||||||
|
('a, upgrade_handler) Route.t ->
|
||||||
|
'a ->
|
||||||
|
unit
|
||||||
|
|
||||||
(** {2 Run the server} *)
|
(** {2 Run the server} *)
|
||||||
|
|
||||||
val running : t -> bool
|
val running : t -> bool
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue