diff --git a/examples/echo.ml b/examples/echo.ml
index 3b7c7700..f77bbe72 100644
--- a/examples/echo.ml
+++ b/examples/echo.ml
@@ -1,6 +1,39 @@
module S = Tiny_httpd
+let now_ = Unix.gettimeofday
+
+(* util: a little middleware collecting statistics *)
+let middleware_stat () : S.Middleware.t * (unit -> string) =
+ let n_req = ref 0 in
+ let total_time_ = ref 0. in
+ let parse_time_ = ref 0. in
+ let build_time_ = ref 0. in
+ let write_time_ = ref 0. in
+
+ let m h req ~resp =
+ incr n_req;
+ let t1 = S.Request.start_time req in
+ let t2 = now_ () in
+ h req ~resp:(fun response ->
+ let t3 = now_ () in
+ resp response;
+ let t4 = now_ () in
+ total_time_ := !total_time_ +. (t4 -. t1);
+ parse_time_ := !parse_time_ +. (t2 -. t1);
+ build_time_ := !build_time_ +. (t3 -. t2);
+ write_time_ := !write_time_ +. (t4 -. t3);
+ )
+ and get_stat () =
+ Printf.sprintf "%d requests (average response time: %.3fms = %.3fms + %.3fms + %.3fms)"
+ !n_req (!total_time_ /. float !n_req *. 1e3)
+ (!parse_time_ /. float !n_req *. 1e3)
+ (!build_time_ /. float !n_req *. 1e3)
+ (!write_time_ /. float !n_req *. 1e3)
+ in
+ m, get_stat
+
+
let () =
let port_ = ref 8080 in
let j = ref 32 in
@@ -10,12 +43,19 @@ let () =
"--debug", Arg.Unit (fun () -> S._enable_debug true), " enable debug";
"-j", Arg.Set_int j, " maximum number of connections";
]) (fun _ -> raise (Arg.Bad "")) "echo [option]*";
+
let server = S.create ~port:!port_ ~max_connections:!j () in
- Tiny_httpd_camlzip.setup ~compress_above:1024 ~buf_size:(1024*1024) server;
+ Tiny_httpd_camlzip.setup ~compress_above:1024 ~buf_size:(16*1024) server;
+
+ let m_stats, get_stats = middleware_stat () in
+ S.add_middleware server ~stage:(`Stage 1) m_stats;
+
(* say hello *)
S.add_route_handler ~meth:`GET server
S.Route.(exact "hello" @/ string @/ return)
(fun name _req -> S.Response.make_string (Ok ("hello " ^name ^"!\n")));
+
+ (* compressed file access *)
S.add_route_handler ~meth:`GET server
S.Route.(exact "zcat" @/ string_urlencoded @/ return)
(fun path _req ->
@@ -33,6 +73,7 @@ let () =
in
S.Response.make_stream ~headers:mime_type (Ok str)
);
+
(* echo request *)
S.add_route_handler server
S.Route.(exact "echo" @/ return)
@@ -43,6 +84,8 @@ let () =
in
S.Response.make_string
(Ok (Format.asprintf "echo:@ %a@ (query: %s)@." S.Request.pp req q)));
+
+ (* file upload *)
S.add_route_handler_stream ~meth:`PUT server
S.Route.(exact "upload" @/ string @/ return)
(fun path req ->
@@ -56,6 +99,28 @@ let () =
with e ->
S.Response.fail ~code:500 "couldn't upload file: %s" (Printexc.to_string e)
);
+
+ (* stats *)
+ S.add_route_handler server S.Route.(exact "stats" @/ return)
+ (fun _req ->
+ let stats = get_stats() in
+ S.Response.make_string @@ Ok stats
+ );
+
+ (* main page *)
+ S.add_route_handler server S.Route.(return)
+ (fun _req ->
+ let s = "
\n\
+ welcome!\n
endpoints are:\n
\
+ /hello/'name' (GET)
\n\
+ /echo/ (GET) echoes back query
\n\
+ /upload/'path' (PUT) to upload a file
\n\
+ /zcat/'path' (GET) to download a file (compressed)
\n\
+ /stats/ (GET) to access statistics
\n\
+
"
+ in
+ S.Response.make_string ~headers:["content-type", "text/html"] @@ Ok s);
+
Printf.printf "listening on http://%s:%d\n%!" (S.addr server) (S.port server);
match S.run server with
| Ok () -> ()
diff --git a/src/Tiny_httpd.ml b/src/Tiny_httpd.ml
index c0d92bfd..b0e52dbb 100644
--- a/src/Tiny_httpd.ml
+++ b/src/Tiny_httpd.ml
@@ -66,10 +66,10 @@ module Byte_stream = struct
bs_close=(fun () -> ());
}
- let of_chan_ ~close ic : t =
+ let of_chan_ ?(buf_size=16 * 1024) ~close ic : t =
let i = ref 0 in
let len = ref 0 in
- let buf = Bytes.make 4096 ' ' in
+ let buf = Bytes.make buf_size ' ' in
{ bs_fill_buf=(fun () ->
if !i >= !len then (
i := 0;
@@ -116,10 +116,10 @@ module Byte_stream = struct
let of_string s : t =
of_bytes (Bytes.unsafe_of_string s)
- let with_file file f =
+ let with_file ?buf_size file f =
let ic = open_in file in
try
- let x = f (of_chan ic) in
+ let x = f (of_chan ?buf_size ic) in
close_in ic;
x
with e ->
@@ -367,6 +367,7 @@ module Request = struct
path_components: string list;
query: (string*string) list;
body: 'body;
+ start_time: float;
}
let headers self = self.headers
@@ -374,13 +375,16 @@ module Request = struct
let meth self = self.meth
let path self = self.path
let body self = self.body
+ let start_time self = self.start_time
let query self = self.query
let get_header ?f self h = Headers.get ?f h self.headers
let get_header_int self h = match get_header self h with
| Some x -> (try Some (int_of_string x) with _ -> None)
| None -> None
- let set_header self k v = {self with headers=Headers.set k v self.headers}
+ let set_header k v self = {self with headers=Headers.set k v self.headers}
+ let update_headers f self = {self with headers=f self.headers}
+ let set_body b self = {self with body=b}
let pp_comp_ out comp =
Format.fprintf out "[%s]"
@@ -481,6 +485,7 @@ module Request = struct
let parse_req_start ~buf (bs:byte_stream) : unit t option resp_result =
try
let line = Byte_stream.read_line ~buf bs in
+ let start_time = Unix.gettimeofday () in
let meth, path =
try
let m, p, v = Scanf.sscanf line "%s %s HTTP/1.%d\r" (fun x y z->x,y,z) in
@@ -506,7 +511,7 @@ module Request = struct
| Error e -> bad_reqf 400 "invalid query: %s" e
in
Ok (Some {meth; query; host; path; path_components;
- headers; body=()})
+ headers; body=(); start_time; })
with
| End_of_file | Sys_error _ -> Ok None
| Bad_req (c,s) -> Error (c,s)
@@ -540,9 +545,10 @@ module Request = struct
| e ->
Error (400, Printexc.to_string e)
- let read_body_full (self:byte_stream t) : string t =
+ let read_body_full ?buf_size (self:byte_stream t) : string t =
try
- let body = Byte_stream.read_all self.body in
+ let buf = Buf_.create ?size:buf_size () in
+ let body = Byte_stream.read_all ~buf self.body in
{ self with body }
with
| Bad_req _ as e -> raise e
@@ -581,6 +587,12 @@ module Response = struct
body: body;
}
+ let set_body body self = {self with body}
+ let set_headers headers self = {self with headers}
+ let update_headers f self = {self with headers=f self.headers}
+ let set_header k v self = {self with headers = Headers.set k v self.headers}
+ let set_code code self = {self with code}
+
let make_raw ?(headers=[]) ~code body : t =
(* add content length to response *)
let headers =
@@ -787,12 +799,21 @@ module Route = struct
let pp out x = Format.pp_print_string out (to_string x)
end
+module Middleware = struct
+ type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit
+ type t = handler -> handler
+
+ (** Apply a list of middlewares to [h] *)
+ let apply_l (l:t list) (h:handler) : handler =
+ List.fold_right (fun m h -> m h) l h
+
+ let[@inline] nil : t = fun h -> h
+end
+
(* a request handler. handles a single request. *)
type cb_path_handler =
out_channel ->
- byte_stream Request.t ->
- resp:(Response.t -> unit) ->
- unit
+ Middleware.handler
module type SERVER_SENT_GENERATOR = sig
val set_headers : Headers.t -> unit
@@ -823,19 +844,20 @@ type t = {
masksigpipe: bool;
+ buf_size: int;
+
mutable handler: (string Request.t -> Response.t);
(* toplevel handler, if any *)
+ mutable middlewares : (int * Middleware.t) list;
+ (** Global middlewares *)
+
+ mutable middlewares_sorted : (int * Middleware.t) list lazy_t;
+ (* sorted version of {!middlewares} *)
+
mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list;
(* path handlers *)
- mutable cb_decode_req:
- (unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) list;
- (* middleware to decode requests *)
-
- mutable cb_encode_resp: (unit Request.t -> Response.t -> Response.t option) list;
- (* middleware to encode responses *)
-
mutable running: bool;
(* true while the server is running. no need to protect with a mutex,
writes should be atomic enough. *)
@@ -846,15 +868,48 @@ let port self = self.port
let active_connections self = Sem_.num_acquired self.sem_max_connections - 1
-let add_decode_request_cb self f = self.cb_decode_req <- f :: self.cb_decode_req
-let add_encode_response_cb self f = self.cb_encode_resp <- f :: self.cb_encode_resp
+let add_middleware ~stage self m =
+ let stage = match stage with
+ | `Encoding -> 0
+ | `Stage n when n < 1 -> invalid_arg "add_middleware: bad stage"
+ | `Stage n -> n
+ in
+ self.middlewares <- (stage,m) :: self.middlewares;
+ self.middlewares_sorted <- lazy (
+ List.stable_sort (fun (s1,_) (s2,_) -> compare s1 s2) self.middlewares
+ )
+
+let add_decode_request_cb self f =
+ (* turn it into a middleware *)
+ let m h req ~resp =
+ (* see if [f] modifies the stream *)
+ let req0 = {req with Request.body=()} in
+ match f req0 with
+ | None -> h req ~resp (* pass through *)
+ | Some (req1, tr_stream) ->
+ let req = {req1 with Request.body=tr_stream req.Request.body} in
+ h req ~resp
+ in
+ add_middleware self ~stage:`Encoding m
+
+let add_encode_response_cb self f =
+ let m h req ~resp =
+ h req ~resp:(fun r ->
+ let req0 = {req with Request.body=()} in
+ (* now transform [r] if we want to *)
+ match f req0 r with
+ | None -> resp r
+ | Some r' -> resp r')
+ in
+ add_middleware self ~stage:`Encoding m
+
let set_top_handler self f = self.handler <- f
(* route the given handler.
@param tr_req wraps the actual concrete function returned by the route
and makes it into a handler. *)
let add_route_handler_
- ?(accept=fun _req -> Ok ())
+ ?(accept=fun _req -> Ok ()) ?(middlewares=[])
?meth ~tr_req self (route:_ Route.t) f =
let ph req : cb_path_handler resp_result option =
match meth with
@@ -864,7 +919,10 @@ let add_route_handler_
| Some handler ->
(* we have a handler, do we accept the request based on its headers? *)
begin match accept req with
- | Ok () -> Some (Ok (fun oc req ~resp -> tr_req oc req ~resp handler))
+ | Ok () ->
+ Some (Ok (fun oc ->
+ Middleware.apply_l middlewares @@
+ fun req ~resp -> tr_req oc req ~resp handler))
| Error _ as e -> Some e
end
| None ->
@@ -873,13 +931,14 @@ let add_route_handler_
in
self.path_handlers <- ph :: self.path_handlers
-let add_route_handler (type a) ?accept ?meth self (route:(a,_) Route.t) (f:_) : unit =
- let tr_req _oc req ~resp f = resp (f (Request.read_body_full req)) in
- add_route_handler_ ?accept ?meth self route ~tr_req f
+let add_route_handler (type a) ?accept ?middlewares ?meth
+ self (route:(a,_) Route.t) (f:_) : unit =
+ let tr_req _oc req ~resp f = resp (f (Request.read_body_full ~buf_size:self.buf_size req)) in
+ add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f
-let add_route_handler_stream ?accept ?meth self route f =
+let add_route_handler_stream ?accept ?middlewares ?meth self route f =
let tr_req _oc req ~resp f = resp (f req) in
- add_route_handler_ ?accept ?meth self route ~tr_req f
+ add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f
let[@inline] _opt_iter ~f o = match o with
| None -> ()
@@ -887,7 +946,7 @@ let[@inline] _opt_iter ~f o = match o with
let add_route_server_sent_handler ?accept self route f =
let tr_req oc req ~resp f =
- let req = Request.read_body_full req in
+ let req = Request.read_body_full ~buf_size:self.buf_size req in
let headers = ref Headers.(empty |> set "content-type" "text/event-stream") in
(* send response once *)
@@ -929,15 +988,21 @@ let create
?(masksigpipe=true)
?(max_connections=32)
?(timeout=0.0)
+ ?(buf_size=16 * 1_024)
?(new_thread=(fun f -> ignore (Thread.create f () : Thread.t)))
- ?(addr="127.0.0.1") ?(port=8080) ?sock () : t =
+ ?(addr="127.0.0.1") ?(port=8080) ?sock
+ ?(middlewares=[])
+ () : t =
let handler _req = Response.fail ~code:404 "no top handler" in
let max_connections = max 4 max_connections in
- { new_thread; addr; port; sock; masksigpipe; handler;
+ let self = {
+ new_thread; addr; port; sock; masksigpipe; handler; buf_size;
running= true; sem_max_connections=Sem_.create max_connections;
path_handlers=[]; timeout;
- cb_encode_resp=[]; cb_decode_req=[];
- }
+ middlewares=[]; middlewares_sorted=lazy [];
+ } in
+ List.iter (fun (stage,m) -> add_middleware self ~stage m) middlewares;
+ self
let stop s = s.running <- false
@@ -955,8 +1020,8 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
let _ = Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout) in
let ic = Unix.in_channel_of_descr client_sock in
let oc = Unix.out_channel_of_descr client_sock in
- let buf = Buf_.create() in
- let is = Byte_stream.of_chan ic in
+ let buf = Buf_.create ~size:self.buf_size () in
+ let is = Byte_stream.of_chan ~buf_size:self.buf_size ic in
let continue = ref true in
while !continue && self.running do
_debug (fun k->k "read next request");
@@ -981,7 +1046,10 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
let handler =
match find_map (fun ph -> ph req) self.path_handlers with
| Some f -> unwrap_resp_result f
- | None -> (fun _oc req ~resp -> resp (self.handler (Request.read_body_full req)))
+ | None ->
+ (fun _oc req ~resp ->
+ let body_str = Request.read_body_full ~buf_size:self.buf_size req in
+ resp (self.handler body_str))
in
(* handle expect/continue *)
@@ -993,33 +1061,22 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
| None -> ()
end;
- (* preprocess request's input stream *)
- let req0, tr_stream =
- List.fold_left
- (fun (req,tr) cb ->
- match cb req with
- | None -> req, tr
- | Some (r',f) -> r', (fun is -> tr is |> f))
- (req, (fun is->is)) self.cb_decode_req
- in
- (* now actually read request's body into a stream *)
- let req =
- Request.parse_body_ ~tr_stream ~buf {req0 with body=is}
- |> unwrap_resp_result
+ (* apply middlewares *)
+ let handler =
+ fun oc ->
+ List.fold_right (fun (_, m) h -> m h)
+ (Lazy.force self.middlewares_sorted) (handler oc)
in
- (* how to post-process response accordingly *)
- let post_process_resp resp =
- List.fold_left
- (fun resp cb -> match cb req0 resp with None -> resp | Some r' -> r')
- resp self.cb_encode_resp
+ (* now actually read request's body into a stream *)
+ let req =
+ Request.parse_body_ ~tr_stream:(fun s->s) ~buf {req with body=is}
+ |> unwrap_resp_result
in
(* how to reply *)
let resp r =
- try
- let r = post_process_resp r in
- Response.output_ oc r
+ try Response.output_ oc r
with Sys_error _ -> continue := false
in
diff --git a/src/Tiny_httpd.mli b/src/Tiny_httpd.mli
index 21752f7c..0d59275c 100644
--- a/src/Tiny_httpd.mli
+++ b/src/Tiny_httpd.mli
@@ -129,10 +129,10 @@ module Byte_stream : sig
val empty : t
- val of_chan : in_channel -> t
+ val of_chan : ?buf_size:int -> in_channel -> t
(** Make a buffered stream from the given channel. *)
- val of_chan_close_noerr : in_channel -> t
+ val of_chan_close_noerr : ?buf_size:int -> in_channel -> t
(** Same as {!of_chan} but the [close] method will never fail. *)
val of_bytes : ?i:int -> ?len:int -> bytes -> t
@@ -149,7 +149,7 @@ module Byte_stream : sig
(** Write the stream to the channel.
@since 0.3 *)
- val with_file : string -> (t -> 'a) -> 'a
+ val with_file : ?buf_size:int -> string -> (t -> 'a) -> 'a
(** Open a file with given name, and obtain an input stream
on its content. When the function returns, the stream (and file) are closed. *)
@@ -227,6 +227,7 @@ module Request : sig
path_components: string list;
query: (string*string) list;
body: 'body;
+ start_time: float; (** @since NEXT_RELEASE *)
}
(** A request with method, path, host, headers, and a body, sent by a client.
@@ -253,7 +254,16 @@ module Request : sig
val get_header_int : _ t -> string -> int option
- val set_header : 'a t -> string -> string -> 'a t
+ val set_header : string -> string -> 'a t -> 'a t
+ (** [set_header k v req] sets [k: v] in the request [req]'s headers. *)
+
+ val update_headers : (Headers.t -> Headers.t) -> 'a t -> 'a t
+ (** Modify headers
+ @since 0.11 *)
+
+ val set_body : 'a -> _ t -> 'a t
+ (** [set_body b req] returns a new query whose body is [b].
+ @since 0.11 *)
val host : _ t -> string
(** Host field of the request. It also appears in the headers. *)
@@ -271,14 +281,20 @@ module Request : sig
val body : 'b t -> 'b
(** Request body, possibly empty. *)
+ val start_time : _ t -> float
+ (** time stamp (from {!Unix.gettimeofday}) after parsing the first line of the request
+ @since NEXT_RELEASE *)
+
val limit_body_size : max_size:int -> byte_stream t -> byte_stream t
(** Limit the body size to [max_size] bytes, or return
a [413] error.
@since 0.3
*)
- val read_body_full : byte_stream t -> string t
- (** Read the whole body into a string. Potentially blocking. *)
+ val read_body_full : ?buf_size:int -> byte_stream t -> string t
+ (** Read the whole body into a string. Potentially blocking.
+
+ @param buf_size initial size of underlying buffer (since NEXT_RELEASE) *)
(**/**)
(* for testing purpose, do not use *)
@@ -318,13 +334,33 @@ module Response : sig
(** Body of a response, either as a simple string,
or a stream of bytes, or nothing (for server-sent events). *)
- type t = {
+ type t = private {
code: Response_code.t; (** HTTP response code. See {!Response_code}. *)
headers: Headers.t; (** Headers of the reply. Some will be set by [Tiny_httpd] automatically. *)
body: body; (** Body of the response. Can be empty. *)
}
(** A response to send back to a client. *)
+ val set_body : body -> t -> t
+ (** Set the body of the response.
+ @since 0.11 *)
+
+ val set_header : string -> string -> t -> t
+ (** Set a header.
+ @since 0.11 *)
+
+ val update_headers : (Headers.t -> Headers.t) -> t -> t
+ (** Modify headers
+ @since 0.11 *)
+
+ val set_headers : Headers.t -> t -> t
+ (** Set all headers.
+ @since 0.11 *)
+
+ val set_code : Response_code.t -> t -> t
+ (** Set the response code.
+ @since 0.11 *)
+
val make_raw :
?headers:Headers.t ->
code:Response_code.t ->
@@ -426,6 +462,31 @@ module Route : sig
@since 0.7 *)
end
+(** {2 Middlewares}
+
+ A middleware can be inserted in a handler to modify or observe
+ its behavior.
+
+ @since NEXT_RELEASE
+*)
+module Middleware : sig
+ type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit
+ (** Handlers are functions returning a response to a request.
+ The response can be delayed, hence the use of a continuation
+ as the [resp] parameter. *)
+
+ type t = handler -> handler
+ (** A middleware is a handler transformation.
+
+ It takes the existing handler [h],
+ and returns a new one which, given a query, modify it or log it
+ before passing it to [h], or fail. It can also log or modify or drop
+ the response. *)
+
+ val nil : t
+ (** Trivial middleware that does nothing. *)
+end
+
(** {2 Main Server type} *)
type t
@@ -435,10 +496,12 @@ val create :
?masksigpipe:bool ->
?max_connections:int ->
?timeout:float ->
+ ?buf_size:int ->
?new_thread:((unit -> unit) -> unit) ->
?addr:string ->
?port:int ->
?sock:Unix.file_descr ->
+ ?middlewares:([`Encoding | `Stage of int] * Middleware.t) list ->
unit ->
t
(** Create a new webserver.
@@ -450,10 +513,14 @@ val create :
@param masksigpipe if true, block the signal {!Sys.sigpipe} which otherwise
tends to kill client threads when they try to write on broken sockets. Default: [true].
+ @param buf_size size for buffers (since NEXT_RELEASE)
+
@param new_thread a function used to spawn a new thread to handle a
new client connection. By default it is {!Thread.create} but one
could use a thread pool instead.
+ @param middlewares see {!add_middleware} for more details.
+
@param max_connections maximum number of simultaneous connections.
@param timeout connection is closed if the socket does not do read or
write for the amount of second. Default: 0.0 which means no timeout.
@@ -482,20 +549,36 @@ val active_connections : t -> int
val add_decode_request_cb :
t ->
(unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) -> unit
+[@@deprecated "use add_middleware"]
(** Add a callback for every request.
The callback can provide a stream transformer and a new request (with
modified headers, typically).
A possible use is to handle decompression by looking for a [Transfer-Encoding]
header and returning a stream transformer that decompresses on the fly.
+
+ @deprecated use {!add_middleware} instead
*)
val add_encode_response_cb:
t -> (unit Request.t -> Response.t -> Response.t option) -> unit
+[@@deprecated "use add_middleware"]
(** Add a callback for every request/response pair.
Similarly to {!add_encode_response_cb} the callback can return a new
response, for example to compress it.
The callback is given the query with only its headers,
as well as the current response.
+
+ @deprecated use {!add_middleware} instead
+*)
+
+val add_middleware :
+ stage:[`Encoding | `Stage of int] ->
+ t -> Middleware.t -> unit
+(** Add a middleware to every request/response pair.
+ @param stage specify when middleware applies.
+ Encoding comes first (outermost layer), then stages in increasing order.
+ @raise Invalid_argument if stage is [`Stage n] where [n < 1]
+ @since NEXT_RELEASE
*)
(** {2 Request handlers} *)
@@ -509,6 +592,7 @@ val set_top_handler : t -> (string Request.t -> Response.t) -> unit
val add_route_handler :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
+ ?middlewares:Middleware.t list ->
?meth:Meth.t ->
t ->
('a, string Request.t -> Response.t) Route.t -> 'a ->
@@ -534,6 +618,7 @@ val add_route_handler :
val add_route_handler_stream :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
+ ?middlewares:Middleware.t list ->
?meth:Meth.t ->
t ->
('a, byte_stream Request.t -> Response.t) Route.t -> 'a ->
diff --git a/src/camlzip/Tiny_httpd_camlzip.ml b/src/camlzip/Tiny_httpd_camlzip.ml
index 0ff8c20e..217b878a 100644
--- a/src/camlzip/Tiny_httpd_camlzip.ml
+++ b/src/camlzip/Tiny_httpd_camlzip.ml
@@ -2,7 +2,7 @@
module S = Tiny_httpd
module BS = Tiny_httpd.Byte_stream
-let mk_decode_deflate_stream_ ~buf_size () (is:S.byte_stream) : S.byte_stream =
+let decode_deflate_stream_ ~buf_size (is:S.byte_stream) : S.byte_stream =
S._debug (fun k->k "wrap stream with deflate.decode");
let buf = Bytes.make buf_size ' ' in
let buf_len = ref 0 in
@@ -145,7 +145,8 @@ let has_deflate s =
try Scanf.sscanf s "deflate, %s" (fun _ -> true)
with _ -> false
-let cb_decode_compressed_stream ~buf_size (req:unit S.Request.t) : _ option =
+(* decompress [req]'s body if needed *)
+let decompress_req_stream_ ~buf_size (req:BS.t S.Request.t) : _ S.Request.t =
match S.Request.get_header ~f:String.trim req "Transfer-Encoding" with
(* TODO
| Some "gzip" ->
@@ -155,49 +156,63 @@ let cb_decode_compressed_stream ~buf_size (req:unit S.Request.t) : _ option =
| Some s when has_deflate s ->
begin match Scanf.sscanf s "deflate, %s" (fun s -> s) with
| tr' ->
- let req' = S.Request.set_header req "Transfer-Encoding" tr' in
- Some (req', mk_decode_deflate_stream_ ~buf_size ())
- | exception _ -> None
+ let body' = S.Request.body req |> decode_deflate_stream_ ~buf_size in
+ req
+ |> S.Request.set_header "Transfer-Encoding" tr'
+ |> S.Request.set_body body'
+ | exception _ -> req
end
- | _ -> None
+ | _ -> req
-let cb_encode_compressed_stream
+let compress_resp_stream_
~compress_above
- ~buf_size (req:_ S.Request.t) (resp:S.Response.t) : _ option =
+ ~buf_size
+ (req:_ S.Request.t) (resp:S.Response.t) : S.Response.t =
+
+ (* headers for compressed stream *)
+ let update_headers h =
+ h
+ |> S.Headers.remove "Content-Length"
+ |> S.Headers.set "Content-Encoding" "deflate"
+ in
+
if accept_deflate req then (
- let set_headers h =
- h
- |> S.Headers.remove "Content-Length"
- |> S.Headers.set "Content-Encoding" "deflate"
- in
match resp.body with
| `String s when String.length s > compress_above ->
+ (* big string, we compress *)
S._debug
(fun k->k "encode str response with deflate (size %d, threshold %d)"
(String.length s) compress_above);
let body =
encode_deflate_stream_ ~buf_size @@ S.Byte_stream.of_string s
in
- Some {
- resp with
- headers=set_headers resp.headers; body=`Stream body;
- }
+ resp
+ |> S.Response.update_headers update_headers
+ |> S.Response.set_body (`Stream body)
+
| `Stream str ->
S._debug (fun k->k "encode stream response with deflate");
- Some {
- resp with
- headers= set_headers resp.headers;
- body=`Stream (encode_deflate_stream_ ~buf_size str);
- }
- | `String _ | `Void -> None
- ) else None
+ resp
+ |> S.Response.update_headers update_headers
+ |> S.Response.set_body (`Stream (encode_deflate_stream_ ~buf_size str))
+
+ | `String _ | `Void -> resp
+ ) else resp
+
+let middleware
+ ?(compress_above=16 * 1024)
+ ?(buf_size=16 * 1_024)
+ () : S.Middleware.t =
+ let buf_size = max buf_size 1_024 in
+ fun h req ~resp ->
+ let req = decompress_req_stream_ ~buf_size req in
+ h req
+ ~resp:(fun response ->
+ resp @@ compress_resp_stream_ ~buf_size ~compress_above req response)
let setup
- ?(compress_above=500*1024)
- ?(buf_size=48 * 1_024) (server:S.t) : unit =
- let buf_size = max buf_size 1_024 in
- S._debug (fun k->k "setup gzip support (buf-size %d)" buf_size);
- S.add_decode_request_cb server (cb_decode_compressed_stream ~buf_size);
- S.add_encode_response_cb server (cb_encode_compressed_stream ~compress_above ~buf_size);
- ()
+ ?compress_above ?buf_size server =
+ let m = middleware ?compress_above ?buf_size () in
+ S._debug (fun k->k "setup gzip support");
+ S.add_middleware ~stage:`Encoding server m
diff --git a/src/camlzip/Tiny_httpd_camlzip.mli b/src/camlzip/Tiny_httpd_camlzip.mli
index dd2e3cb6..d086e8e6 100644
--- a/src/camlzip/Tiny_httpd_camlzip.mli
+++ b/src/camlzip/Tiny_httpd_camlzip.mli
@@ -1,8 +1,13 @@
+val middleware :
+ ?compress_above:int ->
+ ?buf_size:int -> unit ->
+ Tiny_httpd.Middleware.t
+
val setup :
?compress_above:int ->
?buf_size:int -> Tiny_httpd.t -> unit
-(** Install callbacks for tiny_httpd to be able to encode/decode
+(** Install middleware for tiny_httpd to be able to encode/decode
compressed streams
@param compress_above threshold above with string responses are compressed
@param buf_size size of the underlying buffer for compression/decompression *)
diff --git a/tests/upload_chunked.sh b/tests/upload_chunked.sh
index 9cbd2d7d..a574798d 100755
--- a/tests/upload_chunked.sh
+++ b/tests/upload_chunked.sh
@@ -1,6 +1,6 @@
#!/usr/bin/env sh
-rm data
+if [ -f data ]; then rm data ; fi
SERVER=$1
PORT=8087