diff --git a/examples/echo.ml b/examples/echo.ml index 3b7c7700..f77bbe72 100644 --- a/examples/echo.ml +++ b/examples/echo.ml @@ -1,6 +1,39 @@ module S = Tiny_httpd +let now_ = Unix.gettimeofday + +(* util: a little middleware collecting statistics *) +let middleware_stat () : S.Middleware.t * (unit -> string) = + let n_req = ref 0 in + let total_time_ = ref 0. in + let parse_time_ = ref 0. in + let build_time_ = ref 0. in + let write_time_ = ref 0. in + + let m h req ~resp = + incr n_req; + let t1 = S.Request.start_time req in + let t2 = now_ () in + h req ~resp:(fun response -> + let t3 = now_ () in + resp response; + let t4 = now_ () in + total_time_ := !total_time_ +. (t4 -. t1); + parse_time_ := !parse_time_ +. (t2 -. t1); + build_time_ := !build_time_ +. (t3 -. t2); + write_time_ := !write_time_ +. (t4 -. t3); + ) + and get_stat () = + Printf.sprintf "%d requests (average response time: %.3fms = %.3fms + %.3fms + %.3fms)" + !n_req (!total_time_ /. float !n_req *. 1e3) + (!parse_time_ /. float !n_req *. 1e3) + (!build_time_ /. float !n_req *. 1e3) + (!write_time_ /. float !n_req *. 1e3) + in + m, get_stat + + let () = let port_ = ref 8080 in let j = ref 32 in @@ -10,12 +43,19 @@ let () = "--debug", Arg.Unit (fun () -> S._enable_debug true), " enable debug"; "-j", Arg.Set_int j, " maximum number of connections"; ]) (fun _ -> raise (Arg.Bad "")) "echo [option]*"; + let server = S.create ~port:!port_ ~max_connections:!j () in - Tiny_httpd_camlzip.setup ~compress_above:1024 ~buf_size:(1024*1024) server; + Tiny_httpd_camlzip.setup ~compress_above:1024 ~buf_size:(16*1024) server; + + let m_stats, get_stats = middleware_stat () in + S.add_middleware server ~stage:(`Stage 1) m_stats; + (* say hello *) S.add_route_handler ~meth:`GET server S.Route.(exact "hello" @/ string @/ return) (fun name _req -> S.Response.make_string (Ok ("hello " ^name ^"!\n"))); + + (* compressed file access *) S.add_route_handler ~meth:`GET server S.Route.(exact "zcat" @/ string_urlencoded @/ return) (fun path _req -> @@ -33,6 +73,7 @@ let () = in S.Response.make_stream ~headers:mime_type (Ok str) ); + (* echo request *) S.add_route_handler server S.Route.(exact "echo" @/ return) @@ -43,6 +84,8 @@ let () = in S.Response.make_string (Ok (Format.asprintf "echo:@ %a@ (query: %s)@." S.Request.pp req q))); + + (* file upload *) S.add_route_handler_stream ~meth:`PUT server S.Route.(exact "upload" @/ string @/ return) (fun path req -> @@ -56,6 +99,28 @@ let () = with e -> S.Response.fail ~code:500 "couldn't upload file: %s" (Printexc.to_string e) ); + + (* stats *) + S.add_route_handler server S.Route.(exact "stats" @/ return) + (fun _req -> + let stats = get_stats() in + S.Response.make_string @@ Ok stats + ); + + (* main page *) + S.add_route_handler server S.Route.(return) + (fun _req -> + let s = "\n\ +

welcome!\n

endpoints are:\n

" + in + S.Response.make_string ~headers:["content-type", "text/html"] @@ Ok s); + Printf.printf "listening on http://%s:%d\n%!" (S.addr server) (S.port server); match S.run server with | Ok () -> () diff --git a/src/Tiny_httpd.ml b/src/Tiny_httpd.ml index c0d92bfd..b0e52dbb 100644 --- a/src/Tiny_httpd.ml +++ b/src/Tiny_httpd.ml @@ -66,10 +66,10 @@ module Byte_stream = struct bs_close=(fun () -> ()); } - let of_chan_ ~close ic : t = + let of_chan_ ?(buf_size=16 * 1024) ~close ic : t = let i = ref 0 in let len = ref 0 in - let buf = Bytes.make 4096 ' ' in + let buf = Bytes.make buf_size ' ' in { bs_fill_buf=(fun () -> if !i >= !len then ( i := 0; @@ -116,10 +116,10 @@ module Byte_stream = struct let of_string s : t = of_bytes (Bytes.unsafe_of_string s) - let with_file file f = + let with_file ?buf_size file f = let ic = open_in file in try - let x = f (of_chan ic) in + let x = f (of_chan ?buf_size ic) in close_in ic; x with e -> @@ -367,6 +367,7 @@ module Request = struct path_components: string list; query: (string*string) list; body: 'body; + start_time: float; } let headers self = self.headers @@ -374,13 +375,16 @@ module Request = struct let meth self = self.meth let path self = self.path let body self = self.body + let start_time self = self.start_time let query self = self.query let get_header ?f self h = Headers.get ?f h self.headers let get_header_int self h = match get_header self h with | Some x -> (try Some (int_of_string x) with _ -> None) | None -> None - let set_header self k v = {self with headers=Headers.set k v self.headers} + let set_header k v self = {self with headers=Headers.set k v self.headers} + let update_headers f self = {self with headers=f self.headers} + let set_body b self = {self with body=b} let pp_comp_ out comp = Format.fprintf out "[%s]" @@ -481,6 +485,7 @@ module Request = struct let parse_req_start ~buf (bs:byte_stream) : unit t option resp_result = try let line = Byte_stream.read_line ~buf bs in + let start_time = Unix.gettimeofday () in let meth, path = try let m, p, v = Scanf.sscanf line "%s %s HTTP/1.%d\r" (fun x y z->x,y,z) in @@ -506,7 +511,7 @@ module Request = struct | Error e -> bad_reqf 400 "invalid query: %s" e in Ok (Some {meth; query; host; path; path_components; - headers; body=()}) + headers; body=(); start_time; }) with | End_of_file | Sys_error _ -> Ok None | Bad_req (c,s) -> Error (c,s) @@ -540,9 +545,10 @@ module Request = struct | e -> Error (400, Printexc.to_string e) - let read_body_full (self:byte_stream t) : string t = + let read_body_full ?buf_size (self:byte_stream t) : string t = try - let body = Byte_stream.read_all self.body in + let buf = Buf_.create ?size:buf_size () in + let body = Byte_stream.read_all ~buf self.body in { self with body } with | Bad_req _ as e -> raise e @@ -581,6 +587,12 @@ module Response = struct body: body; } + let set_body body self = {self with body} + let set_headers headers self = {self with headers} + let update_headers f self = {self with headers=f self.headers} + let set_header k v self = {self with headers = Headers.set k v self.headers} + let set_code code self = {self with code} + let make_raw ?(headers=[]) ~code body : t = (* add content length to response *) let headers = @@ -787,12 +799,21 @@ module Route = struct let pp out x = Format.pp_print_string out (to_string x) end +module Middleware = struct + type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit + type t = handler -> handler + + (** Apply a list of middlewares to [h] *) + let apply_l (l:t list) (h:handler) : handler = + List.fold_right (fun m h -> m h) l h + + let[@inline] nil : t = fun h -> h +end + (* a request handler. handles a single request. *) type cb_path_handler = out_channel -> - byte_stream Request.t -> - resp:(Response.t -> unit) -> - unit + Middleware.handler module type SERVER_SENT_GENERATOR = sig val set_headers : Headers.t -> unit @@ -823,19 +844,20 @@ type t = { masksigpipe: bool; + buf_size: int; + mutable handler: (string Request.t -> Response.t); (* toplevel handler, if any *) + mutable middlewares : (int * Middleware.t) list; + (** Global middlewares *) + + mutable middlewares_sorted : (int * Middleware.t) list lazy_t; + (* sorted version of {!middlewares} *) + mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list; (* path handlers *) - mutable cb_decode_req: - (unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) list; - (* middleware to decode requests *) - - mutable cb_encode_resp: (unit Request.t -> Response.t -> Response.t option) list; - (* middleware to encode responses *) - mutable running: bool; (* true while the server is running. no need to protect with a mutex, writes should be atomic enough. *) @@ -846,15 +868,48 @@ let port self = self.port let active_connections self = Sem_.num_acquired self.sem_max_connections - 1 -let add_decode_request_cb self f = self.cb_decode_req <- f :: self.cb_decode_req -let add_encode_response_cb self f = self.cb_encode_resp <- f :: self.cb_encode_resp +let add_middleware ~stage self m = + let stage = match stage with + | `Encoding -> 0 + | `Stage n when n < 1 -> invalid_arg "add_middleware: bad stage" + | `Stage n -> n + in + self.middlewares <- (stage,m) :: self.middlewares; + self.middlewares_sorted <- lazy ( + List.stable_sort (fun (s1,_) (s2,_) -> compare s1 s2) self.middlewares + ) + +let add_decode_request_cb self f = + (* turn it into a middleware *) + let m h req ~resp = + (* see if [f] modifies the stream *) + let req0 = {req with Request.body=()} in + match f req0 with + | None -> h req ~resp (* pass through *) + | Some (req1, tr_stream) -> + let req = {req1 with Request.body=tr_stream req.Request.body} in + h req ~resp + in + add_middleware self ~stage:`Encoding m + +let add_encode_response_cb self f = + let m h req ~resp = + h req ~resp:(fun r -> + let req0 = {req with Request.body=()} in + (* now transform [r] if we want to *) + match f req0 r with + | None -> resp r + | Some r' -> resp r') + in + add_middleware self ~stage:`Encoding m + let set_top_handler self f = self.handler <- f (* route the given handler. @param tr_req wraps the actual concrete function returned by the route and makes it into a handler. *) let add_route_handler_ - ?(accept=fun _req -> Ok ()) + ?(accept=fun _req -> Ok ()) ?(middlewares=[]) ?meth ~tr_req self (route:_ Route.t) f = let ph req : cb_path_handler resp_result option = match meth with @@ -864,7 +919,10 @@ let add_route_handler_ | Some handler -> (* we have a handler, do we accept the request based on its headers? *) begin match accept req with - | Ok () -> Some (Ok (fun oc req ~resp -> tr_req oc req ~resp handler)) + | Ok () -> + Some (Ok (fun oc -> + Middleware.apply_l middlewares @@ + fun req ~resp -> tr_req oc req ~resp handler)) | Error _ as e -> Some e end | None -> @@ -873,13 +931,14 @@ let add_route_handler_ in self.path_handlers <- ph :: self.path_handlers -let add_route_handler (type a) ?accept ?meth self (route:(a,_) Route.t) (f:_) : unit = - let tr_req _oc req ~resp f = resp (f (Request.read_body_full req)) in - add_route_handler_ ?accept ?meth self route ~tr_req f +let add_route_handler (type a) ?accept ?middlewares ?meth + self (route:(a,_) Route.t) (f:_) : unit = + let tr_req _oc req ~resp f = resp (f (Request.read_body_full ~buf_size:self.buf_size req)) in + add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f -let add_route_handler_stream ?accept ?meth self route f = +let add_route_handler_stream ?accept ?middlewares ?meth self route f = let tr_req _oc req ~resp f = resp (f req) in - add_route_handler_ ?accept ?meth self route ~tr_req f + add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f let[@inline] _opt_iter ~f o = match o with | None -> () @@ -887,7 +946,7 @@ let[@inline] _opt_iter ~f o = match o with let add_route_server_sent_handler ?accept self route f = let tr_req oc req ~resp f = - let req = Request.read_body_full req in + let req = Request.read_body_full ~buf_size:self.buf_size req in let headers = ref Headers.(empty |> set "content-type" "text/event-stream") in (* send response once *) @@ -929,15 +988,21 @@ let create ?(masksigpipe=true) ?(max_connections=32) ?(timeout=0.0) + ?(buf_size=16 * 1_024) ?(new_thread=(fun f -> ignore (Thread.create f () : Thread.t))) - ?(addr="127.0.0.1") ?(port=8080) ?sock () : t = + ?(addr="127.0.0.1") ?(port=8080) ?sock + ?(middlewares=[]) + () : t = let handler _req = Response.fail ~code:404 "no top handler" in let max_connections = max 4 max_connections in - { new_thread; addr; port; sock; masksigpipe; handler; + let self = { + new_thread; addr; port; sock; masksigpipe; handler; buf_size; running= true; sem_max_connections=Sem_.create max_connections; path_handlers=[]; timeout; - cb_encode_resp=[]; cb_decode_req=[]; - } + middlewares=[]; middlewares_sorted=lazy []; + } in + List.iter (fun (stage,m) -> add_middleware self ~stage m) middlewares; + self let stop s = s.running <- false @@ -955,8 +1020,8 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = let _ = Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout) in let ic = Unix.in_channel_of_descr client_sock in let oc = Unix.out_channel_of_descr client_sock in - let buf = Buf_.create() in - let is = Byte_stream.of_chan ic in + let buf = Buf_.create ~size:self.buf_size () in + let is = Byte_stream.of_chan ~buf_size:self.buf_size ic in let continue = ref true in while !continue && self.running do _debug (fun k->k "read next request"); @@ -981,7 +1046,10 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = let handler = match find_map (fun ph -> ph req) self.path_handlers with | Some f -> unwrap_resp_result f - | None -> (fun _oc req ~resp -> resp (self.handler (Request.read_body_full req))) + | None -> + (fun _oc req ~resp -> + let body_str = Request.read_body_full ~buf_size:self.buf_size req in + resp (self.handler body_str)) in (* handle expect/continue *) @@ -993,33 +1061,22 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = | None -> () end; - (* preprocess request's input stream *) - let req0, tr_stream = - List.fold_left - (fun (req,tr) cb -> - match cb req with - | None -> req, tr - | Some (r',f) -> r', (fun is -> tr is |> f)) - (req, (fun is->is)) self.cb_decode_req - in - (* now actually read request's body into a stream *) - let req = - Request.parse_body_ ~tr_stream ~buf {req0 with body=is} - |> unwrap_resp_result + (* apply middlewares *) + let handler = + fun oc -> + List.fold_right (fun (_, m) h -> m h) + (Lazy.force self.middlewares_sorted) (handler oc) in - (* how to post-process response accordingly *) - let post_process_resp resp = - List.fold_left - (fun resp cb -> match cb req0 resp with None -> resp | Some r' -> r') - resp self.cb_encode_resp + (* now actually read request's body into a stream *) + let req = + Request.parse_body_ ~tr_stream:(fun s->s) ~buf {req with body=is} + |> unwrap_resp_result in (* how to reply *) let resp r = - try - let r = post_process_resp r in - Response.output_ oc r + try Response.output_ oc r with Sys_error _ -> continue := false in diff --git a/src/Tiny_httpd.mli b/src/Tiny_httpd.mli index 21752f7c..0d59275c 100644 --- a/src/Tiny_httpd.mli +++ b/src/Tiny_httpd.mli @@ -129,10 +129,10 @@ module Byte_stream : sig val empty : t - val of_chan : in_channel -> t + val of_chan : ?buf_size:int -> in_channel -> t (** Make a buffered stream from the given channel. *) - val of_chan_close_noerr : in_channel -> t + val of_chan_close_noerr : ?buf_size:int -> in_channel -> t (** Same as {!of_chan} but the [close] method will never fail. *) val of_bytes : ?i:int -> ?len:int -> bytes -> t @@ -149,7 +149,7 @@ module Byte_stream : sig (** Write the stream to the channel. @since 0.3 *) - val with_file : string -> (t -> 'a) -> 'a + val with_file : ?buf_size:int -> string -> (t -> 'a) -> 'a (** Open a file with given name, and obtain an input stream on its content. When the function returns, the stream (and file) are closed. *) @@ -227,6 +227,7 @@ module Request : sig path_components: string list; query: (string*string) list; body: 'body; + start_time: float; (** @since NEXT_RELEASE *) } (** A request with method, path, host, headers, and a body, sent by a client. @@ -253,7 +254,16 @@ module Request : sig val get_header_int : _ t -> string -> int option - val set_header : 'a t -> string -> string -> 'a t + val set_header : string -> string -> 'a t -> 'a t + (** [set_header k v req] sets [k: v] in the request [req]'s headers. *) + + val update_headers : (Headers.t -> Headers.t) -> 'a t -> 'a t + (** Modify headers + @since 0.11 *) + + val set_body : 'a -> _ t -> 'a t + (** [set_body b req] returns a new query whose body is [b]. + @since 0.11 *) val host : _ t -> string (** Host field of the request. It also appears in the headers. *) @@ -271,14 +281,20 @@ module Request : sig val body : 'b t -> 'b (** Request body, possibly empty. *) + val start_time : _ t -> float + (** time stamp (from {!Unix.gettimeofday}) after parsing the first line of the request + @since NEXT_RELEASE *) + val limit_body_size : max_size:int -> byte_stream t -> byte_stream t (** Limit the body size to [max_size] bytes, or return a [413] error. @since 0.3 *) - val read_body_full : byte_stream t -> string t - (** Read the whole body into a string. Potentially blocking. *) + val read_body_full : ?buf_size:int -> byte_stream t -> string t + (** Read the whole body into a string. Potentially blocking. + + @param buf_size initial size of underlying buffer (since NEXT_RELEASE) *) (**/**) (* for testing purpose, do not use *) @@ -318,13 +334,33 @@ module Response : sig (** Body of a response, either as a simple string, or a stream of bytes, or nothing (for server-sent events). *) - type t = { + type t = private { code: Response_code.t; (** HTTP response code. See {!Response_code}. *) headers: Headers.t; (** Headers of the reply. Some will be set by [Tiny_httpd] automatically. *) body: body; (** Body of the response. Can be empty. *) } (** A response to send back to a client. *) + val set_body : body -> t -> t + (** Set the body of the response. + @since 0.11 *) + + val set_header : string -> string -> t -> t + (** Set a header. + @since 0.11 *) + + val update_headers : (Headers.t -> Headers.t) -> t -> t + (** Modify headers + @since 0.11 *) + + val set_headers : Headers.t -> t -> t + (** Set all headers. + @since 0.11 *) + + val set_code : Response_code.t -> t -> t + (** Set the response code. + @since 0.11 *) + val make_raw : ?headers:Headers.t -> code:Response_code.t -> @@ -426,6 +462,31 @@ module Route : sig @since 0.7 *) end +(** {2 Middlewares} + + A middleware can be inserted in a handler to modify or observe + its behavior. + + @since NEXT_RELEASE +*) +module Middleware : sig + type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit + (** Handlers are functions returning a response to a request. + The response can be delayed, hence the use of a continuation + as the [resp] parameter. *) + + type t = handler -> handler + (** A middleware is a handler transformation. + + It takes the existing handler [h], + and returns a new one which, given a query, modify it or log it + before passing it to [h], or fail. It can also log or modify or drop + the response. *) + + val nil : t + (** Trivial middleware that does nothing. *) +end + (** {2 Main Server type} *) type t @@ -435,10 +496,12 @@ val create : ?masksigpipe:bool -> ?max_connections:int -> ?timeout:float -> + ?buf_size:int -> ?new_thread:((unit -> unit) -> unit) -> ?addr:string -> ?port:int -> ?sock:Unix.file_descr -> + ?middlewares:([`Encoding | `Stage of int] * Middleware.t) list -> unit -> t (** Create a new webserver. @@ -450,10 +513,14 @@ val create : @param masksigpipe if true, block the signal {!Sys.sigpipe} which otherwise tends to kill client threads when they try to write on broken sockets. Default: [true]. + @param buf_size size for buffers (since NEXT_RELEASE) + @param new_thread a function used to spawn a new thread to handle a new client connection. By default it is {!Thread.create} but one could use a thread pool instead. + @param middlewares see {!add_middleware} for more details. + @param max_connections maximum number of simultaneous connections. @param timeout connection is closed if the socket does not do read or write for the amount of second. Default: 0.0 which means no timeout. @@ -482,20 +549,36 @@ val active_connections : t -> int val add_decode_request_cb : t -> (unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) -> unit +[@@deprecated "use add_middleware"] (** Add a callback for every request. The callback can provide a stream transformer and a new request (with modified headers, typically). A possible use is to handle decompression by looking for a [Transfer-Encoding] header and returning a stream transformer that decompresses on the fly. + + @deprecated use {!add_middleware} instead *) val add_encode_response_cb: t -> (unit Request.t -> Response.t -> Response.t option) -> unit +[@@deprecated "use add_middleware"] (** Add a callback for every request/response pair. Similarly to {!add_encode_response_cb} the callback can return a new response, for example to compress it. The callback is given the query with only its headers, as well as the current response. + + @deprecated use {!add_middleware} instead +*) + +val add_middleware : + stage:[`Encoding | `Stage of int] -> + t -> Middleware.t -> unit +(** Add a middleware to every request/response pair. + @param stage specify when middleware applies. + Encoding comes first (outermost layer), then stages in increasing order. + @raise Invalid_argument if stage is [`Stage n] where [n < 1] + @since NEXT_RELEASE *) (** {2 Request handlers} *) @@ -509,6 +592,7 @@ val set_top_handler : t -> (string Request.t -> Response.t) -> unit val add_route_handler : ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> ?meth:Meth.t -> t -> ('a, string Request.t -> Response.t) Route.t -> 'a -> @@ -534,6 +618,7 @@ val add_route_handler : val add_route_handler_stream : ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> ?meth:Meth.t -> t -> ('a, byte_stream Request.t -> Response.t) Route.t -> 'a -> diff --git a/src/camlzip/Tiny_httpd_camlzip.ml b/src/camlzip/Tiny_httpd_camlzip.ml index 0ff8c20e..217b878a 100644 --- a/src/camlzip/Tiny_httpd_camlzip.ml +++ b/src/camlzip/Tiny_httpd_camlzip.ml @@ -2,7 +2,7 @@ module S = Tiny_httpd module BS = Tiny_httpd.Byte_stream -let mk_decode_deflate_stream_ ~buf_size () (is:S.byte_stream) : S.byte_stream = +let decode_deflate_stream_ ~buf_size (is:S.byte_stream) : S.byte_stream = S._debug (fun k->k "wrap stream with deflate.decode"); let buf = Bytes.make buf_size ' ' in let buf_len = ref 0 in @@ -145,7 +145,8 @@ let has_deflate s = try Scanf.sscanf s "deflate, %s" (fun _ -> true) with _ -> false -let cb_decode_compressed_stream ~buf_size (req:unit S.Request.t) : _ option = +(* decompress [req]'s body if needed *) +let decompress_req_stream_ ~buf_size (req:BS.t S.Request.t) : _ S.Request.t = match S.Request.get_header ~f:String.trim req "Transfer-Encoding" with (* TODO | Some "gzip" -> @@ -155,49 +156,63 @@ let cb_decode_compressed_stream ~buf_size (req:unit S.Request.t) : _ option = | Some s when has_deflate s -> begin match Scanf.sscanf s "deflate, %s" (fun s -> s) with | tr' -> - let req' = S.Request.set_header req "Transfer-Encoding" tr' in - Some (req', mk_decode_deflate_stream_ ~buf_size ()) - | exception _ -> None + let body' = S.Request.body req |> decode_deflate_stream_ ~buf_size in + req + |> S.Request.set_header "Transfer-Encoding" tr' + |> S.Request.set_body body' + | exception _ -> req end - | _ -> None + | _ -> req -let cb_encode_compressed_stream +let compress_resp_stream_ ~compress_above - ~buf_size (req:_ S.Request.t) (resp:S.Response.t) : _ option = + ~buf_size + (req:_ S.Request.t) (resp:S.Response.t) : S.Response.t = + + (* headers for compressed stream *) + let update_headers h = + h + |> S.Headers.remove "Content-Length" + |> S.Headers.set "Content-Encoding" "deflate" + in + if accept_deflate req then ( - let set_headers h = - h - |> S.Headers.remove "Content-Length" - |> S.Headers.set "Content-Encoding" "deflate" - in match resp.body with | `String s when String.length s > compress_above -> + (* big string, we compress *) S._debug (fun k->k "encode str response with deflate (size %d, threshold %d)" (String.length s) compress_above); let body = encode_deflate_stream_ ~buf_size @@ S.Byte_stream.of_string s in - Some { - resp with - headers=set_headers resp.headers; body=`Stream body; - } + resp + |> S.Response.update_headers update_headers + |> S.Response.set_body (`Stream body) + | `Stream str -> S._debug (fun k->k "encode stream response with deflate"); - Some { - resp with - headers= set_headers resp.headers; - body=`Stream (encode_deflate_stream_ ~buf_size str); - } - | `String _ | `Void -> None - ) else None + resp + |> S.Response.update_headers update_headers + |> S.Response.set_body (`Stream (encode_deflate_stream_ ~buf_size str)) + + | `String _ | `Void -> resp + ) else resp + +let middleware + ?(compress_above=16 * 1024) + ?(buf_size=16 * 1_024) + () : S.Middleware.t = + let buf_size = max buf_size 1_024 in + fun h req ~resp -> + let req = decompress_req_stream_ ~buf_size req in + h req + ~resp:(fun response -> + resp @@ compress_resp_stream_ ~buf_size ~compress_above req response) let setup - ?(compress_above=500*1024) - ?(buf_size=48 * 1_024) (server:S.t) : unit = - let buf_size = max buf_size 1_024 in - S._debug (fun k->k "setup gzip support (buf-size %d)" buf_size); - S.add_decode_request_cb server (cb_decode_compressed_stream ~buf_size); - S.add_encode_response_cb server (cb_encode_compressed_stream ~compress_above ~buf_size); - () + ?compress_above ?buf_size server = + let m = middleware ?compress_above ?buf_size () in + S._debug (fun k->k "setup gzip support"); + S.add_middleware ~stage:`Encoding server m diff --git a/src/camlzip/Tiny_httpd_camlzip.mli b/src/camlzip/Tiny_httpd_camlzip.mli index dd2e3cb6..d086e8e6 100644 --- a/src/camlzip/Tiny_httpd_camlzip.mli +++ b/src/camlzip/Tiny_httpd_camlzip.mli @@ -1,8 +1,13 @@ +val middleware : + ?compress_above:int -> + ?buf_size:int -> unit -> + Tiny_httpd.Middleware.t + val setup : ?compress_above:int -> ?buf_size:int -> Tiny_httpd.t -> unit -(** Install callbacks for tiny_httpd to be able to encode/decode +(** Install middleware for tiny_httpd to be able to encode/decode compressed streams @param compress_above threshold above with string responses are compressed @param buf_size size of the underlying buffer for compression/decompression *) diff --git a/tests/upload_chunked.sh b/tests/upload_chunked.sh index 9cbd2d7d..a574798d 100755 --- a/tests/upload_chunked.sh +++ b/tests/upload_chunked.sh @@ -1,6 +1,6 @@ #!/usr/bin/env sh -rm data +if [ -f data ]; then rm data ; fi SERVER=$1 PORT=8087