feat: allow handlers to take streams

close #9
This commit is contained in:
Simon Cruanes 2019-12-02 17:17:46 -06:00
parent 1a657515d9
commit 45bc589e00
3 changed files with 132 additions and 54 deletions

View file

@ -59,6 +59,12 @@ module Byte_stream = struct
let close self = self.bs_close() let close self = self.bs_close()
let empty = {
bs_fill_buf=(fun () -> Bytes.empty, 0, 0);
bs_consume=(fun _ -> ());
bs_close=(fun () -> ());
}
let of_chan_ ~close ic : t = let of_chan_ ~close ic : t =
let i = ref 0 in let i = ref 0 in
let len = ref 0 in let len = ref 0 in
@ -76,7 +82,19 @@ module Byte_stream = struct
let of_chan = of_chan_ ~close:close_in let of_chan = of_chan_ ~close:close_in
let of_chan_close_noerr = of_chan_ ~close:close_in_noerr let of_chan_close_noerr = of_chan_ ~close:close_in_noerr
let rec iter f (self:t) : unit =
let s, i, len = self.bs_fill_buf () in
if len > 0 then (
f s i len;
self.bs_consume len;
(iter [@tailcall]) f self
)
let to_chan (oc:out_channel) (self:t) =
iter (fun s i len -> output oc s i len) self
let of_bytes ?(i=0) ?len s : t = let of_bytes ?(i=0) ?len s : t =
(* invariant: !i+!len is constant *)
let len = let len =
ref ( ref (
match len with match len with
@ -87,9 +105,12 @@ module Byte_stream = struct
let i = ref i in let i = ref i in
{ bs_fill_buf=(fun () -> s, !i, !len); { bs_fill_buf=(fun () -> s, !i, !len);
bs_close=(fun () -> ()); bs_close=(fun () -> ());
bs_consume=(fun n -> i := !i + n; len := !len - n); bs_consume=(fun n -> assert (n<= !len); i := !i + n; len := !len - n);
} }
let of_string s : t =
of_bytes (Bytes.unsafe_of_string s)
let with_file file f = let with_file file f =
let ic = open_in file in let ic = open_in file in
try try
@ -103,11 +124,11 @@ module Byte_stream = struct
(* Read as much as possible into [buf]. *) (* Read as much as possible into [buf]. *)
let read_into_buf (self:t) (buf:Buf_.t) : int = let read_into_buf (self:t) (buf:Buf_.t) : int =
let s, i, len = self.bs_fill_buf () in let s, i, len = self.bs_fill_buf () in
if len > 0 then ( if len > 0 then (
Buf_.add_bytes buf s i len; Buf_.add_bytes buf s i len;
self.bs_consume len; self.bs_consume len;
); );
len len
let read_all ?(buf=Buf_.create()) (self:t) : string = let read_all ?(buf=Buf_.create()) (self:t) : string =
let continue = ref true in let continue = ref true in
@ -283,12 +304,6 @@ module Request = struct
(Meth.to_string self.meth) self.host Headers.pp self.headers (Meth.to_string self.meth) self.host Headers.pp self.headers
self.path self.body self.path self.body
let read_body_exact (bs:byte_stream) (n:int) : string =
let bytes = Bytes.make n ' ' in
Byte_stream.read_exactly_ bs bytes n
~too_short:(fun () -> bad_reqf 400 "body is too short");
Bytes.unsafe_to_string bytes
(* decode a "chunked" stream into a normal stream *) (* decode a "chunked" stream into a normal stream *)
let read_stream_chunked_ ?(buf=Buf_.create()) (bs:byte_stream) : byte_stream = let read_stream_chunked_ ?(buf=Buf_.create()) (bs:byte_stream) : byte_stream =
let read_next_chunk_len () : int = let read_next_chunk_len () : int =
@ -313,6 +328,7 @@ module Request = struct
if !offset >= !len then ( if !offset >= !len then (
if !chunk_size = 0 && !refill then ( if !chunk_size = 0 && !refill then (
chunk_size := read_next_chunk_len(); chunk_size := read_next_chunk_len();
_debug (fun k->k"read next chunk of size %d" !chunk_size);
); );
offset := 0; offset := 0;
len := 0; len := 0;
@ -331,29 +347,53 @@ module Request = struct
bytes, !offset, !len bytes, !offset, !len
); );
bs_consume=(fun n -> offset := !offset + n); bs_consume=(fun n -> offset := !offset + n);
bs_close=(fun () -> Byte_stream.close bs); bs_close=(fun () ->
(* close this overlay, do not close underlying stream *)
len := 0; refill:= false);
} }
let read_body_chunked ~tr_stream ~buf ~size:max_size (bs:byte_stream) : string = let limit_body_size_ ~max_size (bs:byte_stream) : byte_stream =
_debug (fun k->k "read body with chunked encoding (max-size: %d)" max_size); _debug (fun k->k "limit size of body to max-size=%d" max_size);
let is = tr_stream @@ read_stream_chunked_ ~buf bs in let size = ref 0 in
let buf_res = Buf_.create() in (* store the accumulated chunks *) { bs_fill_buf = bs.bs_fill_buf;
(* TODO: extract this as a function [read_all_up_to ~max_size is]? *) bs_close=bs.bs_close;
let rec read_chunks () = bs_consume = (fun n ->
let n = Byte_stream.read_into_buf is buf_res in size := !size + n;
if n = 0 then ( if !size > max_size then (
Buf_.contents buf_res (* done *) (* read too much *)
) else ( bad_reqf 413
(* is the body bigger than expected? *) "body size was supposed to be %d, but at least %d bytes received"
if max_size>0 && Buf_.size buf_res > max_size then ( max_size !size
bad_reqf 413 );
"body size was supposed to be %d, but at least %d bytes received" bs.bs_consume n);
max_size (Buf_.size buf_res) }
);
read_chunks() let limit_body_size ~max_size (req:byte_stream t) : byte_stream t =
) { req with body=limit_body_size_ ~max_size req.body }
in
read_chunks() (* read exactly [size] bytes from the stream *)
let read_exactly ~size (bs:byte_stream) : byte_stream =
if size=0 then (
Byte_stream.empty
) else (
let size = ref size in
{ bs_fill_buf = (fun () ->
let buf, i, len = bs.bs_fill_buf () in
let len = min len !size in
if len = 0 && !size > 0 then (
bad_reqf 400 "body is too short"
);
buf, i, len
);
bs_close=(fun () ->
(* do not close underlying stream *)
size := 0);
bs_consume = (fun n ->
let n = min n !size in
size := !size - n;
bs.bs_consume n);
}
)
(* parse request, but not body (yet) *) (* parse request, but not body (yet) *)
let parse_req_start ~buf (bs:byte_stream) : unit t option resp_result = let parse_req_start ~buf (bs:byte_stream) : unit t option resp_result =
@ -379,7 +419,7 @@ module Request = struct
(* parse body, given the headers. (* parse body, given the headers.
@param tr_stream a transformation of the input stream. *) @param tr_stream a transformation of the input stream. *)
let parse_body_ ~tr_stream ~buf (req:byte_stream t) : string t resp_result = let parse_body_ ~tr_stream ~buf (req:byte_stream t) : byte_stream t resp_result =
try try
let size = let size =
match List.assoc "Content-Length" req.headers |> int_of_string with match List.assoc "Content-Length" req.headers |> int_of_string with
@ -389,9 +429,12 @@ module Request = struct
in in
let body = let body =
match get_header ~f:String.trim req "Transfer-Encoding" with match get_header ~f:String.trim req "Transfer-Encoding" with
| None -> read_body_exact (tr_stream req.body) size | None -> read_exactly ~size @@ tr_stream req.body
| Some "chunked" -> | Some "chunked" ->
read_body_chunked ~tr_stream ~buf ~size req.body (* body sent by chunks *) let bs =
read_stream_chunked_ ~buf @@ tr_stream req.body (* body sent by chunks *)
in
if size>0 then limit_body_size_ ~max_size:size bs else bs
| Some s -> bad_reqf 500 "cannot handle transfer encoding: %s" s | Some s -> bad_reqf 500 "cannot handle transfer encoding: %s" s
in in
Ok {req with body} Ok {req with body}
@ -513,7 +556,7 @@ module Sem_ = struct
Mutex.unlock t.mutex Mutex.unlock t.mutex
end end
type cb_path_handler = string Request.t -> Response.t type cb_path_handler = byte_stream Request.t -> Response.t
type t = { type t = {
addr: string; addr: string;
@ -525,7 +568,7 @@ type t = {
mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list; mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list;
mutable cb_decode_req: mutable cb_decode_req:
(unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) list; (unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) list;
mutable cb_encode_resp: (string Request.t -> Response.t -> Response.t option) list; mutable cb_encode_resp: (unit Request.t -> Response.t -> Response.t option) list;
mutable running: bool; mutable running: bool;
} }
@ -536,9 +579,9 @@ let add_decode_request_cb self f = self.cb_decode_req <- f :: self.cb_decode_re
let add_encode_response_cb self f = self.cb_encode_resp <- f :: self.cb_encode_resp let add_encode_response_cb self f = self.cb_encode_resp <- f :: self.cb_encode_resp
let set_top_handler self f = self.handler <- f let set_top_handler self f = self.handler <- f
let add_path_handler let add_path_handler_
?(accept=fun _req -> Ok ()) ?(accept=fun _req -> Ok ())
?meth self fmt f = ?meth ~tr_req self fmt f =
let ph req: cb_path_handler resp_result option = let ph req: cb_path_handler resp_result option =
match meth with match meth with
| Some m when m <> req.Request.meth -> None (* ignore *) | Some m when m <> req.Request.meth -> None (* ignore *)
@ -547,7 +590,7 @@ let add_path_handler
| handler -> | handler ->
(* we have a handler, do we accept the request based on its headers? *) (* we have a handler, do we accept the request based on its headers? *)
begin match accept req with begin match accept req with
| Ok () -> Some (Ok handler) | Ok () -> Some (Ok (fun req -> handler @@ tr_req req))
| Error _ as e -> Some e | Error _ as e -> Some e
end end
| exception _ -> | exception _ ->
@ -556,6 +599,12 @@ let add_path_handler
in in
self.path_handlers <- ph :: self.path_handlers self.path_handlers <- ph :: self.path_handlers
let add_path_handler ?accept ?meth self fmt f=
add_path_handler_ ?accept ?meth ~tr_req:Request.read_body_full self fmt f
let add_path_handler_stream ?accept ?meth self fmt f=
add_path_handler_ ?accept ?meth ~tr_req:(fun x->x) self fmt f
let create let create
?(masksigpipe=true) ?(masksigpipe=true)
?(max_connections=32) ?(max_connections=32)
@ -603,7 +652,7 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
let handler = let handler =
match find_map (fun ph -> ph req) self.path_handlers with match find_map (fun ph -> ph req) self.path_handlers with
| Some f -> unwrap_resp_result f | Some f -> unwrap_resp_result f
| None -> self.handler | None -> (fun req -> self.handler @@ Request.read_body_full req)
in in
(* handle expectations *) (* handle expectations *)
begin match Request.get_header ~f:String.trim req "Expect" with begin match Request.get_header ~f:String.trim req "Expect" with
@ -614,7 +663,7 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
| None -> () | None -> ()
end; end;
(* preprocess request's input stream *) (* preprocess request's input stream *)
let req, tr_stream = let req0, tr_stream =
List.fold_left List.fold_left
(fun (req,tr) cb -> (fun (req,tr) cb ->
match cb req with match cb req with
@ -624,13 +673,13 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit =
in in
(* now actually read request's body *) (* now actually read request's body *)
let req = let req =
Request.parse_body_ ~tr_stream ~buf {req with body=is} Request.parse_body_ ~tr_stream ~buf {req0 with body=is}
|> unwrap_resp_result |> unwrap_resp_result
in in
let resp = handler req in let resp = handler req in
(* post-process response *) (* post-process response *)
List.fold_left List.fold_left
(fun resp cb -> match cb req resp with None -> resp | Some r' -> r') (fun resp cb -> match cb req0 resp with None -> resp | Some r' -> r')
resp self.cb_encode_resp resp self.cb_encode_resp
with with
| Bad_req (code,s) -> | Bad_req (code,s) ->

View file

@ -107,6 +107,8 @@ module Byte_stream : sig
val close : t -> unit val close : t -> unit
val empty : t
val of_chan : in_channel -> t val of_chan : in_channel -> t
(** Make a buffered stream from the given channel. *) (** Make a buffered stream from the given channel. *)
@ -117,6 +119,16 @@ module Byte_stream : sig
(** A stream that just returns the slice of bytes starting from [i] (** A stream that just returns the slice of bytes starting from [i]
and of length [len]. *) and of length [len]. *)
val of_string : string -> t
val iter : (bytes -> int -> int -> unit) -> t -> unit
(** Iterate on the chunks of the stream
@since NEXT_RELEASE *)
val to_chan : out_channel -> t -> unit
(** Write the stream to the channel.
@since NEXT_RELEASE *)
val with_file : string -> (t -> 'a) -> 'a val with_file : string -> (t -> 'a) -> 'a
(** Open a file with given name, and obtain an input stream (** Open a file with given name, and obtain an input stream
on its content. When the function returns, the stream (and file) are closed. *) on its content. When the function returns, the stream (and file) are closed. *)
@ -224,6 +236,12 @@ module Request : sig
val body : 'b t -> 'b val body : 'b t -> 'b
(** Request body, possibly empty. *) (** Request body, possibly empty. *)
val limit_body_size : max_size:int -> byte_stream t -> byte_stream t
(** Limit the body size to [max_size] bytes, or return
a [413] error.
@since 0.3
*)
val read_body_full : byte_stream t -> string t val read_body_full : byte_stream t -> string t
(** Read the whole body into a string. Potentially blocking. *) (** Read the whole body into a string. Potentially blocking. *)
end end
@ -368,12 +386,12 @@ val add_decode_request_cb :
*) *)
val add_encode_response_cb: val add_encode_response_cb:
t -> (string Request.t -> Response.t -> Response.t option) -> unit t -> (unit Request.t -> Response.t -> Response.t option) -> unit
(** Add a callback for every request/response pair. (** Add a callback for every request/response pair.
Similarly to {!add_encode_response_cb} the callback can return a new Similarly to {!add_encode_response_cb} the callback can return a new
response, for example to compress it. response, for example to compress it.
The callback is given the fully parsed query as well as the current The callback is given the query with only its headers,
response. as well as the current response.
*) *)
val set_top_handler : t -> (string Request.t -> Response.t) -> unit val set_top_handler : t -> (string Request.t -> Response.t) -> unit
@ -411,6 +429,19 @@ val add_path_handler :
filter uploads that are too large before the upload even starts. filter uploads that are too large before the upload even starts.
*) *)
val add_path_handler_stream :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
?meth:Meth.t ->
t ->
('a, Scanf.Scanning.in_channel,
'b, 'c -> byte_stream Request.t -> Response.t, 'a -> 'd, 'd) format6 ->
'c -> unit
(** Similar to {!add_path_handler}, but where the body of the request
is a stream of bytes that has not been read yet.
This is useful when one wants to stream the body directly into a parser,
json decoder (such as [Jsonm]) or into a file.
@since 0.3 *)
val stop : t -> unit val stop : t -> unit
(** Ask the server to stop. This might not have an immediate effect (** Ask the server to stop. This might not have an immediate effect
as {!run} might currently be waiting on IO. *) as {!run} might currently be waiting on IO. *)

View file

@ -120,17 +120,14 @@ let serve ~config (dir:string) : _ result =
(fun _ _ -> S.Response.make_raw ~code:405 "delete not allowed"); (fun _ _ -> S.Response.make_raw ~code:405 "delete not allowed");
); );
if config.upload then ( if config.upload then (
S.add_path_handler server ~meth:`PUT "/%s" S.add_path_handler_stream server ~meth:`PUT "/%s"
~accept:(fun req -> ~accept:(fun req ->
match S.Request.get_header_int req "Content-Length" with match S.Request.get_header_int req "Content-Length" with
| Some n when n > config.max_upload_size -> | Some n when n > config.max_upload_size ->
Error (403, "max upload size is " ^ string_of_int config.max_upload_size) Error (403, "max upload size is " ^ string_of_int config.max_upload_size)
| Some _ when contains_dot_dot req.S.Request.path -> | Some _ when contains_dot_dot req.S.Request.path ->
Error (403, "invalid path (contains '..')") Error (403, "invalid path (contains '..')")
| Some _ -> Ok () | _ -> Ok ()
| None ->
Error (411, "must know size before hand: max upload size is " ^
string_of_int config.max_upload_size)
) )
(fun path req -> (fun path req ->
let fpath = dir // path in let fpath = dir // path in
@ -140,7 +137,8 @@ let serve ~config (dir:string) : _ result =
S.Response.fail_raise ~code:403 "cannot upload to %S: %s" S.Response.fail_raise ~code:403 "cannot upload to %S: %s"
path (Printexc.to_string e) path (Printexc.to_string e)
in in
output_string oc req.S.Request.body; let req = S.Request.limit_body_size ~max_size:config.max_upload_size req in
S.Byte_stream.to_chan oc req.S.Request.body;
flush oc; flush oc;
close_out oc; close_out oc;
S.Response.make_raw ~code:201 "upload successful" S.Response.make_raw ~code:201 "upload successful"