fix middlewares: merge-sort per-request middleares and global ones

This commit is contained in:
Simon Cruanes 2024-02-27 15:14:12 -05:00
parent 1debf0f688
commit bcc208cf59
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -60,12 +60,12 @@ module type IO_BACKEND = sig
end end
type handler_result = type handler_result =
| Handle of cb_path_handler | Handle of (int * Middleware.t) list * cb_path_handler
| Fail of resp_error | Fail of resp_error
| Upgrade of upgrade_handler | Upgrade of upgrade_handler
let unwrap_handler_result req = function let unwrap_handler_result req = function
| Handle x -> x | Handle (l, h) -> l, h
| Fail (c, s) -> raise (Bad_req (c, s)) | Fail (c, s) -> raise (Bad_req (c, s))
| Upgrade up -> raise (Upgrade (req, up)) | Upgrade up -> raise (Upgrade (req, up))
@ -101,6 +101,9 @@ let active_connections (self : t) =
| None -> 0 | None -> 0
| Some s -> s.active_connections () | Some s -> s.active_connections ()
let sort_middlewares_ l =
List.stable_sort (fun (s1, _) (s2, _) -> compare s1 s2) l
let add_middleware ~stage self m = let add_middleware ~stage self m =
let stage = let stage =
match stage with match stage with
@ -109,9 +112,7 @@ let add_middleware ~stage self m =
| `Stage n -> n | `Stage n -> n
in in
self.middlewares <- (stage, m) :: self.middlewares; self.middlewares <- (stage, m) :: self.middlewares;
self.middlewares_sorted <- self.middlewares_sorted <- lazy (sort_middlewares_ self.middlewares)
lazy
(List.stable_sort (fun (s1, _) (s2, _) -> compare s1 s2) self.middlewares)
let add_decode_request_cb self f = let add_decode_request_cb self f =
(* turn it into a middleware *) (* turn it into a middleware *)
@ -145,6 +146,7 @@ let set_top_handler self f = self.handler <- f
and makes it into a handler. *) and makes it into a handler. *)
let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
~tr_req self (route : _ Route.t) f = ~tr_req self (route : _ Route.t) f =
let middlewares = List.map (fun h -> 5, h) middlewares in
let ph req : handler_result option = let ph req : handler_result option =
match meth with match meth with
| Some m when m <> req.Request.meth -> None (* ignore *) | Some m when m <> req.Request.meth -> None (* ignore *)
@ -156,9 +158,7 @@ let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
| Ok () -> | Ok () ->
Some Some
(Handle (Handle
(fun oc -> (middlewares, fun oc req ~resp -> tr_req oc req ~resp handler))
Middleware.apply_l middlewares @@ fun req ~resp ->
tr_req oc req ~resp handler))
| Error err -> Some (Fail err)) | Error err -> Some (Fail err))
| None -> None (* path didn't match *)) | None -> None (* path didn't match *))
in in
@ -409,10 +409,10 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
(try (try
(* is there a handler for this path? *) (* is there a handler for this path? *)
let base_handler = let handler_middlewares, base_handler =
match find_map (fun ph -> ph req) self.path_handlers with match find_map (fun ph -> ph req) self.path_handlers with
| Some f -> unwrap_handler_result req f | Some f -> unwrap_handler_result req f
| None -> fun _oc req ~resp -> resp (self.handler req) | None -> [], fun _oc req ~resp -> resp (self.handler req)
in in
(* handle expect/continue *) (* handle expect/continue *)
@ -424,12 +424,21 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
| Some s -> bad_reqf 417 "unknown expectation %s" s | Some s -> bad_reqf 417 "unknown expectation %s" s
| None -> ()); | None -> ());
(* merge per-request middlewares with the server-global middlewares *)
let global_middlewares = Lazy.force self.middlewares_sorted in
let all_middlewares =
if handler_middlewares = [] then
global_middlewares
else
sort_middlewares_
(List.rev_append handler_middlewares self.middlewares)
in
(* apply middlewares *) (* apply middlewares *)
let handler oc = let handler oc =
List.fold_right List.fold_right
(fun (_, m) h -> m h) (fun (_, m) h -> m h)
(Lazy.force self.middlewares_sorted) all_middlewares (base_handler oc)
(base_handler oc)
in in
(* now actually read request's body into a stream *) (* now actually read request's body into a stream *)