From d5f9eacc81331961a813acdebc8a475cda6e582e Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 15 Mar 2022 22:59:34 -0400 Subject: [PATCH] split code into more modules --- src/Tiny_httpd.ml | 1175 +--------------------------- src/Tiny_httpd.mli | 632 +-------------- src/Tiny_httpd_buf.ml | 35 + src/Tiny_httpd_buf.mli | 27 + src/Tiny_httpd_dir.ml | 10 +- src/Tiny_httpd_dir.mli | 6 +- src/Tiny_httpd_server.ml | 893 +++++++++++++++++++++ src/Tiny_httpd_server.mli | 567 ++++++++++++++ src/Tiny_httpd_stream.ml | 306 ++++++++ src/Tiny_httpd_stream.mli | 113 +++ src/camlzip/Tiny_httpd_camlzip.ml | 219 +++--- src/camlzip/Tiny_httpd_camlzip.mli | 4 +- 12 files changed, 2074 insertions(+), 1913 deletions(-) create mode 100644 src/Tiny_httpd_buf.ml create mode 100644 src/Tiny_httpd_buf.mli create mode 100644 src/Tiny_httpd_server.ml create mode 100644 src/Tiny_httpd_server.mli create mode 100644 src/Tiny_httpd_stream.ml create mode 100644 src/Tiny_httpd_stream.mli diff --git a/src/Tiny_httpd.ml b/src/Tiny_httpd.ml index 94000ddf..0d61fa95 100644 --- a/src/Tiny_httpd.ml +++ b/src/Tiny_httpd.ml @@ -1,1175 +1,10 @@ -type byte_stream = { - bs_fill_buf: unit -> (bytes * int * int); - bs_consume: int -> unit; - bs_close: unit -> unit; -} -(** A buffer input stream, with a view into the current buffer (or refill if empty), - and a function to consume [n] bytes *) -let _debug_on = ref ( - match String.trim @@ Sys.getenv "HTTP_DBG" with - | "" -> false | _ -> true | exception _ -> false -) -let _enable_debug b = _debug_on := b -let _debug k = - if !_debug_on then ( - k (fun fmt-> - Printf.fprintf stdout "[http.thread %d]: " Thread.(id @@ self()); - Printf.kfprintf (fun oc -> Printf.fprintf oc "\n%!") stdout fmt) - ) +module Buf = Tiny_httpd_buf -module Buf_ = struct - type t = { - mutable bytes: bytes; - mutable i: int; - } +module Byte_stream = Tiny_httpd_stream - let create ?(size=4_096) () : t = - { bytes=Bytes.make size ' '; i=0 } +include Tiny_httpd_server - let size self = self.i - let bytes_slice self = self.bytes - let clear self : unit = - if Bytes.length self.bytes > 4_096 * 1_024 then ( - self.bytes <- Bytes.make 4096 ' '; (* free big buffer *) - ); - self.i <- 0 +module Util = Tiny_httpd_util - let resize self new_size : unit = - let new_buf = Bytes.make new_size ' ' in - Bytes.blit self.bytes 0 new_buf 0 self.i; - self.bytes <- new_buf - - let add_bytes (self:t) s i len : unit = - if self.i + len >= Bytes.length self.bytes then ( - resize self (self.i + self.i / 2 + len + 10); - ); - Bytes.blit s i self.bytes self.i len; - self.i <- self.i + len - - let contents (self:t) : string = Bytes.sub_string self.bytes 0 self.i - - let contents_and_clear (self:t) : string = - let x = contents self in - clear self; - x -end - -module Byte_stream = struct - type t = byte_stream - - let close self = self.bs_close() - - let empty = { - bs_fill_buf=(fun () -> Bytes.empty, 0, 0); - bs_consume=(fun _ -> ()); - bs_close=(fun () -> ()); - } - - let of_chan_ ?(buf_size=16 * 1024) ~close ic : t = - let i = ref 0 in - let len = ref 0 in - let buf = Bytes.make buf_size ' ' in - { bs_fill_buf=(fun () -> - if !i >= !len then ( - i := 0; - len := input ic buf 0 (Bytes.length buf); - ); - buf, !i,!len - !i); - bs_consume=(fun n -> i := !i + n); - bs_close=(fun () -> close ic) - } - - let of_chan = of_chan_ ~close:close_in - let of_chan_close_noerr = of_chan_ ~close:close_in_noerr - - let rec iter f (self:t) : unit = - let s, i, len = self.bs_fill_buf () in - if len=0 then ( - self.bs_close(); - ) else ( - f s i len; - self.bs_consume len; - (iter [@tailcall]) f self - ) - - let to_chan (oc:out_channel) (self:t) = - iter (fun s i len -> output oc s i len) self - - let of_bytes ?(i=0) ?len s : t = - (* invariant: !i+!len is constant *) - let len = - ref ( - match len with - | Some n -> - if n > Bytes.length s - i then invalid_arg "Byte_stream.of_bytes"; - n - | None -> Bytes.length s - i - ) - in - let i = ref i in - { bs_fill_buf=(fun () -> s, !i, !len); - bs_close=(fun () -> len := 0); - bs_consume=(fun n -> assert (n>=0 && n<= !len); i := !i + n; len := !len - n); - } - - let of_string s : t = - of_bytes (Bytes.unsafe_of_string s) - - let with_file ?buf_size file f = - let ic = open_in file in - try - let x = f (of_chan ?buf_size ic) in - close_in ic; - x - with e -> - close_in_noerr ic; - raise e - - let read_all ?(buf=Buf_.create()) (self:t) : string = - let continue = ref true in - while !continue do - let s, i, len = self.bs_fill_buf () in - _debug (fun k->k "read-all: got i=%d, len=%d, bufsize %d" i len (Buf_.size buf)); - if len > 0 then ( - Buf_.add_bytes buf s i len; - self.bs_consume len; - ); - assert (len >= 0); - if len = 0 then ( - continue := false - ) - done; - Buf_.contents_and_clear buf - - (* put [n] bytes from the input into bytes *) - let read_exactly_ ~too_short (self:t) (bytes:bytes) (n:int) : unit = - assert (Bytes.length bytes >= n); - let offset = ref 0 in - while !offset < n do - let s, i, len = self.bs_fill_buf () in - let n_read = min len (n- !offset) in - Bytes.blit s i bytes !offset n_read; - offset := !offset + n_read; - self.bs_consume n_read; - if n_read=0 then too_short(); - done - - (* read a line into the buffer, after clearing it. *) - let read_line_into (self:t) ~buf : unit = - Buf_.clear buf; - let continue = ref true in - while !continue do - let s, i, len = self.bs_fill_buf () in - if len=0 then ( - continue := false; - if Buf_.size buf = 0 then raise End_of_file; - ); - let j = ref i in - while !j < i+len && Bytes.get s !j <> '\n' do - incr j - done; - if !j-i < len then ( - assert (Bytes.get s !j = '\n'); - Buf_.add_bytes buf s i (!j-i); (* without \n *) - self.bs_consume (!j-i+1); (* remove \n *) - continue := false - ) else ( - Buf_.add_bytes buf s i len; - self.bs_consume len; - ) - done - - (* new stream with maximum size [max_size]. - @param close_rec if true, closing this will also close the input stream - @param too_big called with read size if the max size is reached *) - let limit_size_to ~close_rec ~max_size ~too_big (self:t) : t = - let size = ref 0 in - let continue = ref true in - { bs_fill_buf = - (fun () -> - if !continue then self.bs_fill_buf() else Bytes.empty, 0, 0); - bs_close=(fun () -> - if close_rec then self.bs_close ()); - bs_consume = (fun n -> - size := !size + n; - if !size > max_size then ( - continue := false; - too_big !size - ) else ( - self.bs_consume n - )); - } - - (* read exactly [size] bytes from the stream *) - let read_exactly ~close_rec ~size ~too_short (self:t) : t = - if size=0 then ( - empty - ) else ( - let size = ref size in - { bs_fill_buf = (fun () -> - (* must not block on [self] if we're done *) - if !size = 0 then Bytes.empty, 0, 0 - else ( - let buf, i, len = self.bs_fill_buf () in - let len = min len !size in - if len = 0 && !size > 0 then ( - too_short !size; - ); - buf, i, len - ) - ); - bs_close=(fun () -> - (* close underlying stream if [close_rec] *) - if close_rec then self.bs_close(); - size := 0); - bs_consume = (fun n -> - let n = min n !size in - size := !size - n; - self.bs_consume n); - } - ) - - let read_line ?(buf=Buf_.create()) self : string = - read_line_into self ~buf; - Buf_.contents buf -end - -exception Bad_req of int * string -let bad_reqf c fmt = Printf.ksprintf (fun s ->raise (Bad_req (c,s))) fmt - -module Response_code = struct - type t = int - - let ok = 200 - let not_found = 404 - let descr = function - | 100 -> "Continue" - | 200 -> "OK" - | 201 -> "Created" - | 202 -> "Accepted" - | 204 -> "No content" - | 300 -> "Multiple choices" - | 301 -> "Moved permanently" - | 302 -> "Found" - | 304 -> "Not Modified" - | 400 -> "Bad request" - | 403 -> "Forbidden" - | 404 -> "Not found" - | 405 -> "Method not allowed" - | 408 -> "Request timeout" - | 409 -> "Conflict" - | 410 -> "Gone" - | 411 -> "Length required" - | 413 -> "Payload too large" - | 417 -> "Expectation failed" - | 500 -> "Internal server error" - | 501 -> "Not implemented" - | 503 -> "Service unavailable" - | n -> "Unknown response code " ^ string_of_int n (* TODO *) -end - -type 'a resp_result = ('a, Response_code.t * string) result -let unwrap_resp_result = function - | Ok x -> x - | Error (c,s) -> raise (Bad_req (c,s)) - -module Meth = struct - type t = [ - | `GET - | `PUT - | `POST - | `HEAD - | `DELETE - ] - - let to_string = function - | `GET -> "GET" - | `PUT -> "PUT" - | `HEAD -> "HEAD" - | `POST -> "POST" - | `DELETE -> "DELETE" - let pp out s = Format.pp_print_string out (to_string s) - - let of_string = function - | "GET" -> `GET - | "PUT" -> `PUT - | "POST" -> `POST - | "HEAD" -> `HEAD - | "DELETE" -> `DELETE - | s -> bad_reqf 400 "unknown method %S" s -end - -module Headers = struct - type t = (string * string) list - let empty = [] - let contains name headers = - let name' = String.lowercase_ascii name in - List.exists (fun (n, _) -> name'=n) headers - let get_exn ?(f=fun x->x) x h = - let x' = String.lowercase_ascii x in - List.assoc x' h |> f - let get ?(f=fun x -> x) x h = - try Some (get_exn ~f x h) with Not_found -> None - let remove x h = - let x' = String.lowercase_ascii x in - List.filter (fun (k,_) -> k<>x') h - let set x y h = - let x' = String.lowercase_ascii x in - (x',y) :: List.filter (fun (k,_) -> k<>x') h - let pp out l = - let pp_pair out (k,v) = Format.fprintf out "@[%s: %s@]" k v in - Format.fprintf out "@[%a@]" (Format.pp_print_list pp_pair) l - - (* token = 1*tchar - tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / "^" / "_" - / "`" / "|" / "~" / DIGIT / ALPHA ; any VCHAR, except delimiters - Reference: https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 *) - let is_tchar = function - | '0' .. '9' | 'a' .. 'z' | 'A' .. 'Z' - | '!' | '#' | '$' | '%' | '&' | '\'' | '*' | '+' | '-' | '.' | '^' - | '_' | '`' | '|' | '~' -> true - | _ -> false - - let for_all pred s = - try String.iter (fun c->if not (pred c) then raise Exit) s; true - with Exit -> false - - let parse_ ~buf (bs:byte_stream) : t = - let rec loop acc = - let line = Byte_stream.read_line ~buf bs in - _debug (fun k->k "parsed header line %S" line); - if line = "\r" then ( - acc - ) else ( - let k,v = - try - let i = String.index line ':' in - let k = String.sub line 0 i in - if not (for_all is_tchar k) then ( - invalid_arg (Printf.sprintf "Invalid header key: %S" k)); - let v = String.sub line (i+1) (String.length line-i-1) |> String.trim in - k,v - with _ -> bad_reqf 400 "invalid header line: %S" line - in - loop ((String.lowercase_ascii k,v)::acc) - ) - in - loop [] -end - -module Request = struct - type 'body t = { - meth: Meth.t; - host: string; - headers: Headers.t; - http_version: int*int; - path: string; - path_components: string list; - query: (string*string) list; - body: 'body; - start_time: float; - } - - let headers self = self.headers - let host self = self.host - let meth self = self.meth - let path self = self.path - let body self = self.body - let start_time self = self.start_time - - let query self = self.query - let get_header ?f self h = Headers.get ?f h self.headers - let get_header_int self h = match get_header self h with - | Some x -> (try Some (int_of_string x) with _ -> None) - | None -> None - let set_header k v self = {self with headers=Headers.set k v self.headers} - let update_headers f self = {self with headers=f self.headers} - let set_body b self = {self with body=b} - - (** Should we close the connection after this request? *) - let close_after_req (self:_ t) : bool = - match self.http_version with - | 1, 1 -> get_header self "connection" =Some"close" - | 1, 0 -> not (get_header self "connection"=Some"keep-alive") - | _ -> false - - let pp_comp_ out comp = - Format.fprintf out "[%s]" - (String.concat ";" @@ List.map (Printf.sprintf "%S") comp) - let pp_query out q = - Format.fprintf out "[%s]" - (String.concat ";" @@ - List.map (fun (a,b) -> Printf.sprintf "%S,%S" a b) q) - let pp_ out self : unit = - Format.fprintf out "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ \ - path=%S;@ body=?;@ path_components=%a;@ query=%a@]}" - (Meth.to_string self.meth) self.host Headers.pp self.headers self.path - pp_comp_ self.path_components pp_query self.query - let pp out self : unit = - Format.fprintf out "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ path=%S;@ \ - body=%S;@ path_components=%a;@ query=%a@]}" - (Meth.to_string self.meth) self.host Headers.pp self.headers - self.path self.body pp_comp_ self.path_components pp_query self.query - - (* decode a "chunked" stream into a normal stream *) - let read_stream_chunked_ ?(buf=Buf_.create()) (bs:byte_stream) : byte_stream = - _debug (fun k->k "body: start reading chunked stream..."); - let first = ref true in - let read_next_chunk_len () : int = - if !first then ( - first := false - ) else ( - let line = Byte_stream.read_line ~buf bs in - if String.trim line <> "" then bad_reqf 400 "expected crlf between chunks"; - ); - let line = Byte_stream.read_line ~buf bs in - (* parse chunk length, ignore extensions *) - let chunk_size = ( - if String.trim line = "" then 0 - else - try Scanf.sscanf line "%x %s@\r" (fun n _ext -> n) - with _ -> bad_reqf 400 "cannot read chunk size from line %S" line - ) in - chunk_size - in - let refill = ref true in - let bytes = Bytes.make (16 * 4096) ' ' in (* internal buffer, 16kb *) - let offset = ref 0 in - let len = ref 0 in - let chunk_size = ref 0 in - { bs_fill_buf= - (fun () -> - (* do we need to refill? *) - if !offset >= !len then ( - if !chunk_size = 0 && !refill then ( - chunk_size := read_next_chunk_len(); - (* _debug (fun k->k"read next chunk of size %d" !chunk_size); *) - ); - offset := 0; - len := 0; - if !chunk_size > 0 then ( - (* read the whole chunk, or [Bytes.length bytes] of it *) - let to_read = min !chunk_size (Bytes.length bytes) in - Byte_stream.read_exactly_ - ~too_short:(fun () -> bad_reqf 400 "chunk is too short") - bs bytes to_read; - len := to_read; - chunk_size := !chunk_size - to_read; - ) else ( - refill := false; (* stream is finished *) - ) - ); - bytes, !offset, !len - ); - bs_consume=(fun n -> offset := !offset + n); - bs_close=(fun () -> - (* close this overlay, do not close underlying stream *) - len := 0; refill:= false); - } - - let limit_body_size_ ~max_size (bs:byte_stream) : byte_stream = - _debug (fun k->k "limit size of body to max-size=%d" max_size); - Byte_stream.limit_size_to ~max_size ~close_rec:false bs - ~too_big:(fun size -> - (* read too much *) - bad_reqf 413 - "body size was supposed to be %d, but at least %d bytes received" - max_size size - ) - - let limit_body_size ~max_size (req:byte_stream t) : byte_stream t = - { req with body=limit_body_size_ ~max_size req.body } - - (* read exactly [size] bytes from the stream *) - let read_exactly ~size (bs:byte_stream) : byte_stream = - _debug (fun k->k "body: must read exactly %d bytes" size); - Byte_stream.read_exactly bs ~close_rec:false - ~size ~too_short:(fun size -> - bad_reqf 400 "body is too short by %d bytes" size - ) - - (* parse request, but not body (yet) *) - let parse_req_start ~get_time_s ~buf (bs:byte_stream) : unit t option resp_result = - try - let line = Byte_stream.read_line ~buf bs in - let start_time = get_time_s() in - let meth, path, version = - try - let meth, path, version = Scanf.sscanf line "%s %s HTTP/1.%d\r" (fun x y z->x,y,z) in - if version != 0 && version != 1 then raise Exit; - meth, path, version - with _ -> - _debug (fun k->k "invalid request line: `%s`" line); - raise (Bad_req (400, "Invalid request line")) - in - let meth = Meth.of_string meth in - _debug (fun k->k "got meth: %s, path %S" (Meth.to_string meth) path); - let headers = Headers.parse_ ~buf bs in - let host = - match Headers.get "Host" headers with - | None -> bad_reqf 400 "No 'Host' header in request" - | Some h -> h - in - let path_components, query = Tiny_httpd_util.split_query path in - let path_components = Tiny_httpd_util.split_on_slash path_components in - let query = - match Tiny_httpd_util.(parse_query query) with - | Ok l -> l - | Error e -> bad_reqf 400 "invalid query: %s" e - in - let req = { - meth; query; host; path; path_components; - headers; http_version=(1, version); body=(); start_time; - } in - Ok (Some req) - with - | End_of_file | Sys_error _ -> Ok None - | Bad_req (c,s) -> Error (c,s) - | e -> - Error (400, Printexc.to_string e) - - (* parse body, given the headers. - @param tr_stream a transformation of the input stream. *) - let parse_body_ ~tr_stream ~buf (req:byte_stream t) : byte_stream t resp_result = - try - let size = - match Headers.get_exn "Content-Length" req.headers |> int_of_string with - | n -> n (* body of fixed size *) - | exception Not_found -> 0 - | exception _ -> bad_reqf 400 "invalid content-length" - in - let body = - match get_header ~f:String.trim req "Transfer-Encoding" with - | None -> read_exactly ~size @@ tr_stream req.body - | Some "chunked" -> - let bs = - read_stream_chunked_ ~buf @@ tr_stream req.body (* body sent by chunks *) - in - if size>0 then limit_body_size_ ~max_size:size bs else bs - | Some s -> bad_reqf 500 "cannot handle transfer encoding: %s" s - in - Ok {req with body} - with - | End_of_file -> Error (400, "unexpected end of file") - | Bad_req (c,s) -> Error (c,s) - | e -> - Error (400, Printexc.to_string e) - - let read_body_full ?buf_size (self:byte_stream t) : string t = - try - let buf = Buf_.create ?size:buf_size () in - let body = Byte_stream.read_all ~buf self.body in - { self with body } - with - | Bad_req _ as e -> raise e - | e -> bad_reqf 500 "failed to read body: %s" (Printexc.to_string e) - - module Internal_ = struct - let parse_req_start ?(buf=Buf_.create()) ~get_time_s bs = - parse_req_start ~get_time_s ~buf bs |> unwrap_resp_result - - let parse_body ?(buf=Buf_.create()) req bs : _ t = - parse_body_ ~tr_stream:(fun s->s) ~buf {req with body=bs} |> unwrap_resp_result - end -end - -(*$R - let q = "GET hello HTTP/1.1\r\nHost: coucou\r\nContent-Length: 11\r\n\r\nsalutationsSOMEJUNK" in - let str = Byte_stream.of_string q in - let r = Request.Internal_.parse_req_start ~get_time_s:(fun _ -> 0.) str in - match r with - | None -> assert_failure "should parse" - | Some req -> - assert_equal (Some "coucou") (Headers.get "Host" req.Request.headers); - assert_equal (Some "coucou") (Headers.get "host" req.Request.headers); - assert_equal (Some "11") (Headers.get "content-length" req.Request.headers); - assert_equal "hello" req.Request.path; - let req = Request.Internal_.parse_body req str |> Request.read_body_full in - assert_equal ~printer:(fun s->s) "salutations" req.Request.body; - () -*) - -module Response = struct - type body = [`String of string | `Stream of byte_stream | `Void] - type t = { - code: Response_code.t; - headers: Headers.t; - body: body; - } - - let set_body body self = {self with body} - let set_headers headers self = {self with headers} - let update_headers f self = {self with headers=f self.headers} - let set_header k v self = {self with headers = Headers.set k v self.headers} - let set_code code self = {self with code} - - let make_raw ?(headers=[]) ~code body : t = - (* add content length to response *) - let headers = - Headers.set "Content-Length" (string_of_int (String.length body)) headers - in - { code; headers; body=`String body; } - - let make_raw_stream ?(headers=[]) ~code body : t = - (* add content length to response *) - let headers = Headers.set "Transfer-Encoding" "chunked" headers in - { code; headers; body=`Stream body; } - - let make_void ?(headers=[]) ~code () : t = - { code; headers; body=`Void; } - - let make_string ?headers r = match r with - | Ok body -> make_raw ?headers ~code:200 body - | Error (code,msg) -> make_raw ?headers ~code msg - - let make_stream ?headers r = match r with - | Ok body -> make_raw_stream ?headers ~code:200 body - | Error (code,msg) -> make_raw ?headers ~code msg - - let make ?headers r : t = match r with - | Ok (`String body) -> make_raw ?headers ~code:200 body - | Ok (`Stream body) -> make_raw_stream ?headers ~code:200 body - | Ok `Void -> make_void ?headers ~code:200 () - | Error (code,msg) -> make_raw ?headers ~code msg - - let fail ?headers ~code fmt = - Printf.ksprintf (fun msg -> make_raw ?headers ~code msg) fmt - let fail_raise ~code fmt = - Printf.ksprintf (fun msg -> raise (Bad_req (code,msg))) fmt - - let pp out self : unit = - let pp_body out = function - | `String s -> Format.fprintf out "%S" s - | `Stream _ -> Format.pp_print_string out "" - | `Void -> () - in - Format.fprintf out "{@[code=%d;@ headers=[@[%a@]];@ body=%a@]}" - self.code Headers.pp self.headers pp_body self.body - - (* print a stream as a series of chunks *) - let output_stream_chunked_ (oc:out_channel) (str:byte_stream) : unit = - let continue = ref true in - while !continue do - (* next chunk *) - let s, i, len = str.bs_fill_buf () in - Printf.fprintf oc "%x\r\n" len; - output oc s i len; - str.bs_consume len; - if len = 0 then ( - continue := false; - ); - output_string oc "\r\n"; - done; - () - - let output_ (oc:out_channel) (self:t) : unit = - Printf.fprintf oc "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code); - let body, is_chunked = match self.body with - | `String s when String.length s > 1024 * 500 -> - (* chunk-encode large bodies *) - `Stream (Byte_stream.of_string s), true - | `String _ as b -> b, false - | `Stream _ as b -> b, true - | `Void as b -> b, false - in - let headers = - if is_chunked then ( - self.headers - |> Headers.set "transfer-encoding" "chunked" - |> Headers.remove "content-length" - ) else self.headers - in - let self = {self with headers; body} in - _debug (fun k->k "output response: %s" - (Format.asprintf "%a" pp {self with body=`String "<…>"})); - List.iter (fun (k,v) -> Printf.fprintf oc "%s: %s\r\n" k v) headers; - output_string oc "\r\n"; - begin match body with - | `String "" | `Void -> () - | `String s -> output_string oc s; - | `Stream str -> output_stream_chunked_ oc str; - end; - flush oc -end - -(* semaphore, for limiting concurrency. *) -module Sem_ = struct - type t = { - mutable n : int; - max : int; - mutex : Mutex.t; - cond : Condition.t; - } - - let create n = - if n <= 0 then invalid_arg "Semaphore.create"; - { n; max=n; mutex=Mutex.create(); cond=Condition.create(); } - - let acquire m t = - Mutex.lock t.mutex; - while t.n < m do - Condition.wait t.cond t.mutex; - done; - assert (t.n >= m); - t.n <- t.n - m; - Condition.broadcast t.cond; - Mutex.unlock t.mutex - - let release m t = - Mutex.lock t.mutex; - t.n <- t.n + m; - Condition.broadcast t.cond; - Mutex.unlock t.mutex - - let num_acquired t = t.max - t.n -end - -module Route = struct - type path = string list (* split on '/' *) - - type (_, _) comp = - | Exact : string -> ('a, 'a) comp - | Int : (int -> 'a, 'a) comp - | String : (string -> 'a, 'a) comp - | String_urlencoded : (string -> 'a, 'a) comp - - type (_, _) t = - | Fire : ('b, 'b) t - | Rest : { - url_encoded: bool; - } -> (string -> 'b, 'b) t - | Compose: ('a, 'b) comp * ('b, 'c) t -> ('a, 'c) t - - let return = Fire - let rest_of_path = Rest {url_encoded=false} - let rest_of_path_urlencoded = Rest {url_encoded=true} - let (@/) a b = Compose (a,b) - let string = String - let string_urlencoded = String_urlencoded - let int = Int - let exact (s:string) = Exact s - let exact_path (s:string) tail = - let rec fn = function - | [] -> tail - | ""::ls -> fn ls - | s::ls -> exact s @/ fn ls - in - fn (String.split_on_char '/' s) - let rec eval : - type a b. path -> (a,b) t -> a -> b option = - fun path route f -> - begin match path, route with - | [], Fire -> Some f - | _, Fire -> None - | _, Rest {url_encoded} -> - let whole_path = String.concat "/" path in - begin match - if url_encoded - then match Tiny_httpd_util.percent_decode whole_path with - | Some s -> s - | None -> raise_notrace Exit - else whole_path - with - | whole_path -> - Some (f whole_path) - | exception Exit -> None - end - | (c1 :: path'), Compose (comp, route') -> - begin match comp with - | Int -> - begin match int_of_string c1 with - | i -> eval path' route' (f i) - | exception _ -> None - end - | String -> - eval path' route' (f c1) - | String_urlencoded -> - begin match Tiny_httpd_util.percent_decode c1 with - | None -> None - | Some s -> eval path' route' (f s) - end - | Exact s -> - if s = c1 then eval path' route' f else None - end - | [], Compose (String, Fire) -> Some (f "") (* trailing *) - | [], Compose (String_urlencoded, Fire) -> Some (f "") (* trailing *) - | [], Compose _ -> None - end - - let bpf = Printf.bprintf - let rec pp_ - : type a b. Buffer.t -> (a,b) t -> unit - = fun out -> function - | Fire -> bpf out "/" - | Rest {url_encoded} -> - bpf out "" (if url_encoded then "_urlencoded" else "") - | Compose (Exact s, tl) -> bpf out "%s/%a" s pp_ tl - | Compose (Int, tl) -> bpf out "/%a" pp_ tl - | Compose (String, tl) -> bpf out "/%a" pp_ tl - | Compose (String_urlencoded, tl) -> bpf out "/%a" pp_ tl - - let to_string x = - let b = Buffer.create 16 in - pp_ b x; - Buffer.contents b - let pp out x = Format.pp_print_string out (to_string x) -end - -module Middleware = struct - type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit - type t = handler -> handler - - (** Apply a list of middlewares to [h] *) - let apply_l (l:t list) (h:handler) : handler = - List.fold_right (fun m h -> m h) l h - - let[@inline] nil : t = fun h -> h -end - -(* a request handler. handles a single request. *) -type cb_path_handler = - out_channel -> - Middleware.handler - -module type SERVER_SENT_GENERATOR = sig - val set_headers : Headers.t -> unit - val send_event : - ?event:string -> - ?id:string -> - ?retry:string -> - data:string -> - unit -> unit - val close : unit -> unit -end -type server_sent_generator = (module SERVER_SENT_GENERATOR) - -type t = { - addr: string; - - port: int; - - sock: Unix.file_descr option; - - timeout: float; - - sem_max_connections: Sem_.t; - (* semaphore to restrict the number of active concurrent connections *) - - new_thread: (unit -> unit) -> unit; - (* a function to run the given callback in a separate thread (or thread pool) *) - - masksigpipe: bool; - - buf_size: int; - - get_time_s : unit -> float; - - mutable handler: (string Request.t -> Response.t); - (* toplevel handler, if any *) - - mutable middlewares : (int * Middleware.t) list; - (** Global middlewares *) - - mutable middlewares_sorted : (int * Middleware.t) list lazy_t; - (* sorted version of {!middlewares} *) - - mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list; - (* path handlers *) - - mutable running: bool; - (* true while the server is running. no need to protect with a mutex, - writes should be atomic enough. *) -} - -let addr self = self.addr -let port self = self.port - -let active_connections self = Sem_.num_acquired self.sem_max_connections - 1 - -let add_middleware ~stage self m = - let stage = match stage with - | `Encoding -> 0 - | `Stage n when n < 1 -> invalid_arg "add_middleware: bad stage" - | `Stage n -> n - in - self.middlewares <- (stage,m) :: self.middlewares; - self.middlewares_sorted <- lazy ( - List.stable_sort (fun (s1,_) (s2,_) -> compare s1 s2) self.middlewares - ) - -let add_decode_request_cb self f = - (* turn it into a middleware *) - let m h req ~resp = - (* see if [f] modifies the stream *) - let req0 = {req with Request.body=()} in - match f req0 with - | None -> h req ~resp (* pass through *) - | Some (req1, tr_stream) -> - let req = {req1 with Request.body=tr_stream req.Request.body} in - h req ~resp - in - add_middleware self ~stage:`Encoding m - -let add_encode_response_cb self f = - let m h req ~resp = - h req ~resp:(fun r -> - let req0 = {req with Request.body=()} in - (* now transform [r] if we want to *) - match f req0 r with - | None -> resp r - | Some r' -> resp r') - in - add_middleware self ~stage:`Encoding m - -let set_top_handler self f = self.handler <- f - -(* route the given handler. - @param tr_req wraps the actual concrete function returned by the route - and makes it into a handler. *) -let add_route_handler_ - ?(accept=fun _req -> Ok ()) ?(middlewares=[]) - ?meth ~tr_req self (route:_ Route.t) f = - let ph req : cb_path_handler resp_result option = - match meth with - | Some m when m <> req.Request.meth -> None (* ignore *) - | _ -> - begin match Route.eval req.Request.path_components route f with - | Some handler -> - (* we have a handler, do we accept the request based on its headers? *) - begin match accept req with - | Ok () -> - Some (Ok (fun oc -> - Middleware.apply_l middlewares @@ - fun req ~resp -> tr_req oc req ~resp handler)) - | Error _ as e -> Some e - end - | None -> - None (* path didn't match *) - end - in - self.path_handlers <- ph :: self.path_handlers - -let add_route_handler (type a) ?accept ?middlewares ?meth - self (route:(a,_) Route.t) (f:_) : unit = - let tr_req _oc req ~resp f = resp (f (Request.read_body_full ~buf_size:self.buf_size req)) in - add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f - -let add_route_handler_stream ?accept ?middlewares ?meth self route f = - let tr_req _oc req ~resp f = resp (f req) in - add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f - -let[@inline] _opt_iter ~f o = match o with - | None -> () - | Some x -> f x - -let add_route_server_sent_handler ?accept self route f = - let tr_req oc req ~resp f = - let req = Request.read_body_full ~buf_size:self.buf_size req in - let headers = ref Headers.(empty |> set "content-type" "text/event-stream") in - - (* send response once *) - let resp_sent = ref false in - let send_response_idempotent_ () = - if not !resp_sent then ( - resp_sent := true; - (* send 200 response now *) - let initial_resp = Response.make_void ~headers:!headers ~code:200 () in - resp initial_resp; - ) - in - - let send_event ?event ?id ?retry ~data () : unit = - send_response_idempotent_(); - _opt_iter event ~f:(fun e -> Printf.fprintf oc "event: %s\n" e); - _opt_iter id ~f:(fun e -> Printf.fprintf oc "id: %s\n" e); - _opt_iter retry ~f:(fun e -> Printf.fprintf oc "retry: %s\n" e); - let l = String.split_on_char '\n' data in - List.iter (fun s -> Printf.fprintf oc "data: %s\n" s) l; - output_string oc "\n"; (* finish group *) - flush oc - in - let module SSG = struct - let set_headers h = - if not !resp_sent then ( - headers := List.rev_append h !headers; - send_response_idempotent_() - ) - let send_event = send_event - let close () = raise Exit - end in - try f req (module SSG : SERVER_SENT_GENERATOR); - with Exit -> close_out oc - in - add_route_handler_ self ?accept ~meth:`GET route ~tr_req f - -let create - ?(masksigpipe=true) - ?(max_connections=32) - ?(timeout=0.0) - ?(buf_size=16 * 1_024) - ?(get_time_s=Unix.gettimeofday) - ?(new_thread=(fun f -> ignore (Thread.create f () : Thread.t))) - ?(addr="127.0.0.1") ?(port=8080) ?sock - ?(middlewares=[]) - () : t = - let handler _req = Response.fail ~code:404 "no top handler" in - let max_connections = max 4 max_connections in - let self = { - new_thread; addr; port; sock; masksigpipe; handler; buf_size; - running= true; sem_max_connections=Sem_.create max_connections; - path_handlers=[]; timeout; get_time_s; - middlewares=[]; middlewares_sorted=lazy []; - } in - List.iter (fun (stage,m) -> add_middleware self ~stage m) middlewares; - self - -let stop s = s.running <- false - -let find_map f l = - let rec aux f = function - | [] -> None - | x::l' -> - match f x with - | Some _ as res -> res - | None -> aux f l' - in aux f l - -let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = - Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout); - Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout); - let ic = Unix.in_channel_of_descr client_sock in - let oc = Unix.out_channel_of_descr client_sock in - let buf = Buf_.create ~size:self.buf_size () in - let is = Byte_stream.of_chan ~buf_size:self.buf_size ic in - let continue = ref true in - while !continue && self.running do - _debug (fun k->k "read next request"); - match Request.parse_req_start ~get_time_s:self.get_time_s ~buf is with - | Ok None -> - continue := false (* client is done *) - - | Error (c,s) -> - (* connection error, close *) - let res = Response.make_raw ~code:c s in - begin - try Response.output_ oc res - with Sys_error _ -> () - end; - continue := false - - | Ok (Some req) -> - _debug (fun k->k "req: %s" (Format.asprintf "@[%a@]" Request.pp_ req)); - - if Request.close_after_req req then continue := false; - - try - (* is there a handler for this path? *) - let handler = - match find_map (fun ph -> ph req) self.path_handlers with - | Some f -> unwrap_resp_result f - | None -> - (fun _oc req ~resp -> - let body_str = Request.read_body_full ~buf_size:self.buf_size req in - resp (self.handler body_str)) - in - - (* handle expect/continue *) - begin match Request.get_header ~f:String.trim req "Expect" with - | Some "100-continue" -> - _debug (fun k->k "send back: 100 CONTINUE"); - Response.output_ oc (Response.make_raw ~code:100 ""); - | Some s -> bad_reqf 417 "unknown expectation %s" s - | None -> () - end; - - (* apply middlewares *) - let handler = - fun oc -> - List.fold_right (fun (_, m) h -> m h) - (Lazy.force self.middlewares_sorted) (handler oc) - in - - (* now actually read request's body into a stream *) - let req = - Request.parse_body_ ~tr_stream:(fun s->s) ~buf {req with body=is} - |> unwrap_resp_result - in - - (* how to reply *) - let resp r = - try - if Headers.get "connection" r.Response.headers = Some"close" then - continue := false; - Response.output_ oc r - with Sys_error _ -> continue := false - in - - (* call handler *) - begin - try handler oc req ~resp - with Sys_error _ -> continue := false - end - with - | Sys_error _ -> - continue := false; (* connection broken somehow *) - | Bad_req (code,s) -> - continue := false; - Response.output_ oc @@ Response.make_raw ~code s - | e -> - continue := false; - Response.output_ oc @@ Response.fail ~code:500 "server error: %s" (Printexc.to_string e) - done; - _debug (fun k->k "done with client, exiting"); - (try Unix.close client_sock - with e -> _debug (fun k->k "error when closing sock: %s" (Printexc.to_string e))); - () - -let is_ipv6 self = String.contains self.addr ':' - -let run (self:t) : (unit,_) result = - try - if self.masksigpipe then ( - ignore (Unix.sigprocmask Unix.SIG_BLOCK [Sys.sigpipe] : _ list); - ); - let sock, should_bind = match self.sock with - | Some s -> - s, false (* Because we're getting a socket from the caller (e.g. systemd) *) - | None -> - Unix.socket - (if is_ipv6 self then Unix.PF_INET6 else Unix.PF_INET) - Unix.SOCK_STREAM - 0, - true (* Because we're creating the socket ourselves *) - in - Unix.clear_nonblock sock; - Unix.setsockopt_optint sock Unix.SO_LINGER None; - begin if should_bind then - let inet_addr = Unix.inet_addr_of_string self.addr in - Unix.setsockopt sock Unix.SO_REUSEADDR true; - Unix.bind sock (Unix.ADDR_INET (inet_addr, self.port)); - Unix.listen sock (2 * self.sem_max_connections.Sem_.n) - end; - while self.running do - (* limit concurrency *) - Sem_.acquire 1 self.sem_max_connections; - try - let client_sock, _ = Unix.accept sock in - self.new_thread - (fun () -> - try - handle_client_ self client_sock; - Sem_.release 1 self.sem_max_connections; - with e -> - (try Unix.close client_sock with _ -> ()); - Sem_.release 1 self.sem_max_connections; - raise e - ); - with e -> - Sem_.release 1 self.sem_max_connections; - _debug (fun k -> k - "Unix.accept or Thread.create raised an exception: %s" - (Printexc.to_string e)) - done; - Ok () - with e -> Error e +module Dir = Tiny_httpd_dir diff --git a/src/Tiny_httpd.mli b/src/Tiny_httpd.mli index 158a69fd..7daba83b 100644 --- a/src/Tiny_httpd.mli +++ b/src/Tiny_httpd.mli @@ -81,639 +81,23 @@ echo: processing streams and parsing requests. *) -module Buf_ : sig - type t - val size : t -> int - val clear : t -> unit - val create : ?size:int -> unit -> t - val contents : t -> string - - val bytes_slice : t -> bytes - (** Access underlying slice of bytes. - @since 0.5 *) - - val contents_and_clear : t -> string - (** Get contents of the buffer and clear it. - @since 0.5 *) - - val add_bytes : t -> bytes -> int -> int -> unit - (** Append given bytes slice to the buffer. - @since 0.5 *) -end +module Buf = Tiny_httpd_buf (** {2 Generic stream of data} Streams are used to represent a series of bytes that can arrive progressively. For example, an uploaded file will be sent as a series of chunks. *) -type byte_stream = { - bs_fill_buf: unit -> (bytes * int * int); - (** See the current slice of the internal buffer as [bytes, i, len], - where the slice is [bytes[i] .. [bytes[i+len-1]]]. - Can block to refill the buffer if there is currently no content. - If [len=0] then there is no more data. *) - bs_consume: int -> unit; - (** Consume n bytes from the buffer. This should only be called with [n <= len] - after a call to [is_fill_buf] that returns a slice of length [len]. *) - bs_close: unit -> unit; - (** Close the stream. *) -} -(** A buffered stream, with a view into the current buffer (or refill if empty), - and a function to consume [n] bytes. - See {!Byte_stream} for more details. *) +module Byte_stream = Tiny_httpd_stream -module Byte_stream : sig - type t = byte_stream +(** {2 Main Server Type} *) - val close : t -> unit +include module type of struct include Tiny_httpd_server end - val empty : t +(** {2 Utils} *) - val of_chan : ?buf_size:int -> in_channel -> t - (** Make a buffered stream from the given channel. *) +module Util = Tiny_httpd_util - val of_chan_close_noerr : ?buf_size:int -> in_channel -> t - (** Same as {!of_chan} but the [close] method will never fail. *) +(** {2 Static directory serving} *) - val of_bytes : ?i:int -> ?len:int -> bytes -> t - (** A stream that just returns the slice of bytes starting from [i] - and of length [len]. *) - - val of_string : string -> t - - val iter : (bytes -> int -> int -> unit) -> t -> unit - (** Iterate on the chunks of the stream - @since 0.3 *) - - val to_chan : out_channel -> t -> unit - (** Write the stream to the channel. - @since 0.3 *) - - val with_file : ?buf_size:int -> string -> (t -> 'a) -> 'a - (** Open a file with given name, and obtain an input stream - on its content. When the function returns, the stream (and file) are closed. *) - - val read_line : ?buf:Buf_.t -> t -> string - (** Read a line from the stream. - @param buf a buffer to (re)use. Its content will be cleared. *) - - val read_all : ?buf:Buf_.t -> t -> string - (** Read the whole stream into a string. - @param buf a buffer to (re)use. Its content will be cleared. *) -end - -(** {2 Methods} *) - -module Meth : sig - type t = [ - | `GET - | `PUT - | `POST - | `HEAD - | `DELETE - ] - (** A HTTP method. - For now we only handle a subset of these. - - See https://tools.ietf.org/html/rfc7231#section-4 *) - - val pp : Format.formatter -> t -> unit - val to_string : t -> string -end - -(** {2 Headers} - - Headers are metadata associated with a request or response. *) - -module Headers : sig - type t = (string * string) list - (** The header files of a request or response. - - Neither the key nor the value can contain ['\r'] or ['\n']. - See https://tools.ietf.org/html/rfc7230#section-3.2 *) - - val empty : t - (** Empty list of headers - @since 0.5 *) - - val get : ?f:(string->string) -> string -> t -> string option - (** [get k headers] looks for the header field with key [k]. - @param f if provided, will transform the value before it is returned. *) - - val set : string -> string -> t -> t - (** [set k v headers] sets the key [k] to value [v]. - It erases any previous entry for [k] *) - - val remove : string -> t -> t - (** Remove the key from the headers, if present. *) - - val contains : string -> t -> bool - (** Is there a header with the given key? *) - - val pp : Format.formatter -> t -> unit - (** Pretty print the headers. *) -end - -(** {2 Requests} - - Requests are sent by a client, e.g. a web browser or cURL. *) - -module Request : sig - type 'body t = private { - meth: Meth.t; - host: string; - headers: Headers.t; - http_version: int*int; - path: string; - path_components: string list; - query: (string*string) list; - body: 'body; - start_time: float; - (** Obtained via [get_time_s] in {!create} - @since 0.11 *) - } - (** A request with method, path, host, headers, and a body, sent by a client. - - The body is polymorphic because the request goes through - several transformations. First it has no body, as only the request - and headers are read; then it has a stream body; then the body might be - entirely read as a string via {!read_body_full}. - - @since 0.6 The field [query] was added and contains the query parameters in ["?foo=bar,x=y"] - @since 0.6 The field [path_components] is the part of the path that precedes [query] and is split on ["/"]. - @since 0.11 the type is a private alias - @since 0.11 the field [start_time] was added - *) - - val pp : Format.formatter -> string t -> unit - (** Pretty print the request and its body *) - - val pp_ : Format.formatter -> _ t -> unit - (** Pretty print the request without its body *) - - val headers : _ t -> Headers.t - (** List of headers of the request, including ["Host"] *) - - val get_header : ?f:(string->string) -> _ t -> string -> string option - - val get_header_int : _ t -> string -> int option - - val set_header : string -> string -> 'a t -> 'a t - (** [set_header k v req] sets [k: v] in the request [req]'s headers. *) - - val update_headers : (Headers.t -> Headers.t) -> 'a t -> 'a t - (** Modify headers - @since 0.11 *) - - val set_body : 'a -> _ t -> 'a t - (** [set_body b req] returns a new query whose body is [b]. - @since 0.11 *) - - val host : _ t -> string - (** Host field of the request. It also appears in the headers. *) - - val meth : _ t -> Meth.t - (** Method for the request. *) - - val path : _ t -> string - (** Request path. *) - - val query : _ t -> (string*string) list - (** Decode the query part of the {!path} field - @since 0.4 *) - - val body : 'b t -> 'b - (** Request body, possibly empty. *) - - val start_time : _ t -> float - (** time stamp (from {!Unix.gettimeofday}) after parsing the first line of the request - @since 0.11 *) - - val limit_body_size : max_size:int -> byte_stream t -> byte_stream t - (** Limit the body size to [max_size] bytes, or return - a [413] error. - @since 0.3 - *) - - val read_body_full : ?buf_size:int -> byte_stream t -> string t - (** Read the whole body into a string. Potentially blocking. - - @param buf_size initial size of underlying buffer (since 0.11) *) - - (**/**) - (* for testing purpose, do not use *) - module Internal_ : sig - val parse_req_start : ?buf:Buf_.t -> get_time_s:(unit -> float) -> byte_stream -> unit t option - val parse_body : ?buf:Buf_.t -> unit t -> byte_stream -> byte_stream t - end - (**/**) -end - -(** {2 Response Codes} *) - -module Response_code : sig - type t = int - (** A standard HTTP code. - - https://tools.ietf.org/html/rfc7231#section-6 *) - - val ok : t - (** The code [200] *) - - val not_found : t - (** The code [404] *) - - val descr : t -> string - (** A description of some of the error codes. - NOTE: this is not complete (yet). *) -end - -(** {2 Responses} - - Responses are what a http server, such as {!Tiny_httpd}, send back to - the client to answer a {!Request.t}*) - -module Response : sig - type body = [`String of string | `Stream of byte_stream | `Void] - (** Body of a response, either as a simple string, - or a stream of bytes, or nothing (for server-sent events). *) - - type t = private { - code: Response_code.t; (** HTTP response code. See {!Response_code}. *) - headers: Headers.t; (** Headers of the reply. Some will be set by [Tiny_httpd] automatically. *) - body: body; (** Body of the response. Can be empty. *) - } - (** A response to send back to a client. *) - - val set_body : body -> t -> t - (** Set the body of the response. - @since 0.11 *) - - val set_header : string -> string -> t -> t - (** Set a header. - @since 0.11 *) - - val update_headers : (Headers.t -> Headers.t) -> t -> t - (** Modify headers - @since 0.11 *) - - val set_headers : Headers.t -> t -> t - (** Set all headers. - @since 0.11 *) - - val set_code : Response_code.t -> t -> t - (** Set the response code. - @since 0.11 *) - - val make_raw : - ?headers:Headers.t -> - code:Response_code.t -> - string -> - t - (** Make a response from its raw components, with a string body. - Use [""] to not send a body at all. *) - - val make_raw_stream : - ?headers:Headers.t -> - code:Response_code.t -> - byte_stream -> - t - (** Same as {!make_raw} but with a stream body. The body will be sent with - the chunked transfer-encoding. *) - - val make : - ?headers:Headers.t -> - (body, Response_code.t * string) result -> t - (** [make r] turns a result into a response. - - - [make (Ok body)] replies with [200] and the body. - - [make (Error (code,msg))] replies with the given error code - and message as body. - *) - - val make_string : - ?headers:Headers.t -> - (string, Response_code.t * string) result -> t - (** Same as {!make} but with a string body. *) - - val make_stream : - ?headers:Headers.t -> - (byte_stream, Response_code.t * string) result -> t - (** Same as {!make} but with a stream body. *) - - val fail : ?headers:Headers.t -> code:int -> - ('a, unit, string, t) format4 -> 'a - (** Make the current request fail with the given code and message. - Example: [fail ~code:404 "oh noes, %s not found" "waldo"]. - *) - - val fail_raise : code:int -> ('a, unit, string, 'b) format4 -> 'a - (** Similar to {!fail} but raises an exception that exits the current handler. - This should not be used outside of a (path) handler. - Example: [fail_raise ~code:404 "oh noes, %s not found" "waldo"; never_executed()] - *) - - val pp : Format.formatter -> t -> unit - (** Pretty print the response. *) -end - -(** {2 Routing} - - Basic type-safe routing. - @since 0.6 *) -module Route : sig - type ('a, 'b) comp - (** An atomic component of a path *) - - type ('a, 'b) t - (** A route, composed of path components *) - - val int : (int -> 'a, 'a) comp - (** Matches an integer. *) - - val string : (string -> 'a, 'a) comp - (** Matches a string not containing ['/'] and binds it as is. *) - - val string_urlencoded : (string -> 'a, 'a) comp - (** Matches a URL-encoded string, and decodes it. *) - - val exact : string -> ('a, 'a) comp - (** [exact "s"] matches ["s"] and nothing else. *) - - val return : ('a, 'a) t - (** Matches the empty path. *) - - val rest_of_path : (string -> 'a, 'a) t - (** Matches a string, even containing ['/']. This will match - the entirety of the remaining route. - @since 0.7 *) - - val rest_of_path_urlencoded : (string -> 'a, 'a) t - (** Matches a string, even containing ['/'], an URL-decode it. - This will match the entirety of the remaining route. - @since 0.7 *) - - val (@/) : ('a, 'b) comp -> ('b, 'c) t -> ('a, 'c) t - (** [comp / route] matches ["foo/bar/…"] iff [comp] matches ["foo"], - and [route] matches ["bar/…"]. *) - - val exact_path : string -> ('a,'b) t -> ('a,'b) t - (** [exact_path "foo/bar/..." r] is equivalent to - [exact "foo" @/ exact "bar" @/ ... @/ r] - @since 0.11 **) - - val pp : Format.formatter -> _ t -> unit - (** Print the route. - @since 0.7 *) - - val to_string : _ t -> string - (** Print the route. - @since 0.7 *) -end - -(** {2 Middlewares} - - A middleware can be inserted in a handler to modify or observe - its behavior. - - @since 0.11 -*) -module Middleware : sig - type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit - (** Handlers are functions returning a response to a request. - The response can be delayed, hence the use of a continuation - as the [resp] parameter. *) - - type t = handler -> handler - (** A middleware is a handler transformation. - - It takes the existing handler [h], - and returns a new one which, given a query, modify it or log it - before passing it to [h], or fail. It can also log or modify or drop - the response. *) - - val nil : t - (** Trivial middleware that does nothing. *) -end - -(** {2 Main Server type} *) - -type t -(** A HTTP server. See {!create} for more details. *) - -val create : - ?masksigpipe:bool -> - ?max_connections:int -> - ?timeout:float -> - ?buf_size:int -> - ?get_time_s:(unit -> float) -> - ?new_thread:((unit -> unit) -> unit) -> - ?addr:string -> - ?port:int -> - ?sock:Unix.file_descr -> - ?middlewares:([`Encoding | `Stage of int] * Middleware.t) list -> - unit -> - t -(** Create a new webserver. - - The server will not do anything until {!run} is called on it. - Before starting the server, one can use {!add_path_handler} and - {!set_top_handler} to specify how to handle incoming requests. - - @param masksigpipe if true, block the signal {!Sys.sigpipe} which otherwise - tends to kill client threads when they try to write on broken sockets. Default: [true]. - - @param buf_size size for buffers (since 0.11) - - @param new_thread a function used to spawn a new thread to handle a - new client connection. By default it is {!Thread.create} but one - could use a thread pool instead. - - @param middlewares see {!add_middleware} for more details. - - @param max_connections maximum number of simultaneous connections. - @param timeout connection is closed if the socket does not do read or - write for the amount of second. Default: 0.0 which means no timeout. - timeout is not recommended when using proxy. - @param addr address (IPv4 or IPv6) to listen on. Default ["127.0.0.1"]. - @param port to listen on. Default [8080]. - @param sock an existing socket given to the server to listen on, e.g. by - systemd on Linux (or launchd on macOS). If passed in, this socket will be - used instead of the [addr] and [port]. If not passed in, those will be - used. This parameter exists since 0.10. - - @param get_time_s obtain the current timestamp in seconds. - This parameter exists since 0.11. -*) - -val addr : t -> string -(** Address on which the server listens. *) - -val is_ipv6 : t -> bool -(** [is_ipv6 server] returns [true] iff the address of the server is an IPv6 address. - @since 0.3 *) - -val port : t -> int -(** Port on which the server listens. *) - -val active_connections : t -> int -(** Number of active connections *) - -val add_decode_request_cb : - t -> - (unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) -> unit -[@@deprecated "use add_middleware"] -(** Add a callback for every request. - The callback can provide a stream transformer and a new request (with - modified headers, typically). - A possible use is to handle decompression by looking for a [Transfer-Encoding] - header and returning a stream transformer that decompresses on the fly. - - @deprecated use {!add_middleware} instead -*) - -val add_encode_response_cb: - t -> (unit Request.t -> Response.t -> Response.t option) -> unit -[@@deprecated "use add_middleware"] -(** Add a callback for every request/response pair. - Similarly to {!add_encode_response_cb} the callback can return a new - response, for example to compress it. - The callback is given the query with only its headers, - as well as the current response. - - @deprecated use {!add_middleware} instead -*) - -val add_middleware : - stage:[`Encoding | `Stage of int] -> - t -> Middleware.t -> unit -(** Add a middleware to every request/response pair. - @param stage specify when middleware applies. - Encoding comes first (outermost layer), then stages in increasing order. - @raise Invalid_argument if stage is [`Stage n] where [n < 1] - @since 0.11 -*) - -(** {2 Request handlers} *) - -val set_top_handler : t -> (string Request.t -> Response.t) -> unit -(** Setup a handler called by default. - - This handler is called with any request not accepted by any handler - installed via {!add_path_handler}. - If no top handler is installed, unhandled paths will return a [404] not found. *) - -val add_route_handler : - ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> - ?middlewares:Middleware.t list -> - ?meth:Meth.t -> - t -> - ('a, string Request.t -> Response.t) Route.t -> 'a -> - unit -(** [add_route_handler server Route.(exact "path" @/ string @/ int @/ return) f] - calls [f "foo" 42 request] when a [request] with path "path/foo/42/" - is received. - - Note that the handlers are called in the reverse order of their addition, - so the last registered handler can override previously registered ones. - - @param meth if provided, only accept requests with the given method. - Typically one could react to [`GET] or [`PUT]. - @param accept should return [Ok()] if the given request (before its body - is read) should be accepted, [Error (code,message)] if it's to be rejected (e.g. because - its content is too big, or for some permission error). - See the {!http_of_dir} program for an example of how to use [accept] to - filter uploads that are too large before the upload even starts. - The default always returns [Ok()], i.e. it accepts all requests. - - @since 0.6 -*) - -val add_route_handler_stream : - ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> - ?middlewares:Middleware.t list -> - ?meth:Meth.t -> - t -> - ('a, byte_stream Request.t -> Response.t) Route.t -> 'a -> - unit -(** Similar to {!add_route_handler}, but where the body of the request - is a stream of bytes that has not been read yet. - This is useful when one wants to stream the body directly into a parser, - json decoder (such as [Jsonm]) or into a file. - @since 0.6 *) - -(** {2 Server-sent events} - - {b EXPERIMENTAL}: this API is not stable yet. *) - -(** A server-side function to generate of Server-sent events. - - See {{: https://html.spec.whatwg.org/multipage/server-sent-events.html} the w3c page} - and {{: https://jvns.ca/blog/2021/01/12/day-36--server-sent-events-are-cool--and-a-fun-bug/} - this blog post}. - - @since 0.9 - *) -module type SERVER_SENT_GENERATOR = sig - val set_headers : Headers.t -> unit - (** Set headers of the response. - This is not mandatory but if used at all, it must be called before - any call to {!send_event} (once events are sent the response is - already sent too). *) - - val send_event : - ?event:string -> - ?id:string -> - ?retry:string -> - data:string -> - unit -> unit - (** Send an event from the server. - If data is a multiline string, it will be sent on separate "data:" lines. *) - - val close : unit -> unit - (** Close connection. - @since 0.11 *) -end - -type server_sent_generator = (module SERVER_SENT_GENERATOR) -(** Server-sent event generator - @since 0.9 *) - -val add_route_server_sent_handler : - ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> - t -> - ('a, string Request.t -> server_sent_generator -> unit) Route.t -> 'a -> - unit -(** Add a handler on an endpoint, that serves server-sent events. - - The callback is given a generator that can be used to send events - as it pleases. The connection is always closed by the client, - and the accepted method is always [GET]. - This will set the header "content-type" to "text/event-stream" automatically - and reply with a 200 immediately. - See {!server_sent_generator} for more details. - - This handler stays on the original thread (it is synchronous). - - @since 0.9 *) - -(** {2 Run the server} *) - -val stop : t -> unit -(** Ask the server to stop. This might not have an immediate effect - as {!run} might currently be waiting on IO. *) - -val run : t -> (unit, exn) result -(** Run the main loop of the server, listening on a socket - described at the server's creation time, using [new_thread] to - start a thread for each new client. - - This returns [Ok ()] if the server exits gracefully, or [Error e] if - it exits with an error. *) - -(**/**) - -val _debug : ((('a, out_channel, unit, unit, unit, unit) format6 -> 'a) -> unit) -> unit -val _enable_debug: bool -> unit - -(**/**) +module Dir = Tiny_httpd_dir diff --git a/src/Tiny_httpd_buf.ml b/src/Tiny_httpd_buf.ml new file mode 100644 index 00000000..dda9b653 --- /dev/null +++ b/src/Tiny_httpd_buf.ml @@ -0,0 +1,35 @@ + +type t = { + mutable bytes: bytes; + mutable i: int; +} + +let create ?(size=4_096) () : t = + { bytes=Bytes.make size ' '; i=0 } + +let size self = self.i +let bytes_slice self = self.bytes +let clear self : unit = + if Bytes.length self.bytes > 4_096 * 1_024 then ( + self.bytes <- Bytes.make 4096 ' '; (* free big buffer *) + ); + self.i <- 0 + +let resize self new_size : unit = + let new_buf = Bytes.make new_size ' ' in + Bytes.blit self.bytes 0 new_buf 0 self.i; + self.bytes <- new_buf + +let add_bytes (self:t) s i len : unit = + if self.i + len >= Bytes.length self.bytes then ( + resize self (self.i + self.i / 2 + len + 10); + ); + Bytes.blit s i self.bytes self.i len; + self.i <- self.i + len + +let contents (self:t) : string = Bytes.sub_string self.bytes 0 self.i + +let contents_and_clear (self:t) : string = + let x = contents self in + clear self; + x diff --git a/src/Tiny_httpd_buf.mli b/src/Tiny_httpd_buf.mli new file mode 100644 index 00000000..24f3c42b --- /dev/null +++ b/src/Tiny_httpd_buf.mli @@ -0,0 +1,27 @@ + +(** Simple buffer. + + These buffers are used to avoid allocating too many byte arrays when + processing streams and parsing requests. + + @since NEXT_RELEASE +*) + +type t +val size : t -> int +val clear : t -> unit +val create : ?size:int -> unit -> t +val contents : t -> string + +val bytes_slice : t -> bytes +(** Access underlying slice of bytes. + @since 0.5 *) + +val contents_and_clear : t -> string +(** Get contents of the buffer and clear it. + @since 0.5 *) + +val add_bytes : t -> bytes -> int -> int -> unit +(** Append given bytes slice to the buffer. + @since 0.5 *) + diff --git a/src/Tiny_httpd_dir.ml b/src/Tiny_httpd_dir.ml index 6fe9a6ac..d391ff0d 100644 --- a/src/Tiny_httpd_dir.ml +++ b/src/Tiny_httpd_dir.ml @@ -1,4 +1,4 @@ -module S = Tiny_httpd +module S = Tiny_httpd_server module U = Tiny_httpd_util module Pf = Printf @@ -66,7 +66,7 @@ module type VFS = sig val list_dir : string -> string array val delete : string -> unit val create : string -> (bytes -> int -> int -> unit) * (unit -> unit) - val read_file_content : string -> Tiny_httpd.Byte_stream.t + val read_file_content : string -> Tiny_httpd_stream.t val file_size : string -> int option val file_mtime : string -> float option end @@ -82,7 +82,7 @@ let vfs_of_dir (top:string) : vfs = let list_dir f = Sys.readdir (top // f) let read_file_content f = let ic = open_in_bin (top // f) in - S.Byte_stream.of_chan ic + Tiny_httpd_stream.of_chan ic let create f = let oc = open_out_bin (top // f) in let write = output oc in @@ -197,7 +197,7 @@ let add_vfs_ ~on_fs ~top ~config ~vfs:((module VFS:VFS) as vfs) ~prefix server : path (Printexc.to_string e) in let req = S.Request.limit_body_size ~max_size:config.max_upload_size req in - S.Byte_stream.iter write req.S.Request.body; + Tiny_httpd_stream.iter write req.S.Request.body; close (); S._debug (fun k->k "done uploading"); S.Response.make_raw ~code:201 "upload successful" @@ -367,7 +367,7 @@ module Embedded_fs = struct | _ -> false let read_file_content p = match find_ self p with - | Some (File {content;_}) -> Tiny_httpd.Byte_stream.of_string content + | Some (File {content;_}) -> Tiny_httpd_stream.of_string content | _ -> failwith (Printf.sprintf "no such file: %S" p) let list_dir p = S._debug (fun k->k "list dir %S" p); match find_ self p with diff --git a/src/Tiny_httpd_dir.mli b/src/Tiny_httpd_dir.mli index 16c1491a..67bf5f51 100644 --- a/src/Tiny_httpd_dir.mli +++ b/src/Tiny_httpd_dir.mli @@ -77,7 +77,7 @@ val add_dir_path : config:config -> dir:string -> prefix:string -> - Tiny_httpd.t -> unit + Tiny_httpd_server.t -> unit (** Virtual file system. @@ -105,7 +105,7 @@ module type VFS = sig val create : string -> (bytes -> int -> int -> unit) * (unit -> unit) (** Create a file and obtain a pair [write, close] *) - val read_file_content : string -> Tiny_httpd.Byte_stream.t + val read_file_content : string -> Tiny_httpd_stream.t (** Read content of a file *) val file_size : string -> int option @@ -125,7 +125,7 @@ val add_vfs : config:config -> vfs:(module VFS) -> prefix:string -> - Tiny_httpd.t -> unit + Tiny_httpd_server.t -> unit (** Similar to {!add_dir_path} but using a virtual file system instead. @since NEXT_RELEASE *) diff --git a/src/Tiny_httpd_server.ml b/src/Tiny_httpd_server.ml new file mode 100644 index 00000000..7547ee5e --- /dev/null +++ b/src/Tiny_httpd_server.ml @@ -0,0 +1,893 @@ + +type buf = Tiny_httpd_buf.t +type byte_stream = Tiny_httpd_stream.t + +let _debug_on = ref ( + match String.trim @@ Sys.getenv "HTTP_DBG" with + | "" -> false | _ -> true | exception _ -> false +) +let _enable_debug b = _debug_on := b +let _debug k = + if !_debug_on then ( + k (fun fmt-> + Printf.fprintf stdout "[http.thread %d]: " Thread.(id @@ self()); + Printf.kfprintf (fun oc -> Printf.fprintf oc "\n%!") stdout fmt) + ) + +module Buf = Tiny_httpd_buf + +module Byte_stream = Tiny_httpd_stream + +exception Bad_req of int * string +let bad_reqf c fmt = Printf.ksprintf (fun s ->raise (Bad_req (c,s))) fmt + +module Response_code = struct + type t = int + + let ok = 200 + let not_found = 404 + let descr = function + | 100 -> "Continue" + | 200 -> "OK" + | 201 -> "Created" + | 202 -> "Accepted" + | 204 -> "No content" + | 300 -> "Multiple choices" + | 301 -> "Moved permanently" + | 302 -> "Found" + | 304 -> "Not Modified" + | 400 -> "Bad request" + | 403 -> "Forbidden" + | 404 -> "Not found" + | 405 -> "Method not allowed" + | 408 -> "Request timeout" + | 409 -> "Conflict" + | 410 -> "Gone" + | 411 -> "Length required" + | 413 -> "Payload too large" + | 417 -> "Expectation failed" + | 500 -> "Internal server error" + | 501 -> "Not implemented" + | 503 -> "Service unavailable" + | n -> "Unknown response code " ^ string_of_int n (* TODO *) +end + +type 'a resp_result = ('a, Response_code.t * string) result +let unwrap_resp_result = function + | Ok x -> x + | Error (c,s) -> raise (Bad_req (c,s)) + +module Meth = struct + type t = [ + | `GET + | `PUT + | `POST + | `HEAD + | `DELETE + ] + + let to_string = function + | `GET -> "GET" + | `PUT -> "PUT" + | `HEAD -> "HEAD" + | `POST -> "POST" + | `DELETE -> "DELETE" + let pp out s = Format.pp_print_string out (to_string s) + + let of_string = function + | "GET" -> `GET + | "PUT" -> `PUT + | "POST" -> `POST + | "HEAD" -> `HEAD + | "DELETE" -> `DELETE + | s -> bad_reqf 400 "unknown method %S" s +end + +module Headers = struct + type t = (string * string) list + let empty = [] + let contains name headers = + let name' = String.lowercase_ascii name in + List.exists (fun (n, _) -> name'=n) headers + let get_exn ?(f=fun x->x) x h = + let x' = String.lowercase_ascii x in + List.assoc x' h |> f + let get ?(f=fun x -> x) x h = + try Some (get_exn ~f x h) with Not_found -> None + let remove x h = + let x' = String.lowercase_ascii x in + List.filter (fun (k,_) -> k<>x') h + let set x y h = + let x' = String.lowercase_ascii x in + (x',y) :: List.filter (fun (k,_) -> k<>x') h + let pp out l = + let pp_pair out (k,v) = Format.fprintf out "@[%s: %s@]" k v in + Format.fprintf out "@[%a@]" (Format.pp_print_list pp_pair) l + + (* token = 1*tchar + tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / "^" / "_" + / "`" / "|" / "~" / DIGIT / ALPHA ; any VCHAR, except delimiters + Reference: https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 *) + let is_tchar = function + | '0' .. '9' | 'a' .. 'z' | 'A' .. 'Z' + | '!' | '#' | '$' | '%' | '&' | '\'' | '*' | '+' | '-' | '.' | '^' + | '_' | '`' | '|' | '~' -> true + | _ -> false + + let for_all pred s = + try String.iter (fun c->if not (pred c) then raise Exit) s; true + with Exit -> false + + let parse_ ~buf (bs:byte_stream) : t = + let rec loop acc = + let line = Byte_stream.read_line ~buf bs in + _debug (fun k->k "parsed header line %S" line); + if line = "\r" then ( + acc + ) else ( + let k,v = + try + let i = String.index line ':' in + let k = String.sub line 0 i in + if not (for_all is_tchar k) then ( + invalid_arg (Printf.sprintf "Invalid header key: %S" k)); + let v = String.sub line (i+1) (String.length line-i-1) |> String.trim in + k,v + with _ -> bad_reqf 400 "invalid header line: %S" line + in + loop ((String.lowercase_ascii k,v)::acc) + ) + in + loop [] +end + +module Request = struct + type 'body t = { + meth: Meth.t; + host: string; + headers: Headers.t; + http_version: int*int; + path: string; + path_components: string list; + query: (string*string) list; + body: 'body; + start_time: float; + } + + let headers self = self.headers + let host self = self.host + let meth self = self.meth + let path self = self.path + let body self = self.body + let start_time self = self.start_time + + let query self = self.query + let get_header ?f self h = Headers.get ?f h self.headers + let get_header_int self h = match get_header self h with + | Some x -> (try Some (int_of_string x) with _ -> None) + | None -> None + let set_header k v self = {self with headers=Headers.set k v self.headers} + let update_headers f self = {self with headers=f self.headers} + let set_body b self = {self with body=b} + + (** Should we close the connection after this request? *) + let close_after_req (self:_ t) : bool = + match self.http_version with + | 1, 1 -> get_header self "connection" =Some"close" + | 1, 0 -> not (get_header self "connection"=Some"keep-alive") + | _ -> false + + let pp_comp_ out comp = + Format.fprintf out "[%s]" + (String.concat ";" @@ List.map (Printf.sprintf "%S") comp) + let pp_query out q = + Format.fprintf out "[%s]" + (String.concat ";" @@ + List.map (fun (a,b) -> Printf.sprintf "%S,%S" a b) q) + let pp_ out self : unit = + Format.fprintf out "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ \ + path=%S;@ body=?;@ path_components=%a;@ query=%a@]}" + (Meth.to_string self.meth) self.host Headers.pp self.headers self.path + pp_comp_ self.path_components pp_query self.query + let pp out self : unit = + Format.fprintf out "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ path=%S;@ \ + body=%S;@ path_components=%a;@ query=%a@]}" + (Meth.to_string self.meth) self.host Headers.pp self.headers + self.path self.body pp_comp_ self.path_components pp_query self.query + + (* decode a "chunked" stream into a normal stream *) + let read_stream_chunked_ ?buf (bs:byte_stream) : byte_stream = + _debug (fun k->k "body: start reading chunked stream..."); + Byte_stream.read_chunked ?buf + ~fail:(fun s -> Bad_req (400, s)) + bs + + let limit_body_size_ ~max_size (bs:byte_stream) : byte_stream = + _debug (fun k->k "limit size of body to max-size=%d" max_size); + Byte_stream.limit_size_to ~max_size ~close_rec:false bs + ~too_big:(fun size -> + (* read too much *) + bad_reqf 413 + "body size was supposed to be %d, but at least %d bytes received" + max_size size + ) + + let limit_body_size ~max_size (req:byte_stream t) : byte_stream t = + { req with body=limit_body_size_ ~max_size req.body } + + (* read exactly [size] bytes from the stream *) + let read_exactly ~size (bs:byte_stream) : byte_stream = + _debug (fun k->k "body: must read exactly %d bytes" size); + Byte_stream.read_exactly bs ~close_rec:false + ~size ~too_short:(fun size -> + bad_reqf 400 "body is too short by %d bytes" size + ) + + (* parse request, but not body (yet) *) + let parse_req_start ~get_time_s ~buf (bs:byte_stream) : unit t option resp_result = + try + let line = Byte_stream.read_line ~buf bs in + let start_time = get_time_s() in + let meth, path, version = + try + let meth, path, version = Scanf.sscanf line "%s %s HTTP/1.%d\r" (fun x y z->x,y,z) in + if version != 0 && version != 1 then raise Exit; + meth, path, version + with _ -> + _debug (fun k->k "invalid request line: `%s`" line); + raise (Bad_req (400, "Invalid request line")) + in + let meth = Meth.of_string meth in + _debug (fun k->k "got meth: %s, path %S" (Meth.to_string meth) path); + let headers = Headers.parse_ ~buf bs in + let host = + match Headers.get "Host" headers with + | None -> bad_reqf 400 "No 'Host' header in request" + | Some h -> h + in + let path_components, query = Tiny_httpd_util.split_query path in + let path_components = Tiny_httpd_util.split_on_slash path_components in + let query = + match Tiny_httpd_util.(parse_query query) with + | Ok l -> l + | Error e -> bad_reqf 400 "invalid query: %s" e + in + let req = { + meth; query; host; path; path_components; + headers; http_version=(1, version); body=(); start_time; + } in + Ok (Some req) + with + | End_of_file | Sys_error _ -> Ok None + | Bad_req (c,s) -> Error (c,s) + | e -> + Error (400, Printexc.to_string e) + + (* parse body, given the headers. + @param tr_stream a transformation of the input stream. *) + let parse_body_ ~tr_stream ~buf (req:byte_stream t) : byte_stream t resp_result = + try + let size = + match Headers.get_exn "Content-Length" req.headers |> int_of_string with + | n -> n (* body of fixed size *) + | exception Not_found -> 0 + | exception _ -> bad_reqf 400 "invalid content-length" + in + let body = + match get_header ~f:String.trim req "Transfer-Encoding" with + | None -> read_exactly ~size @@ tr_stream req.body + | Some "chunked" -> + let bs = + read_stream_chunked_ ~buf @@ tr_stream req.body (* body sent by chunks *) + in + if size>0 then limit_body_size_ ~max_size:size bs else bs + | Some s -> bad_reqf 500 "cannot handle transfer encoding: %s" s + in + Ok {req with body} + with + | End_of_file -> Error (400, "unexpected end of file") + | Bad_req (c,s) -> Error (c,s) + | e -> + Error (400, Printexc.to_string e) + + let read_body_full ?buf_size (self:byte_stream t) : string t = + try + let buf = Buf.create ?size:buf_size () in + let body = Byte_stream.read_all ~buf self.body in + { self with body } + with + | Bad_req _ as e -> raise e + | e -> bad_reqf 500 "failed to read body: %s" (Printexc.to_string e) + + module Internal_ = struct + let parse_req_start ?(buf=Buf.create()) ~get_time_s bs = + parse_req_start ~get_time_s ~buf bs |> unwrap_resp_result + + let parse_body ?(buf=Buf.create()) req bs : _ t = + parse_body_ ~tr_stream:(fun s->s) ~buf {req with body=bs} |> unwrap_resp_result + end +end + +(*$R + let q = "GET hello HTTP/1.1\r\nHost: coucou\r\nContent-Length: 11\r\n\r\nsalutationsSOMEJUNK" in + let str = Tiny_httpd.Byte_stream.of_string q in + let r = Request.Internal_.parse_req_start ~get_time_s:(fun _ -> 0.) str in + match r with + | None -> assert_failure "should parse" + | Some req -> + assert_equal (Some "coucou") (Headers.get "Host" req.Request.headers); + assert_equal (Some "coucou") (Headers.get "host" req.Request.headers); + assert_equal (Some "11") (Headers.get "content-length" req.Request.headers); + assert_equal "hello" req.Request.path; + let req = Request.Internal_.parse_body req str |> Request.read_body_full in + assert_equal ~printer:(fun s->s) "salutations" req.Request.body; + () +*) + +module Response = struct + type body = [`String of string | `Stream of byte_stream | `Void] + type t = { + code: Response_code.t; + headers: Headers.t; + body: body; + } + + let set_body body self = {self with body} + let set_headers headers self = {self with headers} + let update_headers f self = {self with headers=f self.headers} + let set_header k v self = {self with headers = Headers.set k v self.headers} + let set_code code self = {self with code} + + let make_raw ?(headers=[]) ~code body : t = + (* add content length to response *) + let headers = + Headers.set "Content-Length" (string_of_int (String.length body)) headers + in + { code; headers; body=`String body; } + + let make_raw_stream ?(headers=[]) ~code body : t = + (* add content length to response *) + let headers = Headers.set "Transfer-Encoding" "chunked" headers in + { code; headers; body=`Stream body; } + + let make_void ?(headers=[]) ~code () : t = + { code; headers; body=`Void; } + + let make_string ?headers r = match r with + | Ok body -> make_raw ?headers ~code:200 body + | Error (code,msg) -> make_raw ?headers ~code msg + + let make_stream ?headers r = match r with + | Ok body -> make_raw_stream ?headers ~code:200 body + | Error (code,msg) -> make_raw ?headers ~code msg + + let make ?headers r : t = match r with + | Ok (`String body) -> make_raw ?headers ~code:200 body + | Ok (`Stream body) -> make_raw_stream ?headers ~code:200 body + | Ok `Void -> make_void ?headers ~code:200 () + | Error (code,msg) -> make_raw ?headers ~code msg + + let fail ?headers ~code fmt = + Printf.ksprintf (fun msg -> make_raw ?headers ~code msg) fmt + let fail_raise ~code fmt = + Printf.ksprintf (fun msg -> raise (Bad_req (code,msg))) fmt + + let pp out self : unit = + let pp_body out = function + | `String s -> Format.fprintf out "%S" s + | `Stream _ -> Format.pp_print_string out "" + | `Void -> () + in + Format.fprintf out "{@[code=%d;@ headers=[@[%a@]];@ body=%a@]}" + self.code Headers.pp self.headers pp_body self.body + + let output_ (oc:out_channel) (self:t) : unit = + Printf.fprintf oc "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code); + let body, is_chunked = match self.body with + | `String s when String.length s > 1024 * 500 -> + (* chunk-encode large bodies *) + `Stream (Byte_stream.of_string s), true + | `String _ as b -> b, false + | `Stream _ as b -> b, true + | `Void as b -> b, false + in + let headers = + if is_chunked then ( + self.headers + |> Headers.set "transfer-encoding" "chunked" + |> Headers.remove "content-length" + ) else self.headers + in + let self = {self with headers; body} in + _debug (fun k->k "output response: %s" + (Format.asprintf "%a" pp {self with body=`String "<…>"})); + List.iter (fun (k,v) -> Printf.fprintf oc "%s: %s\r\n" k v) headers; + output_string oc "\r\n"; + begin match body with + | `String "" | `Void -> () + | `String s -> output_string oc s; + | `Stream str -> Byte_stream.output_chunked oc str; + end; + flush oc +end + +(* semaphore, for limiting concurrency. *) +module Sem_ = struct + type t = { + mutable n : int; + max : int; + mutex : Mutex.t; + cond : Condition.t; + } + + let create n = + if n <= 0 then invalid_arg "Semaphore.create"; + { n; max=n; mutex=Mutex.create(); cond=Condition.create(); } + + let acquire m t = + Mutex.lock t.mutex; + while t.n < m do + Condition.wait t.cond t.mutex; + done; + assert (t.n >= m); + t.n <- t.n - m; + Condition.broadcast t.cond; + Mutex.unlock t.mutex + + let release m t = + Mutex.lock t.mutex; + t.n <- t.n + m; + Condition.broadcast t.cond; + Mutex.unlock t.mutex + + let num_acquired t = t.max - t.n +end + +module Route = struct + type path = string list (* split on '/' *) + + type (_, _) comp = + | Exact : string -> ('a, 'a) comp + | Int : (int -> 'a, 'a) comp + | String : (string -> 'a, 'a) comp + | String_urlencoded : (string -> 'a, 'a) comp + + type (_, _) t = + | Fire : ('b, 'b) t + | Rest : { + url_encoded: bool; + } -> (string -> 'b, 'b) t + | Compose: ('a, 'b) comp * ('b, 'c) t -> ('a, 'c) t + + let return = Fire + let rest_of_path = Rest {url_encoded=false} + let rest_of_path_urlencoded = Rest {url_encoded=true} + let (@/) a b = Compose (a,b) + let string = String + let string_urlencoded = String_urlencoded + let int = Int + let exact (s:string) = Exact s + let exact_path (s:string) tail = + let rec fn = function + | [] -> tail + | ""::ls -> fn ls + | s::ls -> exact s @/ fn ls + in + fn (String.split_on_char '/' s) + let rec eval : + type a b. path -> (a,b) t -> a -> b option = + fun path route f -> + begin match path, route with + | [], Fire -> Some f + | _, Fire -> None + | _, Rest {url_encoded} -> + let whole_path = String.concat "/" path in + begin match + if url_encoded + then match Tiny_httpd_util.percent_decode whole_path with + | Some s -> s + | None -> raise_notrace Exit + else whole_path + with + | whole_path -> + Some (f whole_path) + | exception Exit -> None + end + | (c1 :: path'), Compose (comp, route') -> + begin match comp with + | Int -> + begin match int_of_string c1 with + | i -> eval path' route' (f i) + | exception _ -> None + end + | String -> + eval path' route' (f c1) + | String_urlencoded -> + begin match Tiny_httpd_util.percent_decode c1 with + | None -> None + | Some s -> eval path' route' (f s) + end + | Exact s -> + if s = c1 then eval path' route' f else None + end + | [], Compose (String, Fire) -> Some (f "") (* trailing *) + | [], Compose (String_urlencoded, Fire) -> Some (f "") (* trailing *) + | [], Compose _ -> None + end + + let bpf = Printf.bprintf + let rec pp_ + : type a b. Buffer.t -> (a,b) t -> unit + = fun out -> function + | Fire -> bpf out "/" + | Rest {url_encoded} -> + bpf out "" (if url_encoded then "_urlencoded" else "") + | Compose (Exact s, tl) -> bpf out "%s/%a" s pp_ tl + | Compose (Int, tl) -> bpf out "/%a" pp_ tl + | Compose (String, tl) -> bpf out "/%a" pp_ tl + | Compose (String_urlencoded, tl) -> bpf out "/%a" pp_ tl + + let to_string x = + let b = Buffer.create 16 in + pp_ b x; + Buffer.contents b + let pp out x = Format.pp_print_string out (to_string x) +end + +module Middleware = struct + type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit + type t = handler -> handler + + (** Apply a list of middlewares to [h] *) + let apply_l (l:t list) (h:handler) : handler = + List.fold_right (fun m h -> m h) l h + + let[@inline] nil : t = fun h -> h +end + +(* a request handler. handles a single request. *) +type cb_path_handler = + out_channel -> + Middleware.handler + +module type SERVER_SENT_GENERATOR = sig + val set_headers : Headers.t -> unit + val send_event : + ?event:string -> + ?id:string -> + ?retry:string -> + data:string -> + unit -> unit + val close : unit -> unit +end +type server_sent_generator = (module SERVER_SENT_GENERATOR) + +type t = { + addr: string; + + port: int; + + sock: Unix.file_descr option; + + timeout: float; + + sem_max_connections: Sem_.t; + (* semaphore to restrict the number of active concurrent connections *) + + new_thread: (unit -> unit) -> unit; + (* a function to run the given callback in a separate thread (or thread pool) *) + + masksigpipe: bool; + + buf_size: int; + + get_time_s : unit -> float; + + mutable handler: (string Request.t -> Response.t); + (* toplevel handler, if any *) + + mutable middlewares : (int * Middleware.t) list; + (** Global middlewares *) + + mutable middlewares_sorted : (int * Middleware.t) list lazy_t; + (* sorted version of {!middlewares} *) + + mutable path_handlers : (unit Request.t -> cb_path_handler resp_result option) list; + (* path handlers *) + + mutable running: bool; + (* true while the server is running. no need to protect with a mutex, + writes should be atomic enough. *) +} + +let addr self = self.addr +let port self = self.port + +let active_connections self = Sem_.num_acquired self.sem_max_connections - 1 + +let add_middleware ~stage self m = + let stage = match stage with + | `Encoding -> 0 + | `Stage n when n < 1 -> invalid_arg "add_middleware: bad stage" + | `Stage n -> n + in + self.middlewares <- (stage,m) :: self.middlewares; + self.middlewares_sorted <- lazy ( + List.stable_sort (fun (s1,_) (s2,_) -> compare s1 s2) self.middlewares + ) + +let add_decode_request_cb self f = + (* turn it into a middleware *) + let m h req ~resp = + (* see if [f] modifies the stream *) + let req0 = {req with Request.body=()} in + match f req0 with + | None -> h req ~resp (* pass through *) + | Some (req1, tr_stream) -> + let req = {req1 with Request.body=tr_stream req.Request.body} in + h req ~resp + in + add_middleware self ~stage:`Encoding m + +let add_encode_response_cb self f = + let m h req ~resp = + h req ~resp:(fun r -> + let req0 = {req with Request.body=()} in + (* now transform [r] if we want to *) + match f req0 r with + | None -> resp r + | Some r' -> resp r') + in + add_middleware self ~stage:`Encoding m + +let set_top_handler self f = self.handler <- f + +(* route the given handler. + @param tr_req wraps the actual concrete function returned by the route + and makes it into a handler. *) +let add_route_handler_ + ?(accept=fun _req -> Ok ()) ?(middlewares=[]) + ?meth ~tr_req self (route:_ Route.t) f = + let ph req : cb_path_handler resp_result option = + match meth with + | Some m when m <> req.Request.meth -> None (* ignore *) + | _ -> + begin match Route.eval req.Request.path_components route f with + | Some handler -> + (* we have a handler, do we accept the request based on its headers? *) + begin match accept req with + | Ok () -> + Some (Ok (fun oc -> + Middleware.apply_l middlewares @@ + fun req ~resp -> tr_req oc req ~resp handler)) + | Error _ as e -> Some e + end + | None -> + None (* path didn't match *) + end + in + self.path_handlers <- ph :: self.path_handlers + +let add_route_handler (type a) ?accept ?middlewares ?meth + self (route:(a,_) Route.t) (f:_) : unit = + let tr_req _oc req ~resp f = resp (f (Request.read_body_full ~buf_size:self.buf_size req)) in + add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f + +let add_route_handler_stream ?accept ?middlewares ?meth self route f = + let tr_req _oc req ~resp f = resp (f req) in + add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f + +let[@inline] _opt_iter ~f o = match o with + | None -> () + | Some x -> f x + +let add_route_server_sent_handler ?accept self route f = + let tr_req oc req ~resp f = + let req = Request.read_body_full ~buf_size:self.buf_size req in + let headers = ref Headers.(empty |> set "content-type" "text/event-stream") in + + (* send response once *) + let resp_sent = ref false in + let send_response_idempotent_ () = + if not !resp_sent then ( + resp_sent := true; + (* send 200 response now *) + let initial_resp = Response.make_void ~headers:!headers ~code:200 () in + resp initial_resp; + ) + in + + let send_event ?event ?id ?retry ~data () : unit = + send_response_idempotent_(); + _opt_iter event ~f:(fun e -> Printf.fprintf oc "event: %s\n" e); + _opt_iter id ~f:(fun e -> Printf.fprintf oc "id: %s\n" e); + _opt_iter retry ~f:(fun e -> Printf.fprintf oc "retry: %s\n" e); + let l = String.split_on_char '\n' data in + List.iter (fun s -> Printf.fprintf oc "data: %s\n" s) l; + output_string oc "\n"; (* finish group *) + flush oc + in + let module SSG = struct + let set_headers h = + if not !resp_sent then ( + headers := List.rev_append h !headers; + send_response_idempotent_() + ) + let send_event = send_event + let close () = raise Exit + end in + try f req (module SSG : SERVER_SENT_GENERATOR); + with Exit -> close_out oc + in + add_route_handler_ self ?accept ~meth:`GET route ~tr_req f + +let create + ?(masksigpipe=true) + ?(max_connections=32) + ?(timeout=0.0) + ?(buf_size=16 * 1_024) + ?(get_time_s=Unix.gettimeofday) + ?(new_thread=(fun f -> ignore (Thread.create f () : Thread.t))) + ?(addr="127.0.0.1") ?(port=8080) ?sock + ?(middlewares=[]) + () : t = + let handler _req = Response.fail ~code:404 "no top handler" in + let max_connections = max 4 max_connections in + let self = { + new_thread; addr; port; sock; masksigpipe; handler; buf_size; + running= true; sem_max_connections=Sem_.create max_connections; + path_handlers=[]; timeout; get_time_s; + middlewares=[]; middlewares_sorted=lazy []; + } in + List.iter (fun (stage,m) -> add_middleware self ~stage m) middlewares; + self + +let stop s = s.running <- false + +let find_map f l = + let rec aux f = function + | [] -> None + | x::l' -> + match f x with + | Some _ as res -> res + | None -> aux f l' + in aux f l + +let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = + Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout); + Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout); + let ic = Unix.in_channel_of_descr client_sock in + let oc = Unix.out_channel_of_descr client_sock in + let buf = Buf.create ~size:self.buf_size () in + let is = Byte_stream.of_chan ~buf_size:self.buf_size ic in + let continue = ref true in + while !continue && self.running do + _debug (fun k->k "read next request"); + match Request.parse_req_start ~get_time_s:self.get_time_s ~buf is with + | Ok None -> + continue := false (* client is done *) + + | Error (c,s) -> + (* connection error, close *) + let res = Response.make_raw ~code:c s in + begin + try Response.output_ oc res + with Sys_error _ -> () + end; + continue := false + + | Ok (Some req) -> + _debug (fun k->k "req: %s" (Format.asprintf "@[%a@]" Request.pp_ req)); + + if Request.close_after_req req then continue := false; + + try + (* is there a handler for this path? *) + let handler = + match find_map (fun ph -> ph req) self.path_handlers with + | Some f -> unwrap_resp_result f + | None -> + (fun _oc req ~resp -> + let body_str = Request.read_body_full ~buf_size:self.buf_size req in + resp (self.handler body_str)) + in + + (* handle expect/continue *) + begin match Request.get_header ~f:String.trim req "Expect" with + | Some "100-continue" -> + _debug (fun k->k "send back: 100 CONTINUE"); + Response.output_ oc (Response.make_raw ~code:100 ""); + | Some s -> bad_reqf 417 "unknown expectation %s" s + | None -> () + end; + + (* apply middlewares *) + let handler = + fun oc -> + List.fold_right (fun (_, m) h -> m h) + (Lazy.force self.middlewares_sorted) (handler oc) + in + + (* now actually read request's body into a stream *) + let req = + Request.parse_body_ ~tr_stream:(fun s->s) ~buf {req with body=is} + |> unwrap_resp_result + in + + (* how to reply *) + let resp r = + try + if Headers.get "connection" r.Response.headers = Some"close" then + continue := false; + Response.output_ oc r + with Sys_error _ -> continue := false + in + + (* call handler *) + begin + try handler oc req ~resp + with Sys_error _ -> continue := false + end + with + | Sys_error _ -> + continue := false; (* connection broken somehow *) + | Bad_req (code,s) -> + continue := false; + Response.output_ oc @@ Response.make_raw ~code s + | e -> + continue := false; + Response.output_ oc @@ Response.fail ~code:500 "server error: %s" (Printexc.to_string e) + done; + _debug (fun k->k "done with client, exiting"); + (try Unix.close client_sock + with e -> _debug (fun k->k "error when closing sock: %s" (Printexc.to_string e))); + () + +let is_ipv6 self = String.contains self.addr ':' + +let run (self:t) : (unit,_) result = + try + if self.masksigpipe then ( + ignore (Unix.sigprocmask Unix.SIG_BLOCK [Sys.sigpipe] : _ list); + ); + let sock, should_bind = match self.sock with + | Some s -> + s, false (* Because we're getting a socket from the caller (e.g. systemd) *) + | None -> + Unix.socket + (if is_ipv6 self then Unix.PF_INET6 else Unix.PF_INET) + Unix.SOCK_STREAM + 0, + true (* Because we're creating the socket ourselves *) + in + Unix.clear_nonblock sock; + Unix.setsockopt_optint sock Unix.SO_LINGER None; + begin if should_bind then + let inet_addr = Unix.inet_addr_of_string self.addr in + Unix.setsockopt sock Unix.SO_REUSEADDR true; + Unix.bind sock (Unix.ADDR_INET (inet_addr, self.port)); + Unix.listen sock (2 * self.sem_max_connections.Sem_.n) + end; + while self.running do + (* limit concurrency *) + Sem_.acquire 1 self.sem_max_connections; + try + let client_sock, _ = Unix.accept sock in + self.new_thread + (fun () -> + try + handle_client_ self client_sock; + Sem_.release 1 self.sem_max_connections; + with e -> + (try Unix.close client_sock with _ -> ()); + Sem_.release 1 self.sem_max_connections; + raise e + ); + with e -> + Sem_.release 1 self.sem_max_connections; + _debug (fun k -> k + "Unix.accept or Thread.create raised an exception: %s" + (Printexc.to_string e)) + done; + Ok () + with e -> Error e diff --git a/src/Tiny_httpd_server.mli b/src/Tiny_httpd_server.mli new file mode 100644 index 00000000..59eaacec --- /dev/null +++ b/src/Tiny_httpd_server.mli @@ -0,0 +1,567 @@ + +(** HTTP server. + + This library implements a very simple, basic HTTP/1.1 server using blocking + IOs and threads. + + It is possible to use a thread pool, see {!create}'s argument [new_thread]. +*) + +type buf = Tiny_httpd_buf.t +type byte_stream = Tiny_httpd_stream.t + +(** {2 Methods} *) + +module Meth : sig + type t = [ + | `GET + | `PUT + | `POST + | `HEAD + | `DELETE + ] + (** A HTTP method. + For now we only handle a subset of these. + + See https://tools.ietf.org/html/rfc7231#section-4 *) + + val pp : Format.formatter -> t -> unit + val to_string : t -> string +end + +(** {2 Headers} + + Headers are metadata associated with a request or response. *) + +module Headers : sig + type t = (string * string) list + (** The header files of a request or response. + + Neither the key nor the value can contain ['\r'] or ['\n']. + See https://tools.ietf.org/html/rfc7230#section-3.2 *) + + val empty : t + (** Empty list of headers + @since 0.5 *) + + val get : ?f:(string->string) -> string -> t -> string option + (** [get k headers] looks for the header field with key [k]. + @param f if provided, will transform the value before it is returned. *) + + val set : string -> string -> t -> t + (** [set k v headers] sets the key [k] to value [v]. + It erases any previous entry for [k] *) + + val remove : string -> t -> t + (** Remove the key from the headers, if present. *) + + val contains : string -> t -> bool + (** Is there a header with the given key? *) + + val pp : Format.formatter -> t -> unit + (** Pretty print the headers. *) +end + +(** {2 Requests} + + Requests are sent by a client, e.g. a web browser or cURL. *) + +module Request : sig + type 'body t = private { + meth: Meth.t; + host: string; + headers: Headers.t; + http_version: int*int; + path: string; + path_components: string list; + query: (string*string) list; + body: 'body; + start_time: float; + (** Obtained via [get_time_s] in {!create} + @since 0.11 *) + } + (** A request with method, path, host, headers, and a body, sent by a client. + + The body is polymorphic because the request goes through + several transformations. First it has no body, as only the request + and headers are read; then it has a stream body; then the body might be + entirely read as a string via {!read_body_full}. + + @since 0.6 The field [query] was added and contains the query parameters in ["?foo=bar,x=y"] + @since 0.6 The field [path_components] is the part of the path that precedes [query] and is split on ["/"]. + @since 0.11 the type is a private alias + @since 0.11 the field [start_time] was added + *) + + val pp : Format.formatter -> string t -> unit + (** Pretty print the request and its body *) + + val pp_ : Format.formatter -> _ t -> unit + (** Pretty print the request without its body *) + + val headers : _ t -> Headers.t + (** List of headers of the request, including ["Host"] *) + + val get_header : ?f:(string->string) -> _ t -> string -> string option + + val get_header_int : _ t -> string -> int option + + val set_header : string -> string -> 'a t -> 'a t + (** [set_header k v req] sets [k: v] in the request [req]'s headers. *) + + val update_headers : (Headers.t -> Headers.t) -> 'a t -> 'a t + (** Modify headers + @since 0.11 *) + + val set_body : 'a -> _ t -> 'a t + (** [set_body b req] returns a new query whose body is [b]. + @since 0.11 *) + + val host : _ t -> string + (** Host field of the request. It also appears in the headers. *) + + val meth : _ t -> Meth.t + (** Method for the request. *) + + val path : _ t -> string + (** Request path. *) + + val query : _ t -> (string*string) list + (** Decode the query part of the {!path} field + @since 0.4 *) + + val body : 'b t -> 'b + (** Request body, possibly empty. *) + + val start_time : _ t -> float + (** time stamp (from {!Unix.gettimeofday}) after parsing the first line of the request + @since 0.11 *) + + val limit_body_size : max_size:int -> byte_stream t -> byte_stream t + (** Limit the body size to [max_size] bytes, or return + a [413] error. + @since 0.3 + *) + + val read_body_full : ?buf_size:int -> byte_stream t -> string t + (** Read the whole body into a string. Potentially blocking. + + @param buf_size initial size of underlying buffer (since 0.11) *) + + (**/**) + (* for testing purpose, do not use *) + module Internal_ : sig + val parse_req_start : ?buf:buf -> get_time_s:(unit -> float) -> byte_stream -> unit t option + val parse_body : ?buf:buf -> unit t -> byte_stream -> byte_stream t + end + (**/**) +end + +(** {2 Response Codes} *) + +module Response_code : sig + type t = int + (** A standard HTTP code. + + https://tools.ietf.org/html/rfc7231#section-6 *) + + val ok : t + (** The code [200] *) + + val not_found : t + (** The code [404] *) + + val descr : t -> string + (** A description of some of the error codes. + NOTE: this is not complete (yet). *) +end + +(** {2 Responses} + + Responses are what a http server, such as {!Tiny_httpd}, send back to + the client to answer a {!Request.t}*) + +module Response : sig + type body = [`String of string | `Stream of byte_stream | `Void] + (** Body of a response, either as a simple string, + or a stream of bytes, or nothing (for server-sent events). *) + + type t = private { + code: Response_code.t; (** HTTP response code. See {!Response_code}. *) + headers: Headers.t; (** Headers of the reply. Some will be set by [Tiny_httpd] automatically. *) + body: body; (** Body of the response. Can be empty. *) + } + (** A response to send back to a client. *) + + val set_body : body -> t -> t + (** Set the body of the response. + @since 0.11 *) + + val set_header : string -> string -> t -> t + (** Set a header. + @since 0.11 *) + + val update_headers : (Headers.t -> Headers.t) -> t -> t + (** Modify headers + @since 0.11 *) + + val set_headers : Headers.t -> t -> t + (** Set all headers. + @since 0.11 *) + + val set_code : Response_code.t -> t -> t + (** Set the response code. + @since 0.11 *) + + val make_raw : + ?headers:Headers.t -> + code:Response_code.t -> + string -> + t + (** Make a response from its raw components, with a string body. + Use [""] to not send a body at all. *) + + val make_raw_stream : + ?headers:Headers.t -> + code:Response_code.t -> + byte_stream -> + t + (** Same as {!make_raw} but with a stream body. The body will be sent with + the chunked transfer-encoding. *) + + val make : + ?headers:Headers.t -> + (body, Response_code.t * string) result -> t + (** [make r] turns a result into a response. + + - [make (Ok body)] replies with [200] and the body. + - [make (Error (code,msg))] replies with the given error code + and message as body. + *) + + val make_string : + ?headers:Headers.t -> + (string, Response_code.t * string) result -> t + (** Same as {!make} but with a string body. *) + + val make_stream : + ?headers:Headers.t -> + (byte_stream, Response_code.t * string) result -> t + (** Same as {!make} but with a stream body. *) + + val fail : ?headers:Headers.t -> code:int -> + ('a, unit, string, t) format4 -> 'a + (** Make the current request fail with the given code and message. + Example: [fail ~code:404 "oh noes, %s not found" "waldo"]. + *) + + val fail_raise : code:int -> ('a, unit, string, 'b) format4 -> 'a + (** Similar to {!fail} but raises an exception that exits the current handler. + This should not be used outside of a (path) handler. + Example: [fail_raise ~code:404 "oh noes, %s not found" "waldo"; never_executed()] + *) + + val pp : Format.formatter -> t -> unit + (** Pretty print the response. *) +end + +(** {2 Routing} + + Basic type-safe routing. + @since 0.6 *) +module Route : sig + type ('a, 'b) comp + (** An atomic component of a path *) + + type ('a, 'b) t + (** A route, composed of path components *) + + val int : (int -> 'a, 'a) comp + (** Matches an integer. *) + + val string : (string -> 'a, 'a) comp + (** Matches a string not containing ['/'] and binds it as is. *) + + val string_urlencoded : (string -> 'a, 'a) comp + (** Matches a URL-encoded string, and decodes it. *) + + val exact : string -> ('a, 'a) comp + (** [exact "s"] matches ["s"] and nothing else. *) + + val return : ('a, 'a) t + (** Matches the empty path. *) + + val rest_of_path : (string -> 'a, 'a) t + (** Matches a string, even containing ['/']. This will match + the entirety of the remaining route. + @since 0.7 *) + + val rest_of_path_urlencoded : (string -> 'a, 'a) t + (** Matches a string, even containing ['/'], an URL-decode it. + This will match the entirety of the remaining route. + @since 0.7 *) + + val (@/) : ('a, 'b) comp -> ('b, 'c) t -> ('a, 'c) t + (** [comp / route] matches ["foo/bar/…"] iff [comp] matches ["foo"], + and [route] matches ["bar/…"]. *) + + val exact_path : string -> ('a,'b) t -> ('a,'b) t + (** [exact_path "foo/bar/..." r] is equivalent to + [exact "foo" @/ exact "bar" @/ ... @/ r] + @since 0.11 **) + + val pp : Format.formatter -> _ t -> unit + (** Print the route. + @since 0.7 *) + + val to_string : _ t -> string + (** Print the route. + @since 0.7 *) +end + +(** {2 Middlewares} + + A middleware can be inserted in a handler to modify or observe + its behavior. + + @since 0.11 +*) +module Middleware : sig + type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit + (** Handlers are functions returning a response to a request. + The response can be delayed, hence the use of a continuation + as the [resp] parameter. *) + + type t = handler -> handler + (** A middleware is a handler transformation. + + It takes the existing handler [h], + and returns a new one which, given a query, modify it or log it + before passing it to [h], or fail. It can also log or modify or drop + the response. *) + + val nil : t + (** Trivial middleware that does nothing. *) +end + +(** {2 Main Server type} *) + +type t +(** A HTTP server. See {!create} for more details. *) + +val create : + ?masksigpipe:bool -> + ?max_connections:int -> + ?timeout:float -> + ?buf_size:int -> + ?get_time_s:(unit -> float) -> + ?new_thread:((unit -> unit) -> unit) -> + ?addr:string -> + ?port:int -> + ?sock:Unix.file_descr -> + ?middlewares:([`Encoding | `Stage of int] * Middleware.t) list -> + unit -> + t +(** Create a new webserver. + + The server will not do anything until {!run} is called on it. + Before starting the server, one can use {!add_path_handler} and + {!set_top_handler} to specify how to handle incoming requests. + + @param masksigpipe if true, block the signal {!Sys.sigpipe} which otherwise + tends to kill client threads when they try to write on broken sockets. Default: [true]. + + @param buf_size size for buffers (since 0.11) + + @param new_thread a function used to spawn a new thread to handle a + new client connection. By default it is {!Thread.create} but one + could use a thread pool instead. + + @param middlewares see {!add_middleware} for more details. + + @param max_connections maximum number of simultaneous connections. + @param timeout connection is closed if the socket does not do read or + write for the amount of second. Default: 0.0 which means no timeout. + timeout is not recommended when using proxy. + @param addr address (IPv4 or IPv6) to listen on. Default ["127.0.0.1"]. + @param port to listen on. Default [8080]. + @param sock an existing socket given to the server to listen on, e.g. by + systemd on Linux (or launchd on macOS). If passed in, this socket will be + used instead of the [addr] and [port]. If not passed in, those will be + used. This parameter exists since 0.10. + + @param get_time_s obtain the current timestamp in seconds. + This parameter exists since 0.11. +*) + +val addr : t -> string +(** Address on which the server listens. *) + +val is_ipv6 : t -> bool +(** [is_ipv6 server] returns [true] iff the address of the server is an IPv6 address. + @since 0.3 *) + +val port : t -> int +(** Port on which the server listens. *) + +val active_connections : t -> int +(** Number of active connections *) + +val add_decode_request_cb : + t -> + (unit Request.t -> (unit Request.t * (byte_stream -> byte_stream)) option) -> unit +[@@deprecated "use add_middleware"] +(** Add a callback for every request. + The callback can provide a stream transformer and a new request (with + modified headers, typically). + A possible use is to handle decompression by looking for a [Transfer-Encoding] + header and returning a stream transformer that decompresses on the fly. + + @deprecated use {!add_middleware} instead +*) + +val add_encode_response_cb: + t -> (unit Request.t -> Response.t -> Response.t option) -> unit +[@@deprecated "use add_middleware"] +(** Add a callback for every request/response pair. + Similarly to {!add_encode_response_cb} the callback can return a new + response, for example to compress it. + The callback is given the query with only its headers, + as well as the current response. + + @deprecated use {!add_middleware} instead +*) + +val add_middleware : + stage:[`Encoding | `Stage of int] -> + t -> Middleware.t -> unit +(** Add a middleware to every request/response pair. + @param stage specify when middleware applies. + Encoding comes first (outermost layer), then stages in increasing order. + @raise Invalid_argument if stage is [`Stage n] where [n < 1] + @since 0.11 +*) + +(** {2 Request handlers} *) + +val set_top_handler : t -> (string Request.t -> Response.t) -> unit +(** Setup a handler called by default. + + This handler is called with any request not accepted by any handler + installed via {!add_path_handler}. + If no top handler is installed, unhandled paths will return a [404] not found. *) + +val add_route_handler : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> + ?meth:Meth.t -> + t -> + ('a, string Request.t -> Response.t) Route.t -> 'a -> + unit +(** [add_route_handler server Route.(exact "path" @/ string @/ int @/ return) f] + calls [f "foo" 42 request] when a [request] with path "path/foo/42/" + is received. + + Note that the handlers are called in the reverse order of their addition, + so the last registered handler can override previously registered ones. + + @param meth if provided, only accept requests with the given method. + Typically one could react to [`GET] or [`PUT]. + @param accept should return [Ok()] if the given request (before its body + is read) should be accepted, [Error (code,message)] if it's to be rejected (e.g. because + its content is too big, or for some permission error). + See the {!http_of_dir} program for an example of how to use [accept] to + filter uploads that are too large before the upload even starts. + The default always returns [Ok()], i.e. it accepts all requests. + + @since 0.6 +*) + +val add_route_handler_stream : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> + ?meth:Meth.t -> + t -> + ('a, byte_stream Request.t -> Response.t) Route.t -> 'a -> + unit +(** Similar to {!add_route_handler}, but where the body of the request + is a stream of bytes that has not been read yet. + This is useful when one wants to stream the body directly into a parser, + json decoder (such as [Jsonm]) or into a file. + @since 0.6 *) + +(** {2 Server-sent events} + + {b EXPERIMENTAL}: this API is not stable yet. *) + +(** A server-side function to generate of Server-sent events. + + See {{: https://html.spec.whatwg.org/multipage/server-sent-events.html} the w3c page} + and {{: https://jvns.ca/blog/2021/01/12/day-36--server-sent-events-are-cool--and-a-fun-bug/} + this blog post}. + + @since 0.9 + *) +module type SERVER_SENT_GENERATOR = sig + val set_headers : Headers.t -> unit + (** Set headers of the response. + This is not mandatory but if used at all, it must be called before + any call to {!send_event} (once events are sent the response is + already sent too). *) + + val send_event : + ?event:string -> + ?id:string -> + ?retry:string -> + data:string -> + unit -> unit + (** Send an event from the server. + If data is a multiline string, it will be sent on separate "data:" lines. *) + + val close : unit -> unit + (** Close connection. + @since 0.11 *) +end + +type server_sent_generator = (module SERVER_SENT_GENERATOR) +(** Server-sent event generator + @since 0.9 *) + +val add_route_server_sent_handler : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + t -> + ('a, string Request.t -> server_sent_generator -> unit) Route.t -> 'a -> + unit +(** Add a handler on an endpoint, that serves server-sent events. + + The callback is given a generator that can be used to send events + as it pleases. The connection is always closed by the client, + and the accepted method is always [GET]. + This will set the header "content-type" to "text/event-stream" automatically + and reply with a 200 immediately. + See {!server_sent_generator} for more details. + + This handler stays on the original thread (it is synchronous). + + @since 0.9 *) + +(** {2 Run the server} *) + +val stop : t -> unit +(** Ask the server to stop. This might not have an immediate effect + as {!run} might currently be waiting on IO. *) + +val run : t -> (unit, exn) result +(** Run the main loop of the server, listening on a socket + described at the server's creation time, using [new_thread] to + start a thread for each new client. + + This returns [Ok ()] if the server exits gracefully, or [Error e] if + it exits with an error. *) + +(**/**) + +val _debug : ((('a, out_channel, unit, unit, unit, unit) format6 -> 'a) -> unit) -> unit +val _enable_debug: bool -> unit + +(**/**) diff --git a/src/Tiny_httpd_stream.ml b/src/Tiny_httpd_stream.ml new file mode 100644 index 00000000..75a5375c --- /dev/null +++ b/src/Tiny_httpd_stream.ml @@ -0,0 +1,306 @@ + +module Buf = Tiny_httpd_buf + +let spf = Printf.sprintf + +type hidden = unit +type t = { + mutable bs: bytes; + mutable off : int; + mutable len : int; + fill_buf: unit -> unit; + consume: int -> unit; + close: unit -> unit; + _rest: hidden; +} + +let[@inline] close self = self.close() + +let empty = { + bs=Bytes.empty; + off=0; + len=0; + fill_buf=ignore; + consume=ignore; + close=ignore; + _rest=(); +} + +let make ?(bs=Bytes.create @@ 16 * 1024) ?(close=ignore) ~consume ~fill () : t = + let rec self = { + bs; + off=0; + len=0; + close=(fun () -> close self); + fill_buf=(fun () -> fill self); + consume= + (fun n -> + assert (n <= self.len); + consume self n + ); + _rest=(); + } in + self + +let of_chan_ ?(buf_size=16 * 1024) ~close ic : t = + make + ~bs:(Bytes.create buf_size) + ~close:(fun _ -> close ic) + ~consume:(fun buf n -> buf.off <- buf.off + n) + ~fill:(fun self -> + if self.off >= self.len then ( + self.off <- 0; + self.len <- input ic self.bs 0 (Bytes.length self.bs); + ) + ) + () + +let of_chan = of_chan_ ~close:close_in +let of_chan_close_noerr = of_chan_ ~close:close_in_noerr + +let rec iter f (self:t) : unit = + self.fill_buf(); + if self.len=0 then ( + self.close(); + ) else ( + f self.bs self.off self.len; + self.consume self.len; + (iter [@tailcall]) f self + ) + +let to_chan (oc:out_channel) (self:t) = + iter (fun s i len -> output oc s i len) self + +let of_bytes ?(i=0) ?len (bs:bytes) : t = + (* invariant: !i+!len is constant *) + let len = + match len with + | Some n -> + if n > Bytes.length bs - i then invalid_arg "Byte_stream.of_bytes"; + n + | None -> Bytes.length bs - i + in + let self = + make + ~bs ~fill:ignore + ~close:(fun self -> self.len <- 0) + ~consume:(fun self n -> + assert (n>=0 && n<= self.len); + self.off <- n + self.off; + self.len <- self.len - n + ) + () + in + self.off <- i; + self.len <- len; + self + +let of_string s : t = + of_bytes (Bytes.unsafe_of_string s) + +let with_file ?buf_size file f = + let ic = open_in file in + try + let x = f (of_chan ?buf_size ic) in + close_in ic; + x + with e -> + close_in_noerr ic; + raise e + +let read_all ?(buf=Buf.create()) (self:t) : string = + let continue = ref true in + while !continue do + self.fill_buf(); + if self.len > 0 then ( + Buf.add_bytes buf self.bs self.off self.len; + self.consume self.len; + ); + assert (self.len >= 0); + if self.len = 0 then ( + continue := false + ) + done; + Buf.contents_and_clear buf + +(* put [n] bytes from the input into bytes *) +let read_exactly_ ~too_short (self:t) (bytes:bytes) (n:int) : unit = + assert (Bytes.length bytes >= n); + let offset = ref 0 in + while !offset < n do + self.fill_buf(); + let n_read = min self.len (n - !offset) in + Bytes.blit self.bs self.off bytes !offset n_read; + offset := !offset + n_read; + self.consume n_read; + if n_read=0 then too_short(); + done + +(* read a line into the buffer, after clearing it. *) +let read_line_into (self:t) ~buf : unit = + Buf.clear buf; + let continue = ref true in + while !continue do + self.fill_buf(); + if self.len=0 then ( + continue := false; + if Buf.size buf = 0 then raise End_of_file; + ); + let j = ref self.off in + while !j < self.off + self.len && Bytes.get self.bs !j <> '\n' do + incr j + done; + if !j-self.off < self.len then ( + assert (Bytes.get self.bs !j = '\n'); + Buf.add_bytes buf self.bs self.off (!j-self.off); (* without \n *) + self.consume (!j-self.off+1); (* remove \n *) + continue := false + ) else ( + Buf.add_bytes buf self.bs self.off self.len; + self.consume self.len; + ) + done + +(* new stream with maximum size [max_size]. + @param close_rec if true, closing this will also close the input stream + @param too_big called with read size if the max size is reached *) +let limit_size_to ~close_rec ~max_size ~too_big (self:t) : t = + let size = ref 0 in + let continue = ref true in + make + ~bs:Bytes.empty + ~close:(fun _ -> + if close_rec then self.close ()) + ~fill:(fun buf -> + if buf.len = 0 && !continue then ( + self.fill_buf(); + buf.bs <- self.bs; + buf.off <- self.off; + buf.len <- self.len; + ) else ( + self.bs <- Bytes.empty; + self.off <- 0; + self.len <- 0; + ) + ) + ~consume:(fun buf n -> + size := !size + n; + if !size > max_size then ( + continue := false; + too_big !size + ) else ( + self.consume n; + buf.len <- buf.len - n; + )) + () + +(* read exactly [size] bytes from the stream *) +let read_exactly ~close_rec ~size ~too_short (self:t) : t = + if size=0 then ( + empty + ) else ( + let size = ref size in + make ~bs:Bytes.empty + ~fill:(fun buf -> + (* must not block on [self] if we're done *) + if !size = 0 then ( + buf.bs <- Bytes.empty; + buf.off <- 0; + buf.len <- 0; + ) else ( + self.fill_buf(); + buf.bs <- self.bs; + buf.off <- self.off; + let len = min self.len !size in + if len = 0 && !size > 0 then ( + too_short !size; + ); + buf.len <- len; + )) + ~close:(fun _buf -> + (* close underlying stream if [close_rec] *) + if close_rec then self.close(); + size := 0 + ) + ~consume:(fun buf n -> + let n = min n !size in + size := !size - n; + buf.len <- buf.len - n; + self.consume n + ) + () + ) + +let read_line ?(buf=Buf.create()) self : string = + read_line_into self ~buf; + Buf.contents buf + +let read_chunked ?(buf=Buf.create()) ~fail (bs:t) : t= + let first = ref true in + let read_next_chunk_len () : int = + if !first then ( + first := false + ) else ( + let line = read_line ~buf bs in + if String.trim line <> "" then raise (fail "expected crlf between chunks";) + ); + let line = read_line ~buf bs in + (* parse chunk length, ignore extensions *) + let chunk_size = ( + if String.trim line = "" then 0 + else + try Scanf.sscanf line "%x %s@\r" (fun n _ext -> n) + with _ -> raise (fail (spf "cannot read chunk size from line %S" line)) + ) in + chunk_size + in + let refill = ref true in + let chunk_size = ref 0 in + make + ~bs:(Bytes.create (16 * 4096)) + ~fill:(fun self -> + (* do we need to refill? *) + if self.off >= self.len then ( + if !chunk_size = 0 && !refill then ( + chunk_size := read_next_chunk_len(); + (* _debug (fun k->k"read next chunk of size %d" !chunk_size); *) + ); + self.off <- 0; + self.len <- 0; + if !chunk_size > 0 then ( + (* read the whole chunk, or [Bytes.length bytes] of it *) + let to_read = min !chunk_size (Bytes.length self.bs) in + read_exactly_ + ~too_short:(fun () -> raise (fail "chunk is too short")) + bs self.bs to_read; + self.len <- to_read; + chunk_size := !chunk_size - to_read; + ) else ( + refill := false; (* stream is finished *) + ) + ); + ) + ~consume:(fun self n -> self.off <- self.off + n) + ~close:(fun self -> + (* close this overlay, do not close underlying stream *) + self.len <- 0; + refill:= false + ) + () + +(* print a stream as a series of chunks *) +let output_chunked (oc:out_channel) (self:t) : unit = + let continue = ref true in + while !continue do + (* next chunk *) + self.fill_buf(); + let n = self.len in + Printf.fprintf oc "%x\r\n" n; + output oc self.bs self.off n; + self.consume n; + if n = 0 then ( + continue := false; + ); + output_string oc "\r\n"; + done; + () diff --git a/src/Tiny_httpd_stream.mli b/src/Tiny_httpd_stream.mli new file mode 100644 index 00000000..13ccb168 --- /dev/null +++ b/src/Tiny_httpd_stream.mli @@ -0,0 +1,113 @@ + +type hidden + +type t = { + mutable bs: bytes; + (** The bytes *) + + mutable off : int; + (** Beginning of valid slice in {!bs} *) + + mutable len : int; + (** Length of valid slice in {!bs}. If [len = 0] after + a call to {!fill}, then the stream is finished. *) + + fill_buf: unit -> unit; + (** See the current slice of the internal buffer as [bytes, i, len], + where the slice is [bytes[i] .. [bytes[i+len-1]]]. + Can block to refill the buffer if there is currently no content. + If [len=0] then there is no more data. *) + + consume: int -> unit; + (** Consume [n] bytes from the buffer. + This should only be called with [n <= len]. *) + + close: unit -> unit; + (** Close the stream. *) + + _rest: hidden; + (** Use {!make} to build a stream. *) +} +(** A buffered stream, with a view into the current buffer (or refill if empty), + and a function to consume [n] bytes. + See {!Byte_stream} for more details. *) + +val close : t -> unit + +val empty : t + +val of_chan : ?buf_size:int -> in_channel -> t +(** Make a buffered stream from the given channel. *) + +val of_chan_close_noerr : ?buf_size:int -> in_channel -> t +(** Same as {!of_chan} but the [close] method will never fail. *) + +val of_bytes : ?i:int -> ?len:int -> bytes -> t +(** A stream that just returns the slice of bytes starting from [i] + and of length [len]. *) + +val of_string : string -> t + +val iter : (bytes -> int -> int -> unit) -> t -> unit +(** Iterate on the chunks of the stream + @since 0.3 *) + +val to_chan : out_channel -> t -> unit +(** Write the stream to the channel. + @since 0.3 *) + +val make : + ?bs:bytes -> + ?close:(t -> unit) -> + consume:(t -> int -> unit) -> + fill:(t -> unit) -> + unit -> t +(** [make ~fill ()] creates a byte stream. + @param fill is used to refill the buffer, and is called initially. + @param close optional closing. + @param init_size size of the buffer. +*) + +val with_file : ?buf_size:int -> string -> (t -> 'a) -> 'a +(** Open a file with given name, and obtain an input stream + on its content. When the function returns, the stream (and file) are closed. *) + +val read_line : ?buf:Tiny_httpd_buf.t -> t -> string +(** Read a line from the stream. + @param buf a buffer to (re)use. Its content will be cleared. *) + +val read_all : ?buf:Tiny_httpd_buf.t -> t -> string +(** Read the whole stream into a string. + @param buf a buffer to (re)use. Its content will be cleared. *) + +val limit_size_to : + close_rec:bool -> + max_size:int -> + too_big:(int -> unit) -> + t -> t +(* New stream with maximum size [max_size]. + @param close_rec if true, closing this will also close the input stream + @param too_big called with read size if the max size is reached *) + +val read_chunked : + ?buf:Tiny_httpd_buf.t -> + fail:(string -> exn) -> + t -> t +(** Convert a stream into a stream of byte chunks using + the chunked encoding. The size of chunks is not specified. + @param buf buffer used for intermediate storage. + @param fail used to build an exception if reading fails. +*) + +val read_exactly : + close_rec:bool -> size:int -> too_short:(int -> unit) -> + t -> t +(** [read_exactly ~size bs] returns a new stream that reads exactly + [size] bytes from [bs], and then closes. + @param close_rec if true, closing the resulting stream also closes + [bs] + @param too_short is called if [bs] closes with still [n] bytes remaining +*) + +val output_chunked : out_channel -> t -> unit +(** Write the stream into the channel, using the chunked encoding. *) diff --git a/src/camlzip/Tiny_httpd_camlzip.ml b/src/camlzip/Tiny_httpd_camlzip.ml index f9a394fc..35aa9569 100644 --- a/src/camlzip/Tiny_httpd_camlzip.ml +++ b/src/camlzip/Tiny_httpd_camlzip.ml @@ -1,124 +1,125 @@ -module S = Tiny_httpd -module BS = Tiny_httpd.Byte_stream +module S = Tiny_httpd_server +module BS = Tiny_httpd_stream let decode_deflate_stream_ ~buf_size (is:S.byte_stream) : S.byte_stream = S._debug (fun k->k "wrap stream with deflate.decode"); - let buf = Bytes.make buf_size ' ' in - let buf_len = ref 0 in - let write_offset = ref 0 in let zlib_str = Zlib.inflate_init false in let is_done = ref false in - let bs_close () = - Zlib.inflate_end zlib_str; - BS.close is - in - let bs_consume len : unit = - if len > !buf_len then ( - S.Response.fail_raise ~code:400 - "inflate: error during decompression: invalid consume len %d (max %d)" - len !buf_len - ); - write_offset := !write_offset + len; - in - let bs_fill_buf () : _*_*_ = - (* refill [buf] if needed *) - if !write_offset >= !buf_len && not !is_done then ( - let ib, ioff, ilen = is.S.bs_fill_buf () in - begin - try - let finished, used_in, used_out = - Zlib.inflate zlib_str - buf 0 (Bytes.length buf) - ib ioff ilen Zlib.Z_SYNC_FLUSH - in - is.S.bs_consume used_in; - write_offset := 0; - buf_len := used_out; - if finished then is_done := true; - S._debug (fun k->k "decode %d bytes as %d bytes from inflate (finished: %b)" - used_in used_out finished); - with Zlib.Error (e1,e2) -> + BS.make + ~bs:(Bytes.create buf_size) + ~close:(fun _ -> + Zlib.inflate_end zlib_str; + BS.close is + ) + ~consume:(fun self len -> + if len > self.len then ( S.Response.fail_raise ~code:400 - "inflate: error during decompression:\n%s %s" e1 e2 - end; - S._debug (fun k->k "inflate: refill %d bytes into internal buf" !buf_len); - ); - buf, !write_offset, !buf_len - !write_offset - in - {S.bs_fill_buf; bs_consume; bs_close} + "inflate: error during decompression: invalid consume len %d (max %d)" + len self.len + ); + self.off <- self.off + len; + self.len <- self.len - len; + ) + ~fill:(fun self -> + (* refill [buf] if needed *) + if self.len = 0 && not !is_done then ( + is.fill_buf(); + begin + try + let finished, used_in, used_out = + Zlib.inflate zlib_str + self.bs 0 (Bytes.length self.bs) + is.bs is.off is.len Zlib.Z_SYNC_FLUSH + in + is.consume used_in; + self.off <- 0; + self.len <- used_out; + if finished then is_done := true; + S._debug (fun k->k "decode %d bytes as %d bytes from inflate (finished: %b)" + used_in used_out finished); + with Zlib.Error (e1,e2) -> + S.Response.fail_raise ~code:400 + "inflate: error during decompression:\n%s %s" e1 e2 + end; + S._debug (fun k->k "inflate: refill %d bytes into internal buf" self.len); + ); + ) + () + +;; let encode_deflate_stream_ ~buf_size (is:S.byte_stream) : S.byte_stream = S._debug (fun k->k "wrap stream with deflate.encode"); let refill = ref true in - let buf = Bytes.make buf_size ' ' in - let buf_len = ref 0 in - let write_offset = ref 0 in let zlib_str = Zlib.deflate_init 4 false in - let bs_close () = - S._debug (fun k->k "deflate: close"); - Zlib.deflate_end zlib_str; - BS.close is - in - let bs_consume n = - write_offset := n + !write_offset - in - let bs_fill_buf () = - let rec loop() = - S._debug (fun k->k "deflate.fill.iter out_off=%d out_len=%d" - !write_offset !buf_len); - if !write_offset < !buf_len then ( - (* still the same slice, not consumed entirely by output *) - buf, !write_offset, !buf_len - !write_offset - ) else if not !refill then ( - (* empty slice, no refill *) - buf, !write_offset, !buf_len - !write_offset - ) else ( - (* the output was entirely consumed, we need to do more work *) - write_offset := 0; - buf_len := 0; - let in_s, in_i, in_len = is.S.bs_fill_buf () in - if in_len>0 then ( - (* try to decompress from input buffer *) - let _finished, used_in, used_out = - Zlib.deflate zlib_str - in_s in_i in_len - buf 0 (Bytes.length buf) - Zlib.Z_NO_FLUSH - in - buf_len := used_out; - is.S.bs_consume used_in; - S._debug - (fun k->k "encode %d bytes as %d bytes using deflate (finished: %b)" - used_in used_out _finished); - if _finished then ( - S._debug (fun k->k "deflate: finished"); - refill := false; - ); - loop() - ) else ( - (* finish sending the internal state *) - let _finished, used_in, used_out = - Zlib.deflate zlib_str - in_s in_i in_len - buf 0 (Bytes.length buf) - Zlib.Z_FULL_FLUSH - in - assert (used_in = 0); - buf_len := used_out; - if used_out = 0 then ( - refill := false; - ); - loop() - ) + BS.make + ~bs:(Bytes.create buf_size) + ~close:(fun _self -> + S._debug (fun k->k "deflate: close"); + Zlib.deflate_end zlib_str; + BS.close is ) - in - try loop() - with Zlib.Error (e1,e2) -> - S.Response.fail_raise ~code:400 - "deflate: error during compression:\n%s %s" e1 e2 - in - {S.bs_fill_buf; bs_consume; bs_close} + ~consume:(fun self n -> + self.off <- self.off + n; + self.len <- self.len - n + ) + ~fill:(fun self -> + let rec loop() = + S._debug (fun k->k "deflate.fill.iter out_off=%d out_len=%d" + self.off self.len); + if self.len > 0 then ( + () (* still the same slice, not consumed entirely by output *) + ) else if not !refill then ( + () (* empty slice, no refill *) + ) else ( + (* the output was entirely consumed, we need to do more work *) + is.BS.fill_buf(); + if is.len > 0 then ( + (* try to decompress from input buffer *) + let _finished, used_in, used_out = + Zlib.deflate zlib_str + is.bs is.off is.len + self.bs 0 (Bytes.length self.bs) + Zlib.Z_NO_FLUSH + in + self.off <- 0; + self.len <- used_out; + is.consume used_in; + S._debug + (fun k->k "encode %d bytes as %d bytes using deflate (finished: %b)" + used_in used_out _finished); + if _finished then ( + S._debug (fun k->k "deflate: finished"); + refill := false; + ); + loop() + ) else ( + (* [is] is done, finish sending the data in current buffer *) + let _finished, used_in, used_out = + Zlib.deflate zlib_str + is.bs is.off is.len + self.bs 0 (Bytes.length self.bs) + Zlib.Z_FULL_FLUSH + in + assert (used_in = 0); + self.off <- 0; + self.len <- used_out; + if used_out = 0 then ( + refill := false; + ); + loop() + ) + ) + in + try loop() + with Zlib.Error (e1,e2) -> + S.Response.fail_raise ~code:400 + "deflate: error during compression:\n%s %s" e1 e2 + ) + + () +;; let split_on_char ?(f=fun x->x) c s : string list = let rec loop acc i = @@ -184,7 +185,7 @@ let compress_resp_stream_ (fun k->k "encode str response with deflate (size %d, threshold %d)" (String.length s) compress_above); let body = - encode_deflate_stream_ ~buf_size @@ S.Byte_stream.of_string s + encode_deflate_stream_ ~buf_size @@ BS.of_string s in resp |> S.Response.update_headers update_headers diff --git a/src/camlzip/Tiny_httpd_camlzip.mli b/src/camlzip/Tiny_httpd_camlzip.mli index 9fc75267..52f17cd8 100644 --- a/src/camlzip/Tiny_httpd_camlzip.mli +++ b/src/camlzip/Tiny_httpd_camlzip.mli @@ -2,13 +2,13 @@ val middleware : ?compress_above:int -> ?buf_size:int -> unit -> - Tiny_httpd.Middleware.t + Tiny_httpd_server.Middleware.t (** Middleware responsible for deflate compression/decompression. @since 0.11 *) val setup : ?compress_above:int -> - ?buf_size:int -> Tiny_httpd.t -> unit + ?buf_size:int -> Tiny_httpd_server.t -> unit (** Install middleware for tiny_httpd to be able to encode/decode compressed streams @param compress_above threshold above with string responses are compressed