feat: add a notion of Middleware

this subsumes and deprecates the encoding/decoding callbacks.
This commit is contained in:
Simon Cruanes 2021-12-09 16:43:47 -05:00
parent 2d2ffc722a
commit 6b0000eb6e
No known key found for this signature in database
GPG key ID: 4AC01D0849AA62B6
2 changed files with 119 additions and 41 deletions

View file

@ -789,12 +789,21 @@ module Route = struct
let pp out x = Format.pp_print_string out (to_string x) let pp out x = Format.pp_print_string out (to_string x)
end end
module Middleware = struct
type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit
type t = handler -> handler
(** Apply a list of middlewares to [h] *)
let apply_l (l:t list) (h:handler) : handler =
List.fold_right (fun m h -> m h) l h
let[@inline] nil : t = fun h -> h
end
(* a request handler. handles a single request. *) (* a request handler. handles a single request. *)
type cb_path_handler = type cb_path_handler =
out_channel -> out_channel ->
byte_stream Request.t -> Middleware.handler
resp:(Response.t -> unit) ->
unit
module type SERVER_SENT_GENERATOR = sig module type SERVER_SENT_GENERATOR = sig
val set_headers : Headers.t -> unit val set_headers : Headers.t -> unit
@ -828,16 +837,15 @@ type t = {
mutable handler: (string Request.t -> Response.t); mutable handler: (string Request.t -> Response.t);
(* toplevel handler, if any *) (* toplevel handler, if any *)
mutable middlewares : (int * Middleware.t) list;
(** Global middlewares *)
mutable middlewares_sorted : (int * Middleware.t) list lazy_t;
(* sorted version of {!middlewares} *)
mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list; mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list;
(* path handlers *) (* path handlers *)
mutable cb_decode_req:
(unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) list;
(* middleware to decode requests *)
mutable cb_encode_resp: (unit Request.t -> Response.t -> Response.t option) list;
(* middleware to encode responses *)
mutable running: bool; mutable running: bool;
(* true while the server is running. no need to protect with a mutex, (* true while the server is running. no need to protect with a mutex,
writes should be atomic enough. *) writes should be atomic enough. *)
@ -848,15 +856,48 @@ let port self = self.port
let active_connections self = Sem_.num_acquired self.sem_max_connections - 1 let active_connections self = Sem_.num_acquired self.sem_max_connections - 1
let add_decode_request_cb self f = self.cb_decode_req <- f :: self.cb_decode_req let add_middleware ~stage self m =
let add_encode_response_cb self f = self.cb_encode_resp <- f :: self.cb_encode_resp let stage = match stage with
| `Encoding -> 0
| `Stage n when n < 1 -> invalid_arg "add_middleware: bad stage"
| `Stage n -> n
in
self.middlewares <- (stage,m) :: self.middlewares;
self.middlewares_sorted <- lazy (
List.stable_sort (fun (s1,_) (s2,_) -> compare s1 s2) self.middlewares
)
let add_decode_request_cb self f =
(* turn it into a middleware *)
let m h req ~resp =
(* see if [f] modifies the stream *)
let req0 = {req with Request.body=()} in
match f req0 with
| None -> h req ~resp (* pass through *)
| Some (req1, tr_stream) ->
let req = {req1 with Request.body=tr_stream req.Request.body} in
h req ~resp
in
add_middleware self ~stage:`Encoding m
let add_encode_response_cb self f =
let m h req ~resp =
h req ~resp:(fun r ->
let req0 = {req with Request.body=()} in
(* now transform [r] if we want to *)
match f req0 r with
| None -> resp r
| Some r' -> resp r')
in
add_middleware self ~stage:`Encoding m
let set_top_handler self f = self.handler <- f let set_top_handler self f = self.handler <- f
(* route the given handler. (* route the given handler.
@param tr_req wraps the actual concrete function returned by the route @param tr_req wraps the actual concrete function returned by the route
and makes it into a handler. *) and makes it into a handler. *)
let add_route_handler_ let add_route_handler_
?(accept=fun _req -> Ok ()) ?(accept=fun _req -> Ok ()) ?(middlewares=[])
?meth ~tr_req self (route:_ Route.t) f = ?meth ~tr_req self (route:_ Route.t) f =
let ph req : cb_path_handler resp_result option = let ph req : cb_path_handler resp_result option =
match meth with match meth with
@ -866,7 +907,10 @@ let add_route_handler_
| Some handler -> | Some handler ->
(* we have a handler, do we accept the request based on its headers? *) (* we have a handler, do we accept the request based on its headers? *)
begin match accept req with begin match accept req with
| Ok () -> Some (Ok (fun oc req ~resp -> tr_req oc req ~resp handler)) | Ok () ->
Some (Ok (fun oc ->
Middleware.apply_l middlewares @@
fun req ~resp -> tr_req oc req ~resp handler))
| Error _ as e -> Some e | Error _ as e -> Some e
end end
| None -> | None ->
@ -875,13 +919,14 @@ let add_route_handler_
in in
self.path_handlers <- ph :: self.path_handlers self.path_handlers <- ph :: self.path_handlers
let add_route_handler (type a) ?accept ?meth self (route:(a,_) Route.t) (f:_) : unit = let add_route_handler (type a) ?accept ?middlewares ?meth
self (route:(a,_) Route.t) (f:_) : unit =
let tr_req _oc req ~resp f = resp (f (Request.read_body_full req)) in let tr_req _oc req ~resp f = resp (f (Request.read_body_full req)) in
add_route_handler_ ?accept ?meth self route ~tr_req f add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f
let add_route_handler_stream ?accept ?meth self route f = let add_route_handler_stream ?accept ?middlewares ?meth self route f =
let tr_req _oc req ~resp f = resp (f req) in let tr_req _oc req ~resp f = resp (f req) in
add_route_handler_ ?accept ?meth self route ~tr_req f add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f
let[@inline] _opt_iter ~f o = match o with let[@inline] _opt_iter ~f o = match o with
| None -> () | None -> ()
@ -938,7 +983,7 @@ let create
{ new_thread; addr; port; sock; masksigpipe; handler; { new_thread; addr; port; sock; masksigpipe; handler;
running= true; sem_max_connections=Sem_.create max_connections; running= true; sem_max_connections=Sem_.create max_connections;
path_handlers=[]; timeout; path_handlers=[]; timeout;
cb_encode_resp=[]; cb_decode_req=[]; middlewares=[]; middlewares_sorted=lazy [];
} }
let stop s = s.running <- false let stop s = s.running <- false
@ -983,7 +1028,10 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
let handler = let handler =
match find_map (fun ph -> ph req) self.path_handlers with match find_map (fun ph -> ph req) self.path_handlers with
| Some f -> unwrap_resp_result f | Some f -> unwrap_resp_result f
| None -> (fun _oc req ~resp -> resp (self.handler (Request.read_body_full req))) | None ->
(fun _oc req ~resp ->
let body_str = Request.read_body_full req in
resp (self.handler body_str))
in in
(* handle expect/continue *) (* handle expect/continue *)
@ -995,33 +1043,22 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
| None -> () | None -> ()
end; end;
(* preprocess request's input stream *) (* apply middlewares *)
let req0, tr_stream = let handler =
List.fold_left fun oc ->
(fun (req,tr) cb -> List.fold_right (fun (_, m) h -> m h)
match cb req with (Lazy.force self.middlewares_sorted) (handler oc)
| None -> req, tr
| Some (r',f) -> r', (fun is -> tr is |> f))
(req, (fun is->is)) self.cb_decode_req
in
(* now actually read request's body into a stream *)
let req =
Request.parse_body_ ~tr_stream ~buf {req0 with body=is}
|> unwrap_resp_result
in in
(* how to post-process response accordingly *) (* now actually read request's body into a stream *)
let post_process_resp resp = let req =
List.fold_left Request.parse_body_ ~tr_stream:(fun s->s) ~buf {req with body=is}
(fun resp cb -> match cb req0 resp with None -> resp | Some r' -> r') |> unwrap_resp_result
resp self.cb_encode_resp
in in
(* how to reply *) (* how to reply *)
let resp r = let resp r =
try try Response.output_ oc r
let r = post_process_resp r in
Response.output_ oc r
with Sys_error _ -> continue := false with Sys_error _ -> continue := false
in in

View file

@ -426,6 +426,31 @@ module Route : sig
@since 0.7 *) @since 0.7 *)
end end
(** {2 Middlewares}
A middleware can be inserted in a handler to modify or observe
its behavior.
@since NEXT_RELEASE
*)
module Middleware : sig
type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit
(** Handlers are functions returning a response to a request.
The response can be delayed, hence the use of a continuation
as the [resp] parameter. *)
type t = handler -> handler
(** A middleware is a handler transformation.
It takes the existing handler [h],
and returns a new one which, given a query, modify it or log it
before passing it to [h], or fail. It can also log or modify or drop
the response. *)
val nil : t
(** Trivial middleware that does nothing. *)
end
(** {2 Main Server type} *) (** {2 Main Server type} *)
type t type t
@ -487,6 +512,8 @@ val add_decode_request_cb :
modified headers, typically). modified headers, typically).
A possible use is to handle decompression by looking for a [Transfer-Encoding] A possible use is to handle decompression by looking for a [Transfer-Encoding]
header and returning a stream transformer that decompresses on the fly. header and returning a stream transformer that decompresses on the fly.
@deprecated use {!add_middleware} instead
*) *)
val add_encode_response_cb: val add_encode_response_cb:
@ -496,6 +523,18 @@ val add_encode_response_cb:
response, for example to compress it. response, for example to compress it.
The callback is given the query with only its headers, The callback is given the query with only its headers,
as well as the current response. as well as the current response.
@deprecated use {!add_middleware} instead
*)
val add_middleware :
stage:[`Encoding | `Stage of int] ->
t -> Middleware.t -> unit
(** Add a middleware to every request/response pair.
@param stage specify when middleware applies.
Encoding comes first (outermost layer), then stages in increasing order.
@raise Invalid_argument if stage is [`Stage n] where [n < 1]
@since NEXT_RELEASE
*) *)
(** {2 Request handlers} *) (** {2 Request handlers} *)
@ -509,6 +548,7 @@ val set_top_handler : t -> (string Request.t -> Response.t) -> unit
val add_route_handler : val add_route_handler :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> ?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
?middlewares:Middleware.t list ->
?meth:Meth.t -> ?meth:Meth.t ->
t -> t ->
('a, string Request.t -> Response.t) Route.t -> 'a -> ('a, string Request.t -> Response.t) Route.t -> 'a ->
@ -534,6 +574,7 @@ val add_route_handler :
val add_route_handler_stream : val add_route_handler_stream :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> ?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
?middlewares:Middleware.t list ->
?meth:Meth.t -> ?meth:Meth.t ->
t -> t ->
('a, byte_stream Request.t -> Response.t) Route.t -> 'a -> ('a, byte_stream Request.t -> Response.t) Route.t -> 'a ->