diff --git a/src/core/IO.ml b/src/core/IO.ml new file mode 100644 index 00000000..ba00b6ef --- /dev/null +++ b/src/core/IO.ml @@ -0,0 +1,486 @@ +(** IO abstraction. + + We abstract IO so we can support classic unix blocking IOs + with threads, and modern async IO with Eio. + + {b NOTE}: experimental. + + @since 0.14 +*) + +open Common_ +module Buf = Buf +module Slice = Iostream.Slice + +(** Output channel (byte sink) *) +module Output = struct + include Iostream.Out_buf + + class of_fd ?(close_noerr = false) ~closed ~(buf : Slice.t) + (fd : Unix.file_descr) : t = + object + inherit t_from_output ~bytes:buf.bytes () + + method private output_underlying bs i len0 = + let i = ref i in + let len = ref len0 in + while !len > 0 do + match Unix.write fd bs !i !len with + | 0 -> failwith "write failed" + | n -> + i := !i + n; + len := !len - n + | exception + Unix.Unix_error + ( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN + | Unix.ECONNRESET | Unix.EPIPE ), + _, + _ ) -> + failwith "write failed" + | exception + Unix.Unix_error + ((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) -> + ignore (Unix.select [] [ fd ] [] 1.) + done + + method private close_underlying () = + if not !closed then ( + closed := true; + if close_noerr then ( + try Unix.close fd with _ -> () + ) else + Unix.close fd + ) + end + + let output_buf (self : t) (buf : Buf.t) : unit = + let b = Buf.bytes_slice buf in + output self b 0 (Buf.size buf) + + (** [chunk_encoding oc] makes a new channel that outputs its content into [oc] + in chunk encoding form. + @param close_rec if true, closing the result will also close [oc] + @param buf a buffer used to accumulate data into chunks. + Chunks are emitted when [buf]'s size gets over a certain threshold, + or when [flush] is called. + *) + let chunk_encoding ?(buf = Buf.create ()) ~close_rec (oc : #t) : t = + (* write content of [buf] as a chunk if it's big enough. + If [force=true] then write content of [buf] if it's simply non empty. *) + let write_buf ~force () = + let n = Buf.size buf in + if (force && n > 0) || n >= 4_096 then ( + output_string oc (Printf.sprintf "%x\r\n" n); + output oc (Buf.bytes_slice buf) 0 n; + output_string oc "\r\n"; + Buf.clear buf + ) + in + + object (self) + method flush () = + write_buf ~force:true (); + flush self + + method close () = + write_buf ~force:true (); + (* write an empty chunk to close the stream *) + output_string self "0\r\n"; + (* write another crlf after the stream (see #56) *) + output_string self "\r\n"; + self#flush (); + if close_rec then close oc + + method output b i n = + Buf.add_bytes buf b i n; + write_buf ~force:false () + + method output_char c = + Buf.add_char buf c; + write_buf ~force:false () + end +end + +(** Input channel (byte source) *) +module Input = struct + include Iostream.In_buf + + let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t) + (fd : Unix.file_descr) : t = + let eof = ref false in + object + inherit Iostream.In_buf.t_from_refill ~buf () + + method private refill (slice : Slice.t) = + if not !eof then ( + slice.off <- 0; + let continue = ref true in + while !continue do + match Unix.read fd slice.bytes 0 (Bytes.length slice.bytes) with + | n -> + slice.len <- n; + continue := false + | exception + Unix.Unix_error + ( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN + | Unix.ECONNRESET | Unix.EPIPE ), + _, + _ ) -> + eof := true; + continue := false + | exception + Unix.Unix_error + ((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) -> + ignore (Unix.select [ fd ] [] [] 1.) + done; + (* Printf.eprintf "read returned %d B\n%!" !n; *) + if slice.len = 0 then eof := true + ) + + method! close () = + if not !closed then ( + closed := true; + eof := true; + if close_noerr then ( + try Unix.close fd with _ -> () + ) else + Unix.close fd + ) + end + + let of_slice (slice : Slice.t) : t = + object + inherit Iostream.In_buf.t_from_refill ~buf:slice () + + method private refill (slice : Slice.t) = + slice.off <- 0; + slice.len <- 0 + + method! close () = () + end + + (** Read into the given slice. + @return the number of bytes read, [0] means end of input. *) + let[@inline] input (self : t) buf i len = self#input buf i len + + (** Close the channel. *) + let[@inline] close self : unit = self#close () + + (** Read exactly [len] bytes. + @raise End_of_file if the input did not contain enough data. *) + let really_input (self : t) buf i len : unit = + let i = ref i in + let len = ref len in + while !len > 0 do + let n = input self buf !i !len in + if n = 0 then raise End_of_file; + i := !i + n; + len := !len - n + done + + let append (i1 : #t) (i2 : #t) : t = + let use_i1 = ref true in + let rec input_rec (slice : Slice.t) = + if !use_i1 then ( + slice.len <- input i1 slice.bytes 0 (Bytes.length slice.bytes); + if slice.len = 0 then ( + use_i1 := false; + input_rec slice + ) + ) else + slice.len <- input i1 slice.bytes 0 (Bytes.length slice.bytes) + in + + object + inherit Iostream.In_buf.t_from_refill () + + method private refill (slice : Slice.t) = + slice.off <- 0; + input_rec slice + + method! close () = + close i1; + close i2 + end + + let iter (f : Slice.t -> unit) (self : #t) : unit = + let continue = ref true in + while !continue do + let slice = self#fill_buf () in + if slice.len = 0 then ( + continue := false; + close self + ) else ( + f slice; + Slice.consume slice slice.len + ) + done + + let to_chan oc (self : #t) = + iter + (fun (slice : Slice.t) -> + Stdlib.output oc slice.bytes slice.off slice.len) + self + + let to_chan' (oc : #Iostream.Out_buf.t) (self : #t) : unit = + iter + (fun (slice : Slice.t) -> + Iostream.Out_buf.output oc slice.bytes slice.off slice.len) + self + + let read_all_using ~buf (self : #t) : string = + Buf.clear buf; + let continue = ref true in + while !continue do + let slice = self#fill_buf () in + if slice.len = 0 then + continue := false + else ( + assert (slice.len > 0); + Buf.add_bytes buf slice.bytes slice.off slice.len; + Slice.consume slice slice.len + ) + done; + Buf.contents_and_clear buf + + (** put [n] bytes from the input into bytes *) + let read_exactly_ ~too_short (self : #t) (bytes : bytes) (n : int) : unit = + assert (Bytes.length bytes >= n); + let offset = ref 0 in + while !offset < n do + let slice = self#fill_buf () in + let n_read = min slice.len (n - !offset) in + Bytes.blit slice.bytes slice.off bytes !offset n_read; + offset := !offset + n_read; + Slice.consume slice n_read; + if n_read = 0 then too_short () + done + + (** read a line into the buffer, after clearing it. *) + let read_line_into (self : t) ~buf : unit = + Buf.clear buf; + let continue = ref true in + while !continue do + let slice = self#fill_buf () in + if slice.len = 0 then ( + continue := false; + if Buf.size buf = 0 then raise End_of_file + ); + let j = ref slice.off in + while !j < slice.off + slice.len && Bytes.get slice.bytes !j <> '\n' do + incr j + done; + if !j - slice.off < slice.len then ( + assert (Bytes.get slice.bytes !j = '\n'); + (* line without '\n' *) + Buf.add_bytes buf slice.bytes slice.off (!j - slice.off); + (* consume line + '\n' *) + Slice.consume slice (!j - slice.off + 1); + continue := false + ) else ( + Buf.add_bytes buf slice.bytes slice.off slice.len; + Slice.consume slice slice.len + ) + done + + let read_line_using ~buf (self : #t) : string = + read_line_into self ~buf; + Buf.contents_and_clear buf + + let read_line_using_opt ~buf (self : #t) : string option = + match read_line_into self ~buf with + | () -> Some (Buf.contents_and_clear buf) + | exception End_of_file -> None + + (** new stream with maximum size [max_size]. + @param close_rec if true, closing this will also close the input stream *) + let limit_size_to ~close_rec ~max_size ~(bytes : bytes) (arg : t) : t = + let remaining_size = ref max_size in + let slice = Slice.of_bytes bytes in + + object + inherit Iostream.In_buf.t_from_refill ~buf:slice () + method! close () = if close_rec then close arg + + method private refill slice = + if slice.len = 0 then + if !remaining_size > 0 then ( + let sub = fill_buf arg in + let len = min sub.len !remaining_size in + + Bytes.blit sub.bytes sub.off slice.bytes 0 len; + slice.off <- 0; + slice.len <- len; + Slice.consume sub len + ) + end + + (** New stream that consumes exactly [size] bytes from the input. + If fewer bytes are read before [close] is called, we read and discard + the remaining quota of bytes before [close] returns. + @param close_rec if true, closing this will also close the input stream *) + let reading_exactly ~close_rec ~size ~(bytes : bytes) (arg : t) : t = + let remaining_size = ref size in + let slice = Slice.of_bytes bytes in + + object + inherit Iostream.In_buf.t_from_refill ~buf:slice () + + method! close () = + if !remaining_size > 0 then skip arg !remaining_size; + if close_rec then close arg + + method private refill slice = + if slice.len = 0 then + if !remaining_size > 0 then ( + let sub = fill_buf arg in + let len = min sub.len !remaining_size in + + Bytes.blit sub.bytes sub.off slice.bytes 0 len; + slice.off <- 0; + slice.len <- len; + Slice.consume sub len + ) + end + + let read_chunked ~(buf : Slice.t) ~fail (bs : #t) : t = + let first = ref true in + let line_buf = Buf.create ~size:128 () in + let read_next_chunk_len () : int = + if !first then + first := false + else ( + let line = read_line_using ~buf:line_buf bs in + if String.trim line <> "" then + raise (fail "expected crlf between chunks") + ); + let line = read_line_using ~buf:line_buf bs in + (* parse chunk length, ignore extensions *) + let chunk_size = + if String.trim line = "" then + 0 + else ( + try + let off = ref 0 in + let n = Parse_.pos_hex line off in + n + with _ -> + raise (fail (spf "cannot read chunk size from line %S" line)) + ) + in + chunk_size + in + let eof = ref false in + let chunk_size = ref 0 in + + object + inherit t_from_refill ~buf () + + method private refill (slice : Slice.t) : unit = + if !chunk_size = 0 && not !eof then chunk_size := read_next_chunk_len (); + slice.off <- 0; + slice.len <- 0; + if !chunk_size > 0 then ( + (* read the whole chunk, or [Bytes.length bytes] of it *) + let to_read = min !chunk_size (Bytes.length slice.bytes) in + read_exactly_ + ~too_short:(fun () -> raise (fail "chunk is too short")) + bs slice.bytes to_read; + slice.len <- to_read; + chunk_size := !chunk_size - to_read + ) else + (* stream is finished *) + eof := true + + method! close () = + (* do not close underlying stream *) + eof := true; + () + end + + (** Output a stream using chunked encoding *) + let output_chunked' ?buf (oc : #Iostream.Out_buf.t) (self : #t) : unit = + let oc' = Output.chunk_encoding ?buf oc ~close_rec:false in + match to_chan' oc' self with + | () -> Output.close oc' + | exception e -> + let bt = Printexc.get_raw_backtrace () in + Output.close oc'; + Printexc.raise_with_backtrace e bt + + (** print a stream as a series of chunks *) + let output_chunked ?buf (oc : out_channel) (self : #t) : unit = + output_chunked' ?buf (Output.of_out_channel oc) self +end + +(** A writer abstraction. *) +module Writer = struct + type t = { write: Output.t -> unit } [@@unboxed] + (** Writer. + + A writer is a push-based stream of bytes. + Give it an output channel and it will write the bytes in it. + + This is useful for responses: an http endpoint can return a writer + as its response's body; the writer is given access to the connection + to the client and can write into it as if it were a regular + [out_channel], including controlling calls to [flush]. + Tiny_httpd will convert these writes into valid HTTP chunks. + @since 0.14 + *) + + let[@inline] make ~write () : t = { write } + + (** Write into the channel. *) + let[@inline] write (oc : Output.t) (self : t) : unit = self.write oc + + (** Empty writer, will output 0 bytes. *) + let empty : t = { write = ignore } + + (** A writer that just emits the bytes from the given string. *) + let[@inline] of_string (str : string) : t = + let write oc = Output.output_string oc str in + { write } + + let[@inline] of_input (ic : #Input.t) : t = + { write = (fun oc -> Input.to_chan' oc ic) } +end + +(** A TCP server abstraction. *) +module TCP_server = struct + type conn_handler = { + handle: client_addr:Unix.sockaddr -> Input.t -> Output.t -> unit; + (** Handle client connection *) + } + + type t = { + endpoint: unit -> string * int; + (** Endpoint we listen on. This can only be called from within [serve]. *) + active_connections: unit -> int; + (** Number of connections currently active *) + running: unit -> bool; (** Is the server currently running? *) + stop: unit -> unit; + (** Ask the server to stop. This might not take effect immediately, + and is idempotent. After this [server.running()] must return [false]. *) + } + (** A running TCP server. + + This contains some functions that provide information about the running + server, including whether it's active (as opposed to stopped), a function + to stop it, and statistics about the number of connections. *) + + type builder = { + serve: after_init:(t -> unit) -> handle:conn_handler -> unit -> unit; + (** Blocking call to listen for incoming connections and handle them. + Uses the connection handler [handle] to handle individual client + connections in individual threads/fibers/tasks. + @param after_init is called once with the server after the server + has started. *) + } + (** A TCP server builder implementation. + + Calling [builder.serve ~after_init ~handle ()] starts a new TCP server on + an unspecified endpoint + (most likely coming from the function returning this builder) + and returns the running server. *) +end diff --git a/src/core/buf.ml b/src/core/buf.ml new file mode 100644 index 00000000..fcc89933 --- /dev/null +++ b/src/core/buf.ml @@ -0,0 +1,55 @@ +type t = { mutable bytes: bytes; mutable i: int; original: bytes } + +let create ?(size = 4_096) () : t = + let bytes = Bytes.make size ' ' in + { bytes; i = 0; original = bytes } + +let[@inline] size self = self.i +let[@inline] bytes_slice self = self.bytes + +let clear self : unit = + if + Bytes.length self.bytes > 500 * 1_024 + && Bytes.length self.bytes > Bytes.length self.original + then + (* free big buffer *) + self.bytes <- self.original; + self.i <- 0 + +let clear_and_zero self = + clear self; + Bytes.fill self.bytes 0 (Bytes.length self.bytes) '\x00' + +let resize self new_size : unit = + let new_buf = Bytes.make new_size ' ' in + Bytes.blit self.bytes 0 new_buf 0 self.i; + self.bytes <- new_buf + +let add_char self c : unit = + if self.i + 1 >= Bytes.length self.bytes then + resize self (self.i + (self.i / 2) + 10); + Bytes.set self.bytes self.i c; + self.i <- 1 + self.i + +let add_bytes (self : t) s i len : unit = + if self.i + len >= Bytes.length self.bytes then + resize self (self.i + (self.i / 2) + len + 10); + Bytes.blit s i self.bytes self.i len; + self.i <- self.i + len + +let[@inline] add_string self str : unit = + add_bytes self (Bytes.unsafe_of_string str) 0 (String.length str) + +let add_buffer (self : t) (buf : Buffer.t) : unit = + let len = Buffer.length buf in + if self.i + len >= Bytes.length self.bytes then + resize self (self.i + (self.i / 2) + len + 10); + Buffer.blit buf 0 self.bytes self.i len; + self.i <- self.i + len + +let contents (self : t) : string = Bytes.sub_string self.bytes 0 self.i + +let contents_and_clear (self : t) : string = + let x = contents self in + clear self; + x diff --git a/src/core/buf.mli b/src/core/buf.mli new file mode 100644 index 00000000..e5ca90c1 --- /dev/null +++ b/src/core/buf.mli @@ -0,0 +1,42 @@ +(** Simple buffer. + + These buffers are used to avoid allocating too many byte arrays when + processing streams and parsing requests. + + @since 0.12 +*) + +type t + +val size : t -> int +val clear : t -> unit +val create : ?size:int -> unit -> t +val contents : t -> string + +val clear_and_zero : t -> unit +(** Clear the buffer and zero out its storage. + @since 0.15 *) + +val bytes_slice : t -> bytes +(** Access underlying slice of bytes. + @since 0.5 *) + +val contents_and_clear : t -> string +(** Get contents of the buffer and clear it. + @since 0.5 *) + +val add_char : t -> char -> unit +(** Add a single char. + @since 0.14 *) + +val add_bytes : t -> bytes -> int -> int -> unit +(** Append given bytes slice to the buffer. + @since 0.5 *) + +val add_string : t -> string -> unit +(** Add string. + @since 0.14 *) + +val add_buffer : t -> Buffer.t -> unit +(** Append bytes from buffer. + @since 0.14 *) diff --git a/src/core/common_.ml b/src/core/common_.ml new file mode 100644 index 00000000..1058feea --- /dev/null +++ b/src/core/common_.ml @@ -0,0 +1,10 @@ +exception Bad_req of int * string + +let spf = Printf.sprintf +let bad_reqf c fmt = Printf.ksprintf (fun s -> raise (Bad_req (c, s))) fmt + +type 'a resp_result = ('a, int * string) result + +let unwrap_resp_result = function + | Ok x -> x + | Error (c, s) -> raise (Bad_req (c, s)) diff --git a/src/core/dune b/src/core/dune new file mode 100644 index 00000000..36bcf2d8 --- /dev/null +++ b/src/core/dune @@ -0,0 +1,21 @@ + +(library + (name tiny_httpd_core) + (public_name tiny_httpd.core) + (private_modules mime_ parse_) + (libraries threads seq hmap iostream + (select mime_.ml from + (magic-mime -> mime_.magic.ml) + ( -> mime_.dummy.ml)) + (select log.ml from + (logs -> log.logs.ml) + (-> log.default.ml)))) + +(rule + (targets Tiny_httpd_atomic_.ml) + (deps + (:bin ./gen/mkshims.exe)) + (action + (with-stdout-to + %{targets} + (run %{bin})))) diff --git a/src/core/gen/dune b/src/core/gen/dune new file mode 100644 index 00000000..cf54b00e --- /dev/null +++ b/src/core/gen/dune @@ -0,0 +1,2 @@ +(executables + (names mkshims)) diff --git a/src/core/gen/mkshims.ml b/src/core/gen/mkshims.ml new file mode 100644 index 00000000..a49f1ab7 --- /dev/null +++ b/src/core/gen/mkshims.ml @@ -0,0 +1,41 @@ +let atomic_before_412 = + {| + type 'a t = {mutable x: 'a} + let[@inline] make x = {x} + let[@inline] get {x} = x + let[@inline] set r x = r.x <- x + let[@inline] exchange r x = + let y = r.x in + r.x <- x; + y + + let[@inline] compare_and_set r seen v = + if r.x == seen then ( + r.x <- v; + true + ) else false + + let[@inline] fetch_and_add r x = + let v = r.x in + r.x <- x + r.x; + v + + let[@inline] incr r = r.x <- 1 + r.x + let[@inline] decr r = r.x <- r.x - 1 + |} + +let atomic_after_412 = {|include Atomic|} + +let write_file file s = + let oc = open_out file in + output_string oc s; + close_out oc + +let () = + let version = Scanf.sscanf Sys.ocaml_version "%d.%d.%s" (fun x y _ -> x, y) in + print_endline + (if version >= (4, 12) then + atomic_after_412 + else + atomic_before_412); + () diff --git a/src/core/headers.ml b/src/core/headers.ml new file mode 100644 index 00000000..a06a6439 --- /dev/null +++ b/src/core/headers.ml @@ -0,0 +1,70 @@ +open Common_ + +type t = (string * string) list + +let empty = [] + +let contains name headers = + let name' = String.lowercase_ascii name in + List.exists (fun (n, _) -> name' = n) headers + +let get_exn ?(f = fun x -> x) x h = + let x' = String.lowercase_ascii x in + List.assoc x' h |> f + +let get ?(f = fun x -> x) x h = + try Some (get_exn ~f x h) with Not_found -> None + +let remove x h = + let x' = String.lowercase_ascii x in + List.filter (fun (k, _) -> k <> x') h + +let set x y h = + let x' = String.lowercase_ascii x in + (x', y) :: List.filter (fun (k, _) -> k <> x') h + +let pp out l = + let pp_pair out (k, v) = Format.fprintf out "@[%s: %s@]" k v in + Format.fprintf out "@[%a@]" (Format.pp_print_list pp_pair) l + +(* token = 1*tchar + tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / "^" / "_" + / "`" / "|" / "~" / DIGIT / ALPHA ; any VCHAR, except delimiters + Reference: https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 *) +let is_tchar = function + | '0' .. '9' + | 'a' .. 'z' + | 'A' .. 'Z' + | '!' | '#' | '$' | '%' | '&' | '\'' | '*' | '+' | '-' | '.' | '^' | '_' | '`' + | '|' | '~' -> + true + | _ -> false + +let for_all pred s = + try + String.iter (fun c -> if not (pred c) then raise Exit) s; + true + with Exit -> false + +let parse_ ~(buf : Buf.t) (bs : IO.Input.t) : t = + let rec loop acc = + match IO.Input.read_line_using_opt ~buf bs with + | None -> raise End_of_file + | Some "\r" -> acc + | Some line -> + Log.debug (fun k -> k "parsed header line %S" line); + let k, v = + try + let i = String.index line ':' in + let k = String.sub line 0 i in + if not (for_all is_tchar k) then + invalid_arg (Printf.sprintf "Invalid header key: %S" k); + let v = + String.sub line (i + 1) (String.length line - i - 1) |> String.trim + in + k, v + with _ -> bad_reqf 400 "invalid header line: %S" line + in + loop ((String.lowercase_ascii k, v) :: acc) + in + loop [] diff --git a/src/core/headers.mli b/src/core/headers.mli new file mode 100644 index 00000000..b46b5d54 --- /dev/null +++ b/src/core/headers.mli @@ -0,0 +1,35 @@ +(** Headers + + Headers are metadata associated with a request or response. *) + +type t = (string * string) list +(** The header files of a request or response. + + Neither the key nor the value can contain ['\r'] or ['\n']. + See https://tools.ietf.org/html/rfc7230#section-3.2 *) + +val empty : t +(** Empty list of headers. + @since 0.5 *) + +val get : ?f:(string -> string) -> string -> t -> string option +(** [get k headers] looks for the header field with key [k]. + @param f if provided, will transform the value before it is returned. *) + +val get_exn : ?f:(string -> string) -> string -> t -> string +(** @raise Not_found *) + +val set : string -> string -> t -> t +(** [set k v headers] sets the key [k] to value [v]. + It erases any previous entry for [k] *) + +val remove : string -> t -> t +(** Remove the key from the headers, if present. *) + +val contains : string -> t -> bool +(** Is there a header with the given key? *) + +val pp : Format.formatter -> t -> unit +(** Pretty print the headers. *) + +val parse_ : buf:Buf.t -> IO.Input.t -> t diff --git a/src/core/log.default.ml b/src/core/log.default.ml new file mode 100644 index 00000000..5340578b --- /dev/null +++ b/src/core/log.default.ml @@ -0,0 +1,7 @@ +(* default: no logging *) + +let info _ = () +let debug _ = () +let error _ = () +let setup ~debug:_ () = () +let dummy = true diff --git a/src/core/log.logs.ml b/src/core/log.logs.ml new file mode 100644 index 00000000..f2cc8f56 --- /dev/null +++ b/src/core/log.logs.ml @@ -0,0 +1,22 @@ +(* Use Logs *) + +module Log = (val Logs.(src_log @@ Src.create "tiny_httpd")) + +let info k = Log.info (fun fmt -> k (fun x -> fmt ?header:None ?tags:None x)) +let debug k = Log.debug (fun fmt -> k (fun x -> fmt ?header:None ?tags:None x)) +let error k = Log.err (fun fmt -> k (fun x -> fmt ?header:None ?tags:None x)) + +let setup ~debug () = + let mutex = Mutex.create () in + Logs.set_reporter_mutex + ~lock:(fun () -> Mutex.lock mutex) + ~unlock:(fun () -> Mutex.unlock mutex); + Logs.set_reporter @@ Logs.format_reporter (); + Logs.set_level ~all:true + (Some + (if debug then + Logs.Debug + else + Logs.Info)) + +let dummy = false diff --git a/src/core/log.mli b/src/core/log.mli new file mode 100644 index 00000000..5944e125 --- /dev/null +++ b/src/core/log.mli @@ -0,0 +1,12 @@ +(** Logging for tiny_httpd *) + +val info : ((('a, Format.formatter, unit, unit) format4 -> 'a) -> unit) -> unit +val debug : ((('a, Format.formatter, unit, unit) format4 -> 'a) -> unit) -> unit +val error : ((('a, Format.formatter, unit, unit) format4 -> 'a) -> unit) -> unit + +val setup : debug:bool -> unit -> unit +(** Setup and enable logging. This should only ever be used in executables, + not libraries. + @param debug if true, set logging to debug (otherwise info) *) + +val dummy : bool diff --git a/src/core/meth.ml b/src/core/meth.ml new file mode 100644 index 00000000..94e6bb3a --- /dev/null +++ b/src/core/meth.ml @@ -0,0 +1,22 @@ +open Common_ + +type t = [ `GET | `PUT | `POST | `HEAD | `DELETE | `OPTIONS ] + +let to_string = function + | `GET -> "GET" + | `PUT -> "PUT" + | `HEAD -> "HEAD" + | `POST -> "POST" + | `DELETE -> "DELETE" + | `OPTIONS -> "OPTIONS" + +let pp out s = Format.pp_print_string out (to_string s) + +let of_string = function + | "GET" -> `GET + | "PUT" -> `PUT + | "POST" -> `POST + | "HEAD" -> `HEAD + | "DELETE" -> `DELETE + | "OPTIONS" -> `OPTIONS + | s -> bad_reqf 400 "unknown method %S" s diff --git a/src/core/meth.mli b/src/core/meth.mli new file mode 100644 index 00000000..76b2c942 --- /dev/null +++ b/src/core/meth.mli @@ -0,0 +1,11 @@ +(** HTTP Methods *) + +type t = [ `GET | `PUT | `POST | `HEAD | `DELETE | `OPTIONS ] +(** A HTTP method. + For now we only handle a subset of these. + + See https://tools.ietf.org/html/rfc7231#section-4 *) + +val pp : Format.formatter -> t -> unit +val to_string : t -> string +val of_string : string -> t diff --git a/src/core/mime_.dummy.ml b/src/core/mime_.dummy.ml new file mode 100644 index 00000000..dc944b1c --- /dev/null +++ b/src/core/mime_.dummy.ml @@ -0,0 +1 @@ +let mime_of_path _ = "application/octet-stream" diff --git a/src/core/mime_.magic.ml b/src/core/mime_.magic.ml new file mode 100644 index 00000000..72fcd345 --- /dev/null +++ b/src/core/mime_.magic.ml @@ -0,0 +1 @@ +let mime_of_path s = Magic_mime.lookup s diff --git a/src/core/mime_.mli b/src/core/mime_.mli new file mode 100644 index 00000000..1831c02d --- /dev/null +++ b/src/core/mime_.mli @@ -0,0 +1,2 @@ + +val mime_of_path : string -> string diff --git a/src/core/parse_.ml b/src/core/parse_.ml new file mode 100644 index 00000000..39430889 --- /dev/null +++ b/src/core/parse_.ml @@ -0,0 +1,77 @@ +(** Basic parser for lines *) + +type 'a t = string -> int ref -> 'a + +open struct + let spf = Printf.sprintf +end + +let[@inline] eof s off = !off = String.length s + +let[@inline] skip_space : unit t = + fun s off -> + while !off < String.length s && String.unsafe_get s !off = ' ' do + incr off + done + +let pos_int : int t = + fun s off : int -> + skip_space s off; + let n = ref 0 in + let continue = ref true in + while !off < String.length s && !continue do + match String.unsafe_get s !off with + | '0' .. '9' as c -> n := (!n * 10) + Char.code c - Char.code '0' + | ' ' | '\t' | '\n' -> continue := false + | c -> failwith @@ spf "expected int, got %C" c + done; + !n + +let pos_hex : int t = + fun s off : int -> + skip_space s off; + let n = ref 0 in + let continue = ref true in + while !off < String.length s && !continue do + match String.unsafe_get s !off with + | 'a' .. 'f' as c -> + incr off; + n := (!n * 16) + Char.code c - Char.code 'a' + 10 + | 'A' .. 'F' as c -> + incr off; + n := (!n * 16) + Char.code c - Char.code 'A' + 10 + | '0' .. '9' as c -> + incr off; + n := (!n * 16) + Char.code c - Char.code '0' + | ' ' | '\r' -> continue := false + | c -> failwith @@ spf "expected int, got %C" c + done; + !n + +(** Parse a word without spaces *) +let word : string t = + fun s off -> + skip_space s off; + let start = !off in + let continue = ref true in + while !off < String.length s && !continue do + match String.unsafe_get s !off with + | ' ' | '\r' -> continue := false + | _ -> incr off + done; + if !off = start then failwith "expected word"; + String.sub s start (!off - start) + +let exact str : unit t = + fun s off -> + skip_space s off; + let len = String.length str in + if !off + len > String.length s then + failwith @@ spf "unexpected EOF, expected %S" str; + for i = 0 to len - 1 do + let expected = String.unsafe_get str i in + let c = String.unsafe_get s (!off + i) in + if c <> expected then + failwith @@ spf "expected %S, got %C at position %d" str c i + done; + off := !off + len diff --git a/src/core/pool.ml b/src/core/pool.ml new file mode 100644 index 00000000..1a441944 --- /dev/null +++ b/src/core/pool.ml @@ -0,0 +1,51 @@ +module A = Tiny_httpd_atomic_ + +type 'a list_ = Nil | Cons of int * 'a * 'a list_ + +type 'a t = { + mk_item: unit -> 'a; + clear: 'a -> unit; + max_size: int; (** Max number of items *) + items: 'a list_ A.t; +} + +let create ?(clear = ignore) ~mk_item ?(max_size = 512) () : _ t = + { mk_item; clear; max_size; items = A.make Nil } + +let rec acquire_ self = + match A.get self.items with + | Nil -> self.mk_item () + | Cons (_, x, tl) as l -> + if A.compare_and_set self.items l tl then + x + else + acquire_ self + +let[@inline] size_ = function + | Cons (sz, _, _) -> sz + | Nil -> 0 + +let release_ self x : unit = + let rec loop () = + match A.get self.items with + | Cons (sz, _, _) when sz >= self.max_size -> + (* forget the item *) + () + | l -> + if not (A.compare_and_set self.items l (Cons (size_ l + 1, x, l))) then + loop () + in + + self.clear x; + loop () + +let with_resource (self : _ t) f = + let x = acquire_ self in + try + let res = f x in + release_ self x; + res + with e -> + let bt = Printexc.get_raw_backtrace () in + release_ self x; + Printexc.raise_with_backtrace e bt diff --git a/src/core/pool.mli b/src/core/pool.mli new file mode 100644 index 00000000..a2418e11 --- /dev/null +++ b/src/core/pool.mli @@ -0,0 +1,25 @@ +(** Resource pool. + + This pool is used for buffers. It can be used for other resources + but do note that it assumes resources are still reasonably + cheap to produce and discard, and will never block waiting for + a resource — it's not a good pool for DB connections. + + @since 0.14. *) + +type 'a t +(** Pool of values of type ['a] *) + +val create : + ?clear:('a -> unit) -> mk_item:(unit -> 'a) -> ?max_size:int -> unit -> 'a t +(** Create a new pool. + @param mk_item produce a new item in case the pool is empty + @param max_size maximum number of item in the pool before we start + dropping resources on the floor. This controls resource consumption. + @param clear a function called on items before recycling them. + *) + +val with_resource : 'a t -> ('a -> 'b) -> 'b +(** [with_resource pool f] runs [f x] with [x] a resource; + when [f] fails or returns, [x] is returned to the pool for + future reuse. *) diff --git a/src/core/request.ml b/src/core/request.ml new file mode 100644 index 00000000..1a5c19b7 --- /dev/null +++ b/src/core/request.ml @@ -0,0 +1,204 @@ +open Common_ + +type 'body t = { + meth: Meth.t; + host: string; + client_addr: Unix.sockaddr; + headers: Headers.t; + mutable meta: Hmap.t; + http_version: int * int; + path: string; + path_components: string list; + query: (string * string) list; + body: 'body; + start_time: float; +} + +let headers self = self.headers +let host self = self.host +let client_addr self = self.client_addr +let meth self = self.meth +let path self = self.path +let body self = self.body +let start_time self = self.start_time +let query self = self.query +let get_header ?f self h = Headers.get ?f h self.headers +let remove_header k self = { self with headers = Headers.remove k self.headers } + +let get_header_int self h = + match get_header self h with + | Some x -> (try Some (int_of_string x) with _ -> None) + | None -> None + +let set_header k v self = { self with headers = Headers.set k v self.headers } +let update_headers f self = { self with headers = f self.headers } +let set_body b self = { self with body = b } + +(** Should we close the connection after this request? *) +let close_after_req (self : _ t) : bool = + match self.http_version with + | 1, 1 -> get_header self "connection" = Some "close" + | 1, 0 -> not (get_header self "connection" = Some "keep-alive") + | _ -> false + +let pp_comp_ out comp = + Format.fprintf out "[%s]" + (String.concat ";" @@ List.map (Printf.sprintf "%S") comp) + +let pp_query out q = + Format.fprintf out "[%s]" + (String.concat ";" @@ List.map (fun (a, b) -> Printf.sprintf "%S,%S" a b) q) + +let pp_ out self : unit = + Format.fprintf out + "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ path=%S;@ body=?;@ \ + path_components=%a;@ query=%a@]}" + (Meth.to_string self.meth) self.host Headers.pp self.headers self.path + pp_comp_ self.path_components pp_query self.query + +let pp out self : unit = + Format.fprintf out + "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ path=%S;@ body=%S;@ \ + path_components=%a;@ query=%a@]}" + (Meth.to_string self.meth) self.host Headers.pp self.headers self.path + self.body pp_comp_ self.path_components pp_query self.query + +(* decode a "chunked" stream into a normal stream *) +let read_stream_chunked_ ~buf (bs : #IO.Input.t) : IO.Input.t = + Log.debug (fun k -> k "body: start reading chunked stream..."); + IO.Input.read_chunked ~buf ~fail:(fun s -> Bad_req (400, s)) bs + +let limit_body_size_ ~max_size ~bytes (bs : #IO.Input.t) : IO.Input.t = + Log.debug (fun k -> k "limit size of body to max-size=%d" max_size); + IO.Input.limit_size_to ~max_size ~close_rec:false ~bytes bs + +let limit_body_size ~max_size ?(bytes = Bytes.create 4096) (req : IO.Input.t t) + : IO.Input.t t = + { req with body = limit_body_size_ ~max_size ~bytes req.body } + +(** read exactly [size] bytes from the stream *) +let read_exactly ~size ~bytes (bs : #IO.Input.t) : IO.Input.t = + Log.debug (fun k -> k "body: must read exactly %d bytes" size); + IO.Input.reading_exactly bs ~close_rec:false ~size ~bytes + +(* parse request, but not body (yet) *) +let parse_req_start ~client_addr ~get_time_s ~buf (bs : IO.Input.t) : + unit t option resp_result = + try + let line = IO.Input.read_line_using ~buf bs in + let start_time = get_time_s () in + let meth, path, version = + try + let off = ref 0 in + let meth = Parse_.word line off in + let path = Parse_.word line off in + let http_version = Parse_.word line off in + let version = + match http_version with + | "HTTP/1.1" -> 1 + | "HTTP/1.0" -> 0 + | v -> invalid_arg (spf "unsupported HTTP version: %s" v) + in + meth, path, version + with + | Invalid_argument msg -> + Log.error (fun k -> k "invalid request line: `%s`: %s" line msg); + raise (Bad_req (400, "Invalid request line")) + | _ -> + Log.error (fun k -> k "invalid request line: `%s`" line); + raise (Bad_req (400, "Invalid request line")) + in + let meth = Meth.of_string meth in + Log.debug (fun k -> k "got meth: %s, path %S" (Meth.to_string meth) path); + let headers = Headers.parse_ ~buf bs in + let host = + match Headers.get "Host" headers with + | None -> bad_reqf 400 "No 'Host' header in request" + | Some h -> h + in + let path_components, query = Util.split_query path in + let path_components = Util.split_on_slash path_components in + let query = + match Util.parse_query query with + | Ok l -> l + | Error e -> bad_reqf 400 "invalid query: %s" e + in + let req = + { + meth; + query; + host; + meta = Hmap.empty; + client_addr; + path; + path_components; + headers; + http_version = 1, version; + body = (); + start_time; + } + in + Ok (Some req) + with + | End_of_file | Sys_error _ | Unix.Unix_error _ -> Ok None + | Bad_req (c, s) -> Error (c, s) + | e -> Error (400, Printexc.to_string e) + +(* parse body, given the headers. + @param tr_stream a transformation of the input stream. *) +let parse_body_ ~tr_stream ~buf (req : IO.Input.t t) : IO.Input.t t resp_result + = + try + let size = + match Headers.get_exn "Content-Length" req.headers |> int_of_string with + | n -> n (* body of fixed size *) + | exception Not_found -> 0 + | exception _ -> bad_reqf 400 "invalid content-length" + in + let body = + match get_header ~f:String.trim req "Transfer-Encoding" with + | None -> + let bytes = Bytes.create 4096 in + read_exactly ~size ~bytes @@ tr_stream req.body + | Some "chunked" -> + (* body sent by chunks *) + let bs : IO.Input.t = read_stream_chunked_ ~buf @@ tr_stream req.body in + if size > 0 then ( + let bytes = Bytes.create 4096 in + limit_body_size_ ~max_size:size ~bytes bs + ) else + bs + | Some s -> bad_reqf 500 "cannot handle transfer encoding: %s" s + in + Ok { req with body } + with + | End_of_file -> Error (400, "unexpected end of file") + | Bad_req (c, s) -> Error (c, s) + | e -> Error (400, Printexc.to_string e) + +let read_body_full ?buf ?buf_size (self : IO.Input.t t) : string t = + try + let buf = + match buf with + | Some b -> b + | None -> Buf.create ?size:buf_size () + in + let body = IO.Input.read_all_using ~buf self.body in + { self with body } + with + | Bad_req _ as e -> raise e + | e -> bad_reqf 500 "failed to read body: %s" (Printexc.to_string e) + +module Private_ = struct + let close_after_req = close_after_req + let parse_req_start = parse_req_start + + let parse_req_start_exn ?(buf = Buf.create ()) ~client_addr ~get_time_s bs = + parse_req_start ~client_addr ~get_time_s ~buf bs |> unwrap_resp_result + + let parse_body ?(buf = IO.Slice.create 4096) req bs : _ t = + parse_body_ ~tr_stream:(fun s -> s) ~buf { req with body = bs } + |> unwrap_resp_result + + let[@inline] set_body body self = { self with body } +end diff --git a/src/core/request.mli b/src/core/request.mli new file mode 100644 index 00000000..cf4f8368 --- /dev/null +++ b/src/core/request.mli @@ -0,0 +1,135 @@ +(** Requests + + Requests are sent by a client, e.g. a web browser or cURL. + From the point of view of the server, they're inputs. *) + +open Common_ + +type 'body t = private { + meth: Meth.t; (** HTTP method for this request. *) + host: string; + (** Host header, mandatory. It can also be found in {!headers}. *) + client_addr: Unix.sockaddr; (** Client address. Available since 0.14. *) + headers: Headers.t; (** List of headers. *) + mutable meta: Hmap.t; (** Metadata. @since NEXT_RELEASE *) + http_version: int * int; + (** HTTP version. This should be either [1, 0] or [1, 1]. *) + path: string; (** Full path of the requested URL. *) + path_components: string list; + (** Components of the path of the requested URL. *) + query: (string * string) list; (** Query part of the requested URL. *) + body: 'body; (** Body of the request. *) + start_time: float; + (** Obtained via [get_time_s] in {!create} + @since 0.11 *) +} +(** A request with method, path, host, headers, and a body, sent by a client. + + The body is polymorphic because the request goes through + several transformations. First it has no body, as only the request + and headers are read; then it has a stream body; then the body might be + entirely read as a string via {!read_body_full}. + + @since 0.6 The field [query] was added and contains the query parameters in ["?foo=bar,x=y"] + @since 0.6 The field [path_components] is the part of the path that precedes [query] and is split on ["/"]. + @since 0.11 the type is a private alias + @since 0.11 the field [start_time] was added + *) + +val pp : Format.formatter -> string t -> unit +(** Pretty print the request and its body. The exact format of this printing + is not specified. *) + +val pp_ : Format.formatter -> _ t -> unit +(** Pretty print the request without its body. The exact format of this printing + is not specified. *) + +val headers : _ t -> Headers.t +(** List of headers of the request, including ["Host"]. *) + +val get_header : ?f:(string -> string) -> _ t -> string -> string option +(** [get_header req h] looks up header [h] in [req]. It returns [None] if the + header is not present. This is case insensitive and should be used + rather than looking up [h] verbatim in [headers]. *) + +val get_header_int : _ t -> string -> int option +(** Same as {!get_header} but also performs a string to integer conversion. *) + +val set_header : string -> string -> 'a t -> 'a t +(** [set_header k v req] sets [k: v] in the request [req]'s headers. *) + +val remove_header : string -> 'a t -> 'a t +(** Remove one instance of this header. + @since NEXT_RELEASE *) + +val update_headers : (Headers.t -> Headers.t) -> 'a t -> 'a t +(** Modify headers using the given function. + @since 0.11 *) + +val set_body : 'a -> _ t -> 'a t +(** [set_body b req] returns a new query whose body is [b]. + @since 0.11 *) + +val host : _ t -> string +(** Host field of the request. It also appears in the headers. *) + +val client_addr : _ t -> Unix.sockaddr +(** Client address of the request. + @since 0.16 *) + +val meth : _ t -> Meth.t +(** Method for the request. *) + +val path : _ t -> string +(** Request path. *) + +val query : _ t -> (string * string) list +(** Decode the query part of the {!path} field. + @since 0.4 *) + +val body : 'b t -> 'b +(** Request body, possibly empty. *) + +val start_time : _ t -> float +(** time stamp (from {!Unix.gettimeofday}) after parsing the first line of the request + @since 0.11 *) + +val limit_body_size : + max_size:int -> ?bytes:bytes -> IO.Input.t t -> IO.Input.t t +(** Limit the body size to [max_size] bytes, or return + a [413] error. + @param bytes intermediate buffer + @since 0.3 + *) + +val read_body_full : ?buf:Buf.t -> ?buf_size:int -> IO.Input.t t -> string t +(** Read the whole body into a string. Potentially blocking. + + @param buf_size initial size of underlying buffer (since 0.11) + @param buf the initial buffer (since 0.14) + *) + +(**/**) + +(* for internal usage, do not use. There is no guarantee of stability. *) +module Private_ : sig + val parse_req_start : + client_addr:Unix.sockaddr -> + get_time_s:(unit -> float) -> + buf:Buf.t -> + IO.Input.t -> + unit t option resp_result + + val parse_req_start_exn : + ?buf:Buf.t -> + client_addr:Unix.sockaddr -> + get_time_s:(unit -> float) -> + IO.Input.t -> + unit t option + + val close_after_req : _ t -> bool + val parse_body : ?buf:IO.Slice.t -> unit t -> IO.Input.t -> IO.Input.t t + val set_body : 'a -> _ t -> 'a t +end + +(**/**) diff --git a/src/core/response.ml b/src/core/response.ml new file mode 100644 index 00000000..1bd7af66 --- /dev/null +++ b/src/core/response.ml @@ -0,0 +1,162 @@ +open Common_ + +type body = + [ `String of string | `Stream of IO.Input.t | `Writer of IO.Writer.t | `Void ] + +type t = { code: Response_code.t; headers: Headers.t; body: body } + +let set_body body self = { self with body } +let set_headers headers self = { self with headers } +let update_headers f self = { self with headers = f self.headers } +let set_header k v self = { self with headers = Headers.set k v self.headers } +let remove_header k self = { self with headers = Headers.remove k self.headers } +let set_code code self = { self with code } + +let make_raw ?(headers = []) ~code body : t = + (* add content length to response *) + let headers = + Headers.set "Content-Length" (string_of_int (String.length body)) headers + in + { code; headers; body = `String body } + +let make_raw_stream ?(headers = []) ~code body : t = + let headers = Headers.set "Transfer-Encoding" "chunked" headers in + { code; headers; body = `Stream body } + +let make_raw_writer ?(headers = []) ~code body : t = + let headers = Headers.set "Transfer-Encoding" "chunked" headers in + { code; headers; body = `Writer body } + +let make_void_force_ ?(headers = []) ~code () : t = + { code; headers; body = `Void } + +let make_void ?(headers = []) ~code () : t = + let is_ok = code < 200 || code = 204 || code = 304 in + if is_ok then + make_void_force_ ~headers ~code () + else + make_raw ~headers ~code "" (* invalid to not have a body *) + +let make_string ?headers ?(code = 200) r = + match r with + | Ok body -> make_raw ?headers ~code body + | Error (code, msg) -> make_raw ?headers ~code msg + +let make_stream ?headers ?(code = 200) r = + match r with + | Ok body -> make_raw_stream ?headers ~code body + | Error (code, msg) -> make_raw ?headers ~code msg + +let make_writer ?headers ?(code = 200) r : t = + match r with + | Ok body -> make_raw_writer ?headers ~code body + | Error (code, msg) -> make_raw ?headers ~code msg + +let make ?headers ?(code = 200) r : t = + match r with + | Ok (`String body) -> make_raw ?headers ~code body + | Ok (`Stream body) -> make_raw_stream ?headers ~code body + | Ok `Void -> make_void ?headers ~code () + | Ok (`Writer f) -> make_raw_writer ?headers ~code f + | Error (code, msg) -> make_raw ?headers ~code msg + +let fail ?headers ~code fmt = + Printf.ksprintf (fun msg -> make_raw ?headers ~code msg) fmt + +let fail_raise ~code fmt = + Printf.ksprintf (fun msg -> raise (Bad_req (code, msg))) fmt + +let pp out self : unit = + let pp_body out = function + | `String s -> Format.fprintf out "%S" s + | `Stream _ -> Format.pp_print_string out "" + | `Writer _ -> Format.pp_print_string out "" + | `Void -> () + in + Format.fprintf out "{@[code=%d;@ headers=[@[%a@]];@ body=%a@]}" self.code + Headers.pp self.headers pp_body self.body + +let output_ ~buf (oc : IO.Output.t) (self : t) : unit = + (* double indirection: + - print into [buffer] using [bprintf] + - transfer to [buf_] so we can output from there *) + let tmp_buffer = Buffer.create 32 in + Buf.clear buf; + + (* write start of reply *) + Printf.bprintf tmp_buffer "HTTP/1.1 %d %s\r\n" self.code + (Response_code.descr self.code); + Buf.add_buffer buf tmp_buffer; + Buffer.clear tmp_buffer; + + let body, is_chunked = + match self.body with + | `String s when String.length s > 1024 * 500 -> + (* chunk-encode large bodies *) + `Writer (IO.Writer.of_string s), true + | `String _ as b -> b, false + | `Stream _ as b -> b, true + | `Writer _ as b -> b, true + | `Void as b -> b, false + in + let headers = + if is_chunked then + self.headers + |> Headers.set "transfer-encoding" "chunked" + |> Headers.remove "content-length" + else + self.headers + in + let self = { self with headers; body } in + Log.debug (fun k -> + k "t[%d]: output response: %s" + (Thread.id @@ Thread.self ()) + (Format.asprintf "%a" pp { self with body = `String "<...>" })); + + (* write headers, using [buf] to batch writes *) + List.iter + (fun (k, v) -> + Printf.bprintf tmp_buffer "%s: %s\r\n" k v; + Buf.add_buffer buf tmp_buffer; + Buffer.clear tmp_buffer) + headers; + + IO.Output.output_buf oc buf; + IO.Output.output_string oc "\r\n"; + Buf.clear buf; + + (match body with + | `String "" | `Void -> () + | `String s -> IO.Output.output_string oc s + | `Writer w -> + (* use buffer to chunk encode [w] *) + let oc' = IO.Output.chunk_encoding ~buf ~close_rec:false oc in + (try + IO.Writer.write oc' w; + IO.Output.close oc' + with e -> + let bt = Printexc.get_raw_backtrace () in + IO.Output.close oc'; + IO.Output.flush oc; + Printexc.raise_with_backtrace e bt) + | `Stream str -> + (match IO.Input.output_chunked' ~buf oc str with + | () -> + Log.debug (fun k -> + k "t[%d]: done outputing stream" (Thread.id @@ Thread.self ())); + IO.Input.close str + | exception e -> + let bt = Printexc.get_raw_backtrace () in + Log.error (fun k -> + k "t[%d]: outputing stream failed with %s" + (Thread.id @@ Thread.self ()) + (Printexc.to_string e)); + IO.Input.close str; + IO.Output.flush oc; + Printexc.raise_with_backtrace e bt)); + IO.Output.flush oc + +module Private_ = struct + let make_void_force_ = make_void_force_ + let output_ = output_ +end diff --git a/src/core/response.mli b/src/core/response.mli new file mode 100644 index 00000000..610faed5 --- /dev/null +++ b/src/core/response.mli @@ -0,0 +1,118 @@ +(** Responses + + Responses are what a http server, such as {!Tiny_httpd}, send back to + the client to answer a {!Request.t}*) + +type body = + [ `String of string | `Stream of IO.Input.t | `Writer of IO.Writer.t | `Void ] +(** Body of a response, either as a simple string, + or a stream of bytes, or nothing (for server-sent events notably). + + - [`String str] replies with a body set to this string, and a known content-length. + - [`Stream str] replies with a body made from this string, using chunked encoding. + - [`Void] replies with no body. + - [`Writer w] replies with a body created by the writer [w], using + a chunked encoding. + It is available since 0.14. + *) + +type t = private { + code: Response_code.t; (** HTTP response code. See {!Response_code}. *) + headers: Headers.t; + (** Headers of the reply. Some will be set by [Tiny_httpd] automatically. *) + body: body; (** Body of the response. Can be empty. *) +} +(** A response to send back to a client. *) + +val set_body : body -> t -> t +(** Set the body of the response. + @since 0.11 *) + +val set_header : string -> string -> t -> t +(** Set a header. + @since 0.11 *) + +val update_headers : (Headers.t -> Headers.t) -> t -> t +(** Modify headers. + @since 0.11 *) + +val remove_header : string -> t -> t +(** Remove one instance of this header. + @since NEXT_RELEASE *) + +val set_headers : Headers.t -> t -> t +(** Set all headers. + @since 0.11 *) + +val set_code : Response_code.t -> t -> t +(** Set the response code. + @since 0.11 *) + +val make_raw : ?headers:Headers.t -> code:Response_code.t -> string -> t +(** Make a response from its raw components, with a string body. + Use [""] to not send a body at all. *) + +val make_raw_stream : + ?headers:Headers.t -> code:Response_code.t -> IO.Input.t -> t +(** Same as {!make_raw} but with a stream body. The body will be sent with + the chunked transfer-encoding. *) + +val make_void : ?headers:Headers.t -> code:int -> unit -> t +(** Return a response without a body at all. + @since 0.13 *) + +val make : + ?headers:Headers.t -> + ?code:int -> + (body, Response_code.t * string) result -> + t +(** [make r] turns a result into a response. + + - [make (Ok body)] replies with [200] and the body. + - [make (Error (code,msg))] replies with the given error code + and message as body. + *) + +val make_string : + ?headers:Headers.t -> + ?code:int -> + (string, Response_code.t * string) result -> + t +(** Same as {!make} but with a string body. *) + +val make_writer : + ?headers:Headers.t -> + ?code:int -> + (IO.Writer.t, Response_code.t * string) result -> + t +(** Same as {!make} but with a writer body. *) + +val make_stream : + ?headers:Headers.t -> + ?code:int -> + (IO.Input.t, Response_code.t * string) result -> + t +(** Same as {!make} but with a stream body. *) + +val fail : ?headers:Headers.t -> code:int -> ('a, unit, string, t) format4 -> 'a +(** Make the current request fail with the given code and message. + Example: [fail ~code:404 "oh noes, %s not found" "waldo"]. + *) + +val fail_raise : code:int -> ('a, unit, string, 'b) format4 -> 'a +(** Similar to {!fail} but raises an exception that exits the current handler. + This should not be used outside of a (path) handler. + Example: [fail_raise ~code:404 "oh noes, %s not found" "waldo"; never_executed()] + *) + +val pp : Format.formatter -> t -> unit +(** Pretty print the response. The exact format is not specified. *) + +(**/**) + +module Private_ : sig + val make_void_force_ : ?headers:Headers.t -> code:int -> unit -> t + val output_ : buf:Buf.t -> IO.Output.t -> t -> unit +end + +(**/**) diff --git a/src/core/response_code.ml b/src/core/response_code.ml new file mode 100644 index 00000000..cc97380d --- /dev/null +++ b/src/core/response_code.ml @@ -0,0 +1,32 @@ +type t = int + +let ok = 200 +let not_found = 404 + +let descr = function + | 100 -> "Continue" + | 200 -> "OK" + | 201 -> "Created" + | 202 -> "Accepted" + | 204 -> "No content" + | 300 -> "Multiple choices" + | 301 -> "Moved permanently" + | 302 -> "Found" + | 304 -> "Not Modified" + | 400 -> "Bad request" + | 401 -> "Unauthorized" + | 403 -> "Forbidden" + | 404 -> "Not found" + | 405 -> "Method not allowed" + | 408 -> "Request timeout" + | 409 -> "Conflict" + | 410 -> "Gone" + | 411 -> "Length required" + | 413 -> "Payload too large" + | 417 -> "Expectation failed" + | 500 -> "Internal server error" + | 501 -> "Not implemented" + | 503 -> "Service unavailable" + | n -> "Unknown response code " ^ string_of_int n (* TODO *) + +let[@inline] is_success n = n >= 200 && n < 400 diff --git a/src/core/response_code.mli b/src/core/response_code.mli new file mode 100644 index 00000000..fd0663d4 --- /dev/null +++ b/src/core/response_code.mli @@ -0,0 +1,20 @@ +(** Response Codes *) + +type t = int +(** A standard HTTP code. + + https://tools.ietf.org/html/rfc7231#section-6 *) + +val ok : t +(** The code [200] *) + +val not_found : t +(** The code [404] *) + +val descr : t -> string +(** A description of some of the error codes. + NOTE: this is not complete (yet). *) + +val is_success : t -> bool +(** [is_success code] is true iff [code] is in the [2xx] or [3xx] range. + @since NEXT_RELEASE *) diff --git a/src/core/route.ml b/src/core/route.ml new file mode 100644 index 00000000..f2e52f08 --- /dev/null +++ b/src/core/route.ml @@ -0,0 +1,93 @@ +type path = string list (* split on '/' *) + +type (_, _) comp = + | Exact : string -> ('a, 'a) comp + | Int : (int -> 'a, 'a) comp + | String : (string -> 'a, 'a) comp + | String_urlencoded : (string -> 'a, 'a) comp + +type (_, _) t = + | Fire : ('b, 'b) t + | Rest : { url_encoded: bool } -> (string -> 'b, 'b) t + | Compose : ('a, 'b) comp * ('b, 'c) t -> ('a, 'c) t + +let return = Fire +let rest_of_path = Rest { url_encoded = false } +let rest_of_path_urlencoded = Rest { url_encoded = true } +let ( @/ ) a b = Compose (a, b) +let string = String +let string_urlencoded = String_urlencoded +let int = Int +let exact (s : string) = Exact s + +let exact_path (s : string) tail = + let rec fn = function + | [] -> tail + | "" :: ls -> fn ls + | s :: ls -> exact s @/ fn ls + in + fn (String.split_on_char '/' s) + +let rec eval : type a b. path -> (a, b) t -> a -> b option = + fun path route f -> + match path, route with + | [], Fire -> Some f + | _, Fire -> None + | _, Rest { url_encoded } -> + let whole_path = String.concat "/" path in + (match + if url_encoded then ( + match Util.percent_decode whole_path with + | Some s -> s + | None -> raise_notrace Exit + ) else + whole_path + with + | whole_path -> Some (f whole_path) + | exception Exit -> None) + | c1 :: path', Compose (comp, route') -> + (match comp with + | Int -> + (match int_of_string c1 with + | i -> eval path' route' (f i) + | exception _ -> None) + | String -> eval path' route' (f c1) + | String_urlencoded -> + (match Util.percent_decode c1 with + | None -> None + | Some s -> eval path' route' (f s)) + | Exact s -> + if s = c1 then + eval path' route' f + else + None) + | [], Compose (String, Fire) -> Some (f "") (* trailing *) + | [], Compose (String_urlencoded, Fire) -> Some (f "") (* trailing *) + | [], Compose _ -> None + +let bpf = Printf.bprintf + +let rec pp_ : type a b. Buffer.t -> (a, b) t -> unit = + fun out -> function + | Fire -> bpf out "/" + | Rest { url_encoded } -> + bpf out "" + (if url_encoded then + "_urlencoded" + else + "") + | Compose (Exact s, tl) -> bpf out "%s/%a" s pp_ tl + | Compose (Int, tl) -> bpf out "/%a" pp_ tl + | Compose (String, tl) -> bpf out "/%a" pp_ tl + | Compose (String_urlencoded, tl) -> bpf out "/%a" pp_ tl + +let to_string x = + let b = Buffer.create 16 in + pp_ b x; + Buffer.contents b + +module Private_ = struct + let eval = eval +end + +let pp out x = Format.pp_print_string out (to_string x) diff --git a/src/core/route.mli b/src/core/route.mli new file mode 100644 index 00000000..4df45aba --- /dev/null +++ b/src/core/route.mli @@ -0,0 +1,58 @@ +(** Routing + + Basic type-safe routing of handlers based on URL paths. This is optional, + it is possible to only define the root handler with something like + {{: https://github.com/anuragsoni/routes/} Routes}. + @since 0.6 *) + +type ('a, 'b) comp +(** An atomic component of a path *) + +type ('a, 'b) t +(** A route, composed of path components *) + +val int : (int -> 'a, 'a) comp +(** Matches an integer. *) + +val string : (string -> 'a, 'a) comp +(** Matches a string not containing ['/'] and binds it as is. *) + +val string_urlencoded : (string -> 'a, 'a) comp +(** Matches a URL-encoded string, and decodes it. *) + +val exact : string -> ('a, 'a) comp +(** [exact "s"] matches ["s"] and nothing else. *) + +val return : ('a, 'a) t +(** Matches the empty path. *) + +val rest_of_path : (string -> 'a, 'a) t +(** Matches a string, even containing ['/']. This will match + the entirety of the remaining route. + @since 0.7 *) + +val rest_of_path_urlencoded : (string -> 'a, 'a) t +(** Matches a string, even containing ['/'], an URL-decode it. + This will match the entirety of the remaining route. + @since 0.7 *) + +val ( @/ ) : ('a, 'b) comp -> ('b, 'c) t -> ('a, 'c) t +(** [comp / route] matches ["foo/bar/…"] iff [comp] matches ["foo"], + and [route] matches ["bar/…"]. *) + +val exact_path : string -> ('a, 'b) t -> ('a, 'b) t +(** [exact_path "foo/bar/..." r] is equivalent to + [exact "foo" @/ exact "bar" @/ ... @/ r] + @since 0.11 **) + +val pp : Format.formatter -> _ t -> unit +(** Print the route. + @since 0.7 *) + +val to_string : _ t -> string +(** Print the route. + @since 0.7 *) + +module Private_ : sig + val eval : string list -> ('a, 'b) t -> 'a -> 'b option +end diff --git a/src/core/server.ml b/src/core/server.ml new file mode 100644 index 00000000..33bacc1a --- /dev/null +++ b/src/core/server.ml @@ -0,0 +1,515 @@ +open Common_ + +type resp_error = Response_code.t * string + +module Middleware = struct + type handler = IO.Input.t Request.t -> resp:(Response.t -> unit) -> unit + type t = handler -> handler + + (** Apply a list of middlewares to [h] *) + let apply_l (l : t list) (h : handler) : handler = + List.fold_right (fun m h -> m h) l h + + let[@inline] nil : t = fun h -> h +end + +(* a request handler. handles a single request. *) +type cb_path_handler = IO.Output.t -> Middleware.handler + +module type SERVER_SENT_GENERATOR = sig + val set_headers : Headers.t -> unit + + val send_event : + ?event:string -> ?id:string -> ?retry:string -> data:string -> unit -> unit + + val close : unit -> unit +end + +type server_sent_generator = (module SERVER_SENT_GENERATOR) + +(** Handler that upgrades to another protocol *) +module type UPGRADE_HANDLER = sig + type handshake_state + (** Some specific state returned after handshake *) + + val name : string + (** Name in the "upgrade" header *) + + val handshake : unit Request.t -> (Headers.t * handshake_state, string) result + (** Perform the handshake and upgrade the connection. The returned + code is [101] alongside these headers. *) + + val handle_connection : + Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit + (** Take control of the connection and take it from there *) +end + +type upgrade_handler = (module UPGRADE_HANDLER) + +exception Upgrade of unit Request.t * upgrade_handler + +module type IO_BACKEND = sig + val init_addr : unit -> string + val init_port : unit -> int + + val get_time_s : unit -> float + (** obtain the current timestamp in seconds. *) + + val tcp_server : unit -> IO.TCP_server.builder + (** Server that can listen on a port and handle clients. *) +end + +type handler_result = + | Handle of cb_path_handler + | Fail of resp_error + | Upgrade of upgrade_handler + +let unwrap_handler_result req = function + | Handle x -> x + | Fail (c, s) -> raise (Bad_req (c, s)) + | Upgrade up -> raise (Upgrade (req, up)) + +type t = { + backend: (module IO_BACKEND); + mutable tcp_server: IO.TCP_server.t option; + buf_size: int; + mutable handler: IO.Input.t Request.t -> Response.t; + (** toplevel handler, if any *) + mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) + mutable middlewares_sorted: (int * Middleware.t) list lazy_t; + (** sorted version of {!middlewares} *) + mutable path_handlers: (unit Request.t -> handler_result option) list; + (** path handlers *) + buf_pool: Buf.t Pool.t; +} + +let get_addr_ sock = + match Unix.getsockname sock with + | Unix.ADDR_INET (addr, port) -> addr, port + | _ -> invalid_arg "httpd: address is not INET" + +let addr (self : t) = + match self.tcp_server with + | None -> + let (module B) = self.backend in + B.init_addr () + | Some s -> fst @@ s.endpoint () + +let port (self : t) = + match self.tcp_server with + | None -> + let (module B) = self.backend in + B.init_port () + | Some s -> snd @@ s.endpoint () + +let active_connections (self : t) = + match self.tcp_server with + | None -> 0 + | Some s -> s.active_connections () + +let add_middleware ~stage self m = + let stage = + match stage with + | `Encoding -> 0 + | `Stage n when n < 1 -> invalid_arg "add_middleware: bad stage" + | `Stage n -> n + in + self.middlewares <- (stage, m) :: self.middlewares; + self.middlewares_sorted <- + lazy + (List.stable_sort (fun (s1, _) (s2, _) -> compare s1 s2) self.middlewares) + +let add_decode_request_cb self f = + (* turn it into a middleware *) + let m h req ~resp = + (* see if [f] modifies the stream *) + let req0 = Request.Private_.set_body () req in + match f req0 with + | None -> h req ~resp (* pass through *) + | Some (req1, tr_stream) -> + let body = tr_stream req.Request.body in + let req = Request.set_body body req1 in + h req ~resp + in + add_middleware self ~stage:`Encoding m + +let add_encode_response_cb self f = + let m h req ~resp = + h req ~resp:(fun r -> + let req0 = Request.Private_.set_body () req in + (* now transform [r] if we want to *) + match f req0 r with + | None -> resp r + | Some r' -> resp r') + in + add_middleware self ~stage:`Encoding m + +let set_top_handler self f = self.handler <- f + +(* route the given handler. + @param tr_req wraps the actual concrete function returned by the route + and makes it into a handler. *) +let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth + ~tr_req self (route : _ Route.t) f = + let ph req : handler_result option = + match meth with + | Some m when m <> req.Request.meth -> None (* ignore *) + | _ -> + (match Route.Private_.eval req.Request.path_components route f with + | Some handler -> + (* we have a handler, do we accept the request based on its headers? *) + (match accept req with + | Ok () -> + Some + (Handle + (fun oc -> + Middleware.apply_l middlewares @@ fun req ~resp -> + tr_req oc req ~resp handler)) + | Error err -> Some (Fail err)) + | None -> None (* path didn't match *)) + in + self.path_handlers <- ph :: self.path_handlers + +let add_route_handler (type a) ?accept ?middlewares ?meth self + (route : (a, _) Route.t) (f : _) : unit = + let tr_req _oc req ~resp f = + let req = + Pool.with_resource self.buf_pool @@ fun buf -> + Request.read_body_full ~buf req + in + resp (f req) + in + add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f + +let add_route_handler_stream ?accept ?middlewares ?meth self route f = + let tr_req _oc req ~resp f = resp (f req) in + add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f + +let[@inline] _opt_iter ~f o = + match o with + | None -> () + | Some x -> f x + +exception Exit_SSE + +let add_route_server_sent_handler ?accept self route f = + let tr_req (oc : IO.Output.t) req ~resp f = + let req = + Pool.with_resource self.buf_pool @@ fun buf -> + Request.read_body_full ~buf req + in + let headers = + ref Headers.(empty |> set "content-type" "text/event-stream") + in + + (* send response once *) + let resp_sent = ref false in + let send_response_idempotent_ () = + if not !resp_sent then ( + resp_sent := true; + (* send 200 response now *) + let initial_resp = + Response.Private_.make_void_force_ ~headers:!headers ~code:200 () + in + resp initial_resp + ) + in + + let[@inline] writef fmt = + Printf.ksprintf (IO.Output.output_string oc) fmt + in + + let send_event ?event ?id ?retry ~data () : unit = + send_response_idempotent_ (); + _opt_iter event ~f:(fun e -> writef "event: %s\n" e); + _opt_iter id ~f:(fun e -> writef "id: %s\n" e); + _opt_iter retry ~f:(fun e -> writef "retry: %s\n" e); + let l = String.split_on_char '\n' data in + List.iter (fun s -> writef "data: %s\n" s) l; + IO.Output.output_string oc "\n"; + (* finish group *) + IO.Output.flush oc + in + let module SSG = struct + let set_headers h = + if not !resp_sent then ( + headers := List.rev_append h !headers; + send_response_idempotent_ () + ) + + let send_event = send_event + let close () = raise Exit_SSE + end in + (try f req (module SSG : SERVER_SENT_GENERATOR) + with Exit_SSE -> IO.Output.close oc); + Log.info (fun k -> k "closed SSE connection") + in + add_route_handler_ self ?accept ~meth:`GET route ~tr_req f + +let add_upgrade_handler ?(accept = fun _ -> Ok ()) (self : t) route f : unit = + let ph req : handler_result option = + if req.Request.meth <> `GET then + None + else ( + match accept req with + | Ok () -> + (match Route.Private_.eval req.Request.path_components route f with + | Some up -> Some (Upgrade up) + | None -> None (* path didn't match *)) + | Error err -> Some (Fail err) + ) + in + self.path_handlers <- ph :: self.path_handlers + +let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t = + let handler _req = Response.fail ~code:404 "no top handler" in + let self = + { + backend; + tcp_server = None; + handler; + buf_size; + path_handlers = []; + middlewares = []; + middlewares_sorted = lazy []; + buf_pool = + Pool.create ~clear:Buf.clear_and_zero + ~mk_item:(fun () -> Buf.create ~size:buf_size ()) + (); + } + in + List.iter (fun (stage, m) -> add_middleware self ~stage m) middlewares; + self + +let is_ipv6_str addr : bool = String.contains addr ':' + +let stop (self : t) = + match self.tcp_server with + | None -> () + | Some s -> s.stop () + +let running (self : t) = + match self.tcp_server with + | None -> false + | Some s -> s.running () + +let find_map f l = + let rec aux f = function + | [] -> None + | x :: l' -> + (match f x with + | Some _ as res -> res + | None -> aux f l') + in + aux f l + +let string_as_list_contains_ (s : string) (sub : string) : bool = + let fragments = String.split_on_char ',' s in + List.exists (fun fragment -> String.trim fragment = sub) fragments + +(* handle client on [ic] and [oc] *) +let client_handle_for (self : t) ~client_addr ic oc : unit = + Pool.with_resource self.buf_pool @@ fun buf -> + Pool.with_resource self.buf_pool @@ fun buf_res -> + let (module B) = self.backend in + + (* how to log the response to this query *) + let log_response (req : _ Request.t) (resp : Response.t) = + if not Log.dummy then ( + let msgf k = + let elapsed = B.get_time_s () -. req.start_time in + k + ("response to=%s code=%d time=%.3fs path=%S" : _ format4) + (Util.show_sockaddr client_addr) + resp.code elapsed req.path + in + if Response_code.is_success resp.code then + Log.info msgf + else + Log.error msgf + ) + in + + let log_exn msg bt = + Log.error (fun k -> + k "error while processing response for %s msg=%s@.%s" + (Util.show_sockaddr client_addr) + msg + (Printexc.raw_backtrace_to_string bt)) + in + + (* handle generic exception *) + let handle_exn e bt : unit = + let msg = Printexc.to_string e in + let resp = Response.fail ~code:500 "server error: %s" msg in + if not Log.dummy then log_exn msg bt; + Response.Private_.output_ ~buf:buf_res oc resp + in + + let handle_bad_req req e bt = + let msg = Printexc.to_string e in + let resp = Response.fail ~code:500 "server error: %s" msg in + if not Log.dummy then ( + log_exn msg bt; + log_response req resp + ); + Response.Private_.output_ ~buf:buf_res oc resp + in + + let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit = + Log.debug (fun k -> k "upgrade connection"); + try + (* check headers *) + (match Request.get_header req "connection" with + | Some str when string_as_list_contains_ str "Upgrade" -> () + | _ -> bad_reqf 426 "connection header must contain 'Upgrade'"); + (match Request.get_header req "upgrade" with + | Some u when u = UP.name -> () + | Some u -> bad_reqf 426 "expected upgrade to be '%s', got '%s'" UP.name u + | None -> bad_reqf 426 "expected 'connection: upgrade' header"); + + (* ok, this is the upgrade we expected *) + match UP.handshake req with + | Error msg -> + (* fail the upgrade *) + Log.error (fun k -> k "upgrade failed: %s" msg); + let resp = Response.make_raw ~code:429 "upgrade required" in + log_response req resp; + Response.Private_.output_ ~buf:buf_res oc resp + | Ok (headers, handshake_st) -> + (* send the upgrade reply *) + let headers = + [ "connection", "upgrade"; "upgrade", UP.name ] @ headers + in + let resp = Response.make_string ~code:101 ~headers (Ok "") in + log_response req resp; + Response.Private_.output_ ~buf:buf_res oc resp; + + UP.handle_connection client_addr handshake_st ic oc + with e -> + let bt = Printexc.get_raw_backtrace () in + handle_bad_req req e bt + in + + let continue = ref true in + + let handle_one_req () = + match + Request.Private_.parse_req_start ~client_addr ~get_time_s:B.get_time_s + ~buf ic + with + | Ok None -> continue := false (* client is done *) + | Error (c, s) -> + (* connection error, close *) + let res = Response.make_raw ~code:c s in + (try Response.Private_.output_ ~buf:buf_res oc res + with Sys_error _ -> ()); + continue := false + | Ok (Some req) -> + Log.debug (fun k -> + k "t[%d]: parsed request: %s" + (Thread.id @@ Thread.self ()) + (Format.asprintf "@[%a@]" Request.pp_ req)); + + if Request.Private_.close_after_req req then continue := false; + + (try + (* is there a handler for this path? *) + let base_handler = + match find_map (fun ph -> ph req) self.path_handlers with + | Some f -> unwrap_handler_result req f + | None -> fun _oc req ~resp -> resp (self.handler req) + in + + (* handle expect/continue *) + (match Request.get_header ~f:String.trim req "Expect" with + | Some "100-continue" -> + Log.debug (fun k -> k "send back: 100 CONTINUE"); + Response.Private_.output_ ~buf:buf_res oc + (Response.make_raw ~code:100 "") + | Some s -> bad_reqf 417 "unknown expectation %s" s + | None -> ()); + + (* apply middlewares *) + let handler oc = + List.fold_right + (fun (_, m) h -> m h) + (Lazy.force self.middlewares_sorted) + (base_handler oc) + in + + (* now actually read request's body into a stream *) + let req = + Request.Private_.parse_body + ~buf:(IO.Slice.of_bytes (Buf.bytes_slice buf)) + req ic + in + + (* how to reply *) + let resp r = + try + if Headers.get "connection" r.Response.headers = Some "close" then + continue := false; + log_response req r; + Response.Private_.output_ ~buf:buf_res oc r + with Sys_error e -> + Log.debug (fun k -> + k "error when writing response: %s@.connection broken" e); + continue := false + in + + (* call handler *) + try handler oc req ~resp + with Sys_error e -> + Log.debug (fun k -> + k "error while handling request: %s@.connection broken" e); + continue := false + with + | Sys_error e -> + (* connection broken somehow *) + Log.debug (fun k -> k "error: %s@. connection broken" e); + continue := false + | Bad_req (code, s) -> + continue := false; + let resp = Response.make_raw ~code s in + log_response req resp; + Response.Private_.output_ ~buf:buf_res oc resp + | Upgrade _ as e -> raise e + | e -> + let bt = Printexc.get_raw_backtrace () in + handle_bad_req req e bt) + in + + try + while !continue && running self do + Log.debug (fun k -> + k "t[%d]: read next request" (Thread.id @@ Thread.self ())); + handle_one_req () + done + with + | Upgrade (req, up) -> + (* upgrades take over the whole connection, we won't process + any further request *) + handle_upgrade req up + | e -> + let bt = Printexc.get_raw_backtrace () in + handle_exn e bt + +let client_handler (self : t) : IO.TCP_server.conn_handler = + { IO.TCP_server.handle = client_handle_for self } + +let is_ipv6 (self : t) = + let (module B) = self.backend in + is_ipv6_str (B.init_addr ()) + +let run_exn ?(after_init = ignore) (self : t) : unit = + let (module B) = self.backend in + let server = B.tcp_server () in + server.serve + ~after_init:(fun tcp_server -> + self.tcp_server <- Some tcp_server; + after_init ()) + ~handle:(client_handler self) () + +let run ?after_init self : _ result = + try Ok (run_exn ?after_init self) with e -> Error e diff --git a/src/core/server.mli b/src/core/server.mli new file mode 100644 index 00000000..e856c7e4 --- /dev/null +++ b/src/core/server.mli @@ -0,0 +1,298 @@ +(** HTTP server. + + This module implements a very simple, basic HTTP/1.1 server using blocking + IOs and threads. + + It is possible to use a thread pool, see {!create}'s argument [new_thread]. + + @since 0.13 +*) + +(** {2 Middlewares} + + A middleware can be inserted in a handler to modify or observe + its behavior. + + @since 0.11 +*) + +module Middleware : sig + type handler = IO.Input.t Request.t -> resp:(Response.t -> unit) -> unit + (** Handlers are functions returning a response to a request. + The response can be delayed, hence the use of a continuation + as the [resp] parameter. *) + + type t = handler -> handler + (** A middleware is a handler transformation. + + It takes the existing handler [h], + and returns a new one which, given a query, modify it or log it + before passing it to [h], or fail. It can also log or modify or drop + the response. *) + + val nil : t + (** Trivial middleware that does nothing. *) +end + +(** {2 Main Server type} *) + +type t +(** A HTTP server. See {!create} for more details. *) + +(** A backend that provides IO operations, network operations, etc. + + This is used to decouple tiny_httpd from the scheduler/IO library used to + actually open a TCP server and talk to clients. The classic way is + based on {!Unix} and blocking IOs, but it's also possible to + use an OCaml 5 library using effects and non blocking IOs. *) +module type IO_BACKEND = sig + val init_addr : unit -> string + (** Initial TCP address *) + + val init_port : unit -> int + (** Initial port *) + + val get_time_s : unit -> float + (** Obtain the current timestamp in seconds. *) + + val tcp_server : unit -> IO.TCP_server.builder + (** TCP server builder, to create servers that can listen + on a port and handle clients. *) +end + +val create_from : + ?buf_size:int -> + ?middlewares:([ `Encoding | `Stage of int ] * Middleware.t) list -> + backend:(module IO_BACKEND) -> + unit -> + t +(** Create a new webserver using provided backend. + + The server will not do anything until {!run} is called on it. + Before starting the server, one can use {!add_path_handler} and + {!set_top_handler} to specify how to handle incoming requests. + + @param buf_size size for buffers (since 0.11) + @param middlewares see {!add_middleware} for more details. + + @since 0.14 +*) + +val addr : t -> string +(** Address on which the server listens. *) + +val is_ipv6 : t -> bool +(** [is_ipv6 server] returns [true] iff the address of the server is an IPv6 address. + @since 0.3 *) + +val port : t -> int +(** Port on which the server listens. Note that this might be different than + the port initially given if the port was [0] (meaning that the OS picks a + port for us). *) + +val active_connections : t -> int +(** Number of currently active connections. *) + +val add_decode_request_cb : + t -> + (unit Request.t -> (unit Request.t * (IO.Input.t -> IO.Input.t)) option) -> + unit + [@@deprecated "use add_middleware"] +(** Add a callback for every request. + The callback can provide a stream transformer and a new request (with + modified headers, typically). + A possible use is to handle decompression by looking for a [Transfer-Encoding] + header and returning a stream transformer that decompresses on the fly. + + @deprecated use {!add_middleware} instead +*) + +val add_encode_response_cb : + t -> (unit Request.t -> Response.t -> Response.t option) -> unit + [@@deprecated "use add_middleware"] +(** Add a callback for every request/response pair. + Similarly to {!add_encode_response_cb} the callback can return a new + response, for example to compress it. + The callback is given the query with only its headers, + as well as the current response. + + @deprecated use {!add_middleware} instead +*) + +val add_middleware : + stage:[ `Encoding | `Stage of int ] -> t -> Middleware.t -> unit +(** Add a middleware to every request/response pair. + @param stage specify when middleware applies. + Encoding comes first (outermost layer), then stages in increasing order. + @raise Invalid_argument if stage is [`Stage n] where [n < 1] + @since 0.11 +*) + +(** {2 Request handlers} *) + +val set_top_handler : t -> (IO.Input.t Request.t -> Response.t) -> unit +(** Setup a handler called by default. + + This handler is called with any request not accepted by any handler + installed via {!add_path_handler}. + If no top handler is installed, unhandled paths will return a [404] not found + + This used to take a [string Request.t] but it now takes a [byte_stream Request.t] + since 0.14 . Use {!Request.read_body_full} to read the body into + a string if needed. +*) + +val add_route_handler : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> + ?meth:Meth.t -> + t -> + ('a, string Request.t -> Response.t) Route.t -> + 'a -> + unit +(** [add_route_handler server Route.(exact "path" @/ string @/ int @/ return) f] + calls [f "foo" 42 request] when a [request] with path "path/foo/42/" + is received. + + Note that the handlers are called in the reverse order of their addition, + so the last registered handler can override previously registered ones. + + @param meth if provided, only accept requests with the given method. + Typically one could react to [`GET] or [`PUT]. + @param accept should return [Ok()] if the given request (before its body + is read) should be accepted, [Error (code,message)] if it's to be rejected (e.g. because + its content is too big, or for some permission error). + See the {!http_of_dir} program for an example of how to use [accept] to + filter uploads that are too large before the upload even starts. + The default always returns [Ok()], i.e. it accepts all requests. + + @since 0.6 +*) + +val add_route_handler_stream : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> + ?meth:Meth.t -> + t -> + ('a, IO.Input.t Request.t -> Response.t) Route.t -> + 'a -> + unit +(** Similar to {!add_route_handler}, but where the body of the request + is a stream of bytes that has not been read yet. + This is useful when one wants to stream the body directly into a parser, + json decoder (such as [Jsonm]) or into a file. + @since 0.6 *) + +(** {2 Server-sent events} + + {b EXPERIMENTAL}: this API is not stable yet. *) + +(** A server-side function to generate of Server-sent events. + + See {{: https://html.spec.whatwg.org/multipage/server-sent-events.html} the w3c page} + and {{: https://jvns.ca/blog/2021/01/12/day-36--server-sent-events-are-cool--and-a-fun-bug/} + this blog post}. + + @since 0.9 + *) +module type SERVER_SENT_GENERATOR = sig + val set_headers : Headers.t -> unit + (** Set headers of the response. + This is not mandatory but if used at all, it must be called before + any call to {!send_event} (once events are sent the response is + already sent too). *) + + val send_event : + ?event:string -> ?id:string -> ?retry:string -> data:string -> unit -> unit + (** Send an event from the server. + If data is a multiline string, it will be sent on separate "data:" lines. *) + + val close : unit -> unit + (** Close connection. + @since 0.11 *) +end + +type server_sent_generator = (module SERVER_SENT_GENERATOR) +(** Server-sent event generator. This generates events that are forwarded to + the client (e.g. the browser). + @since 0.9 *) + +val add_route_server_sent_handler : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + t -> + ('a, string Request.t -> server_sent_generator -> unit) Route.t -> + 'a -> + unit +(** Add a handler on an endpoint, that serves server-sent events. + + The callback is given a generator that can be used to send events + as it pleases. The connection is always closed by the client, + and the accepted method is always [GET]. + This will set the header "content-type" to "text/event-stream" automatically + and reply with a 200 immediately. + See {!server_sent_generator} for more details. + + This handler stays on the original thread (it is synchronous). + + @since 0.9 *) + +(** {2 Upgrade handlers} + + These handlers upgrade the connection to another protocol. + @since NEXT_RELEASE *) + +(** Handler that upgrades to another protocol. + @since NEXT_RELEASE *) +module type UPGRADE_HANDLER = sig + type handshake_state + (** Some specific state returned after handshake *) + + val name : string + (** Name in the "upgrade" header *) + + val handshake : unit Request.t -> (Headers.t * handshake_state, string) result + (** Perform the handshake and upgrade the connection. The returned + code is [101] alongside these headers. + In case the handshake fails, this only returns [Error log_msg]. + The connection is closed without further ado. *) + + val handle_connection : + Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit + (** Take control of the connection and take it from ther.e *) +end + +type upgrade_handler = (module UPGRADE_HANDLER) +(** @since NEXT_RELEASE *) + +val add_upgrade_handler : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + t -> + ('a, upgrade_handler) Route.t -> + 'a -> + unit + +(** {2 Run the server} *) + +val running : t -> bool +(** Is the server running? + @since 0.14 *) + +val stop : t -> unit +(** Ask the server to stop. This might not have an immediate effect + as {!run} might currently be waiting on IO. *) + +val run : ?after_init:(unit -> unit) -> t -> (unit, exn) result +(** Run the main loop of the server, listening on a socket + described at the server's creation time, using [new_thread] to + start a thread for each new client. + + This returns [Ok ()] if the server exits gracefully, or [Error e] if + it exits with an error. + + @param after_init is called after the server starts listening. since 0.13 . +*) + +val run_exn : ?after_init:(unit -> unit) -> t -> unit +(** [run_exn s] is like [run s] but re-raises an exception if the server exits + with an error. + @since 0.14 *) diff --git a/src/core/stream.ml.tmp b/src/core/stream.ml.tmp new file mode 100644 index 00000000..607a4f5f --- /dev/null +++ b/src/core/stream.ml.tmp @@ -0,0 +1,294 @@ +(* +module Buf = Tiny_httpd_buf +module IO = Tiny_httpd_io + +let spf = Printf.sprintf + +type hidden = unit + +type t = { + mutable bs: bytes; + mutable off: int; + mutable len: int; + fill_buf: unit -> unit; + consume: int -> unit; + close: unit -> unit; + _rest: hidden; +} + +let[@inline] close self = self.close () + +let empty = + { + bs = Bytes.empty; + off = 0; + len = 0; + fill_buf = ignore; + consume = ignore; + close = ignore; + _rest = (); + } + +let make ?(bs = Bytes.create @@ (16 * 1024)) ?(close = ignore) ~consume ~fill () + : t = + let rec self = + { + bs; + off = 0; + len = 0; + close = (fun () -> close self); + fill_buf = (fun () -> if self.len = 0 then fill self); + consume = + (fun n -> + assert (n <= self.len); + consume self n); + _rest = (); + } + in + self + +let of_input ?(buf_size = 16 * 1024) (ic : IO.Input.t) : t = + make ~bs:(Bytes.create buf_size) + ~close:(fun _ -> IO.Input.close ic) + ~consume:(fun self n -> + assert (self.len >= n); + self.off <- self.off + n; + self.len <- self.len - n) + ~fill:(fun self -> + if self.len = 0 then ( + self.off <- 0; + self.len <- IO.Input.input ic self.bs 0 (Bytes.length self.bs) + )) + () + +let of_chan_ ?buf_size ic ~close_noerr : t = + let inc = IO.Input.of_in_channel ~close_noerr ic in + of_input ?buf_size inc + +let of_chan ?buf_size ic = of_chan_ ?buf_size ic ~close_noerr:false +let of_chan_close_noerr ?buf_size ic = of_chan_ ?buf_size ic ~close_noerr:true + +let of_fd_ ?buf_size ~close_noerr ~closed ic : t = + let inc = IO.Input.of_unix_fd ~close_noerr ~closed ic in + of_input ?buf_size inc + +let of_fd ?buf_size ~closed fd : t = + of_fd_ ?buf_size ~closed ~close_noerr:false fd + +let of_fd_close_noerr ?buf_size ~closed fd : t = + of_fd_ ?buf_size ~closed ~close_noerr:true fd + +let iter f (self : t) : unit = + let continue = ref true in + while !continue do + self.fill_buf (); + if self.len = 0 then ( + continue := false; + self.close () + ) else ( + f self.bs self.off self.len; + self.consume self.len + ) + done + +let to_chan (oc : out_channel) (self : t) = iter (output oc) self +let to_chan' (oc : IO.Output.t) (self : t) = iter (IO.Output.output oc) self + +let to_writer (self : t) : Tiny_httpd_io.Writer.t = + { write = (fun oc -> to_chan' oc self) } + +let of_bytes ?(i = 0) ?len (bs : bytes) : t = + (* invariant: !i+!len is constant *) + let len = + match len with + | Some n -> + if n > Bytes.length bs - i then invalid_arg "Byte_stream.of_bytes"; + n + | None -> Bytes.length bs - i + in + let self = + make ~bs ~fill:ignore + ~close:(fun self -> self.len <- 0) + ~consume:(fun self n -> + assert (n >= 0 && n <= self.len); + self.off <- n + self.off; + self.len <- self.len - n) + () + in + self.off <- i; + self.len <- len; + self + +let of_string s : t = of_bytes (Bytes.unsafe_of_string s) + +let with_file ?buf_size file f = + let ic = Unix.(openfile file [ O_RDONLY ] 0) in + try + let x = f (of_fd ?buf_size ~closed:(ref false) ic) in + Unix.close ic; + x + with e -> + Unix.close ic; + raise e + +let read_all ?(buf = Buf.create ()) (self : t) : string = + let continue = ref true in + while !continue do + self.fill_buf (); + if self.len = 0 then + continue := false + else ( + assert (self.len > 0); + Buf.add_bytes buf self.bs self.off self.len; + self.consume self.len + ) + done; + Buf.contents_and_clear buf + +(* put [n] bytes from the input into bytes *) +let read_exactly_ ~too_short (self : t) (bytes : bytes) (n : int) : unit = + assert (Bytes.length bytes >= n); + let offset = ref 0 in + while !offset < n do + self.fill_buf (); + let n_read = min self.len (n - !offset) in + Bytes.blit self.bs self.off bytes !offset n_read; + offset := !offset + n_read; + self.consume n_read; + if n_read = 0 then too_short () + done + +(* read a line into the buffer, after clearing it. *) +let read_line_into (self : t) ~buf : unit = + Buf.clear buf; + let continue = ref true in + while !continue do + self.fill_buf (); + if self.len = 0 then ( + continue := false; + if Buf.size buf = 0 then raise End_of_file + ); + let j = ref self.off in + while !j < self.off + self.len && Bytes.get self.bs !j <> '\n' do + incr j + done; + if !j - self.off < self.len then ( + assert (Bytes.get self.bs !j = '\n'); + (* line without '\n' *) + Buf.add_bytes buf self.bs self.off (!j - self.off); + (* consume line + '\n' *) + self.consume (!j - self.off + 1); + continue := false + ) else ( + Buf.add_bytes buf self.bs self.off self.len; + self.consume self.len + ) + done + + +(* read exactly [size] bytes from the stream *) +let read_exactly ~close_rec ~size ~too_short (arg : t) : t = + if size = 0 then + empty + else ( + let size = ref size in + make ~bs:Bytes.empty + ~fill:(fun res -> + (* must not block on [arg] if we're done *) + if !size = 0 then ( + res.bs <- Bytes.empty; + res.off <- 0; + res.len <- 0 + ) else ( + arg.fill_buf (); + res.bs <- arg.bs; + res.off <- arg.off; + let len = min arg.len !size in + if len = 0 && !size > 0 then too_short !size; + res.len <- len + )) + ~close:(fun _res -> + (* close underlying stream if [close_rec] *) + if close_rec then arg.close (); + size := 0) + ~consume:(fun res n -> + let n = min n !size in + size := !size - n; + arg.consume n; + res.off <- res.off + n; + res.len <- res.len - n) + () + ) + +let read_line ?(buf = Buf.create ()) self : string = + read_line_into self ~buf; + Buf.contents buf + +let read_chunked ?(buf = Buf.create ()) ~fail (bs : t) : t = + let first = ref true in + let read_next_chunk_len () : int = + if !first then + first := false + else ( + let line = read_line ~buf bs in + if String.trim line <> "" then raise (fail "expected crlf between chunks") + ); + let line = read_line ~buf bs in + (* parse chunk length, ignore extensions *) + let chunk_size = + if String.trim line = "" then + 0 + else ( + try + let off = ref 0 in + let n = Tiny_httpd_parse_.pos_hex line off in + n + with _ -> + raise (fail (spf "cannot read chunk size from line %S" line)) + ) + in + chunk_size + in + let refill = ref true in + let chunk_size = ref 0 in + make + ~bs:(Bytes.create (16 * 4096)) + ~fill:(fun self -> + (* do we need to refill? *) + if self.len = 0 then ( + if !chunk_size = 0 && !refill then chunk_size := read_next_chunk_len (); + self.off <- 0; + self.len <- 0; + if !chunk_size > 0 then ( + (* read the whole chunk, or [Bytes.length bytes] of it *) + let to_read = min !chunk_size (Bytes.length self.bs) in + read_exactly_ + ~too_short:(fun () -> raise (fail "chunk is too short")) + bs self.bs to_read; + self.len <- to_read; + chunk_size := !chunk_size - to_read + ) else + refill := false (* stream is finished *) + )) + ~consume:(fun self n -> + self.off <- self.off + n; + self.len <- self.len - n) + ~close:(fun self -> + (* close this overlay, do not close underlying stream *) + self.len <- 0; + refill := false) + () + +let output_chunked' ?buf (oc : IO.Output.t) (self : t) : unit = + let oc' = IO.Output.chunk_encoding ?buf oc ~close_rec:false in + match to_chan' oc' self with + | () -> IO.Output.close oc' + | exception e -> + let bt = Printexc.get_raw_backtrace () in + IO.Output.close oc'; + Printexc.raise_with_backtrace e bt + +(* print a stream as a series of chunks *) +let output_chunked ?buf (oc : out_channel) (self : t) : unit = + output_chunked' ?buf (IO.Output.of_out_channel oc) self + *) diff --git a/src/core/stream.mli.tmp b/src/core/stream.mli.tmp new file mode 100644 index 00000000..5d2facad --- /dev/null +++ b/src/core/stream.mli.tmp @@ -0,0 +1,159 @@ +(** Byte streams. + + Streams are used to represent a series of bytes that can arrive progressively. + For example, an uploaded file will be sent as a series of chunks. + + These used to live in {!Tiny_httpd} but are now in their own module. + @since 0.12 *) + +type hidden +(** Type used to make {!t} unbuildable via a record literal. Use {!make} instead. *) + +type t = { + mutable bs: bytes; (** The bytes *) + mutable off: int; (** Beginning of valid slice in {!bs} *) + mutable len: int; + (** Length of valid slice in {!bs}. If [len = 0] after + a call to {!fill}, then the stream is finished. *) + fill_buf: unit -> unit; + (** See the current slice of the internal buffer as [bytes, i, len], + where the slice is [bytes[i] .. [bytes[i+len-1]]]. + Can block to refill the buffer if there is currently no content. + If [len=0] then there is no more data. *) + consume: int -> unit; + (** Consume [n] bytes from the buffer. + This should only be called with [n <= len]. *) + close: unit -> unit; (** Close the stream. *) + _rest: hidden; (** Use {!make} to build a stream. *) +} +(** A buffered stream, with a view into the current buffer (or refill if empty), + and a function to consume [n] bytes. + + The point of this type is that it gives the caller access to its internal buffer + ([bs], with the slice [off,len]). This is convenient for things like line + reading where one needs to peek ahead. + + Some core invariant for this type of stream are: + - [off,len] delimits a valid slice in [bs] (indices: [off, off+1, … off+len-1]) + - if [fill_buf()] was just called, then either [len=0] which indicates the end + of stream; or [len>0] and the slice contains some data. + + To actually move forward in the stream, you can call [consume n] + to consume [n] bytes (where [n <= len]). If [len] gets to [0], calling + [fill_buf()] is required, so it can try to obtain a new slice. + + To emulate a classic OCaml reader with a [read: bytes -> int -> int -> int] function, + the simplest is: + + {[ + let read (self:t) buf offset max_len : int = + self.fill_buf(); + let len = min max_len self.len in + if len > 0 then ( + Bytes.blit self.bs self.off buf offset len; + self.consume len; + ); + len + + ]} +*) + +val close : t -> unit +(** Close stream *) + +val empty : t +(** Stream with 0 bytes inside *) + +val of_input : ?buf_size:int -> Io.Input.t -> t +(** Make a buffered stream from the given channel. + @since 0.14 *) + +val of_chan : ?buf_size:int -> in_channel -> t +(** Make a buffered stream from the given channel. *) + +val of_chan_close_noerr : ?buf_size:int -> in_channel -> t +(** Same as {!of_chan} but the [close] method will never fail. *) + +val of_fd : ?buf_size:int -> closed:bool ref -> Unix.file_descr -> t +(** Make a buffered stream from the given file descriptor. *) + +val of_fd_close_noerr : ?buf_size:int -> closed:bool ref -> Unix.file_descr -> t +(** Same as {!of_fd} but the [close] method will never fail. *) + +val of_bytes : ?i:int -> ?len:int -> bytes -> t +(** A stream that just returns the slice of bytes starting from [i] + and of length [len]. *) + +val of_string : string -> t + +val iter : (bytes -> int -> int -> unit) -> t -> unit +(** Iterate on the chunks of the stream + @since 0.3 *) + +val to_chan : out_channel -> t -> unit +(** Write the stream to the channel. + @since 0.3 *) + +val to_chan' : Tiny_httpd_io.Output.t -> t -> unit +(** Write to the IO channel. + @since 0.14 *) + +val to_writer : t -> Tiny_httpd_io.Writer.t +(** Turn this stream into a writer. + @since 0.14 *) + +val make : + ?bs:bytes -> + ?close:(t -> unit) -> + consume:(t -> int -> unit) -> + fill:(t -> unit) -> + unit -> + t +(** [make ~fill ()] creates a byte stream. + @param fill is used to refill the buffer, and is called initially. + @param close optional closing. + @param init_size size of the buffer. +*) + +val with_file : ?buf_size:int -> string -> (t -> 'a) -> 'a +(** Open a file with given name, and obtain an input stream + on its content. When the function returns, the stream (and file) are closed. *) + +val read_line : ?buf:Tiny_httpd_buf.t -> t -> string +(** Read a line from the stream. + @param buf a buffer to (re)use. Its content will be cleared. *) + +val read_all : ?buf:Tiny_httpd_buf.t -> t -> string +(** Read the whole stream into a string. + @param buf a buffer to (re)use. Its content will be cleared. *) + +val limit_size_to : + close_rec:bool -> max_size:int -> too_big:(int -> unit) -> t -> t +(* New stream with maximum size [max_size]. + @param close_rec if true, closing this will also close the input stream + @param too_big called with read size if the max size is reached *) + +val read_chunked : ?buf:Tiny_httpd_buf.t -> fail:(string -> exn) -> t -> t +(** Convert a stream into a stream of byte chunks using + the chunked encoding. The size of chunks is not specified. + @param buf buffer used for intermediate storage. + @param fail used to build an exception if reading fails. +*) + +val read_exactly : + close_rec:bool -> size:int -> too_short:(int -> unit) -> t -> t +(** [read_exactly ~size bs] returns a new stream that reads exactly + [size] bytes from [bs], and then closes. + @param close_rec if true, closing the resulting stream also closes + [bs] + @param too_short is called if [bs] closes with still [n] bytes remaining +*) + +val output_chunked : ?buf:Tiny_httpd_buf.t -> out_channel -> t -> unit +(** Write the stream into the channel, using the chunked encoding. + @param buf optional buffer for chunking (since 0.14) *) + +val output_chunked' : + ?buf:Tiny_httpd_buf.t -> Tiny_httpd_io.Output.t -> t -> unit +(** Write the stream into the channel, using the chunked encoding. + @since 0.14 *) diff --git a/src/core/util.ml b/src/core/util.ml new file mode 100644 index 00000000..73617702 --- /dev/null +++ b/src/core/util.ml @@ -0,0 +1,121 @@ +let percent_encode ?(skip = fun _ -> false) s = + let buf = Buffer.create (String.length s) in + String.iter + (function + | c when skip c -> Buffer.add_char buf c + | ( ' ' | '!' | '"' | '#' | '$' | '%' | '&' | '\'' | '(' | ')' | '*' | '+' + | ',' | '/' | ':' | ';' | '=' | '?' | '@' | '[' | ']' | '~' ) as c -> + Printf.bprintf buf "%%%X" (Char.code c) + | c when Char.code c > 127 -> Printf.bprintf buf "%%%X" (Char.code c) + | c -> Buffer.add_char buf c) + s; + Buffer.contents buf + +let int_of_hex_nibble = function + | '0' .. '9' as c -> Char.code c - Char.code '0' + | 'a' .. 'f' as c -> 10 + Char.code c - Char.code 'a' + | 'A' .. 'F' as c -> 10 + Char.code c - Char.code 'A' + | _ -> invalid_arg "string: invalid hex" + +let percent_decode (s : string) : _ option = + let buf = Buffer.create (String.length s) in + let i = ref 0 in + try + while !i < String.length s do + match String.get s !i with + | '%' -> + if !i + 2 < String.length s then ( + (match + (int_of_hex_nibble (String.get s (!i + 1)) lsl 4) + + int_of_hex_nibble (String.get s (!i + 2)) + with + | n -> Buffer.add_char buf (Char.chr n) + | exception _ -> raise Exit); + i := !i + 3 + ) else + raise Exit (* truncated *) + | '+' -> + Buffer.add_char buf ' '; + incr i (* for query strings *) + | c -> + Buffer.add_char buf c; + incr i + done; + Some (Buffer.contents buf) + with Exit -> None + +exception Invalid_query + +let find_q_index_ s = String.index s '?' + +let get_non_query_path s = + match find_q_index_ s with + | i -> String.sub s 0 i + | exception Not_found -> s + +let get_query s : string = + match find_q_index_ s with + | i -> String.sub s (i + 1) (String.length s - i - 1) + | exception Not_found -> "" + +let split_query s = get_non_query_path s, get_query s + +let split_on_slash s : _ list = + let l = ref [] in + let i = ref 0 in + let n = String.length s in + while !i < n do + match String.index_from s !i '/' with + | exception Not_found -> + if !i < n then (* last component *) l := String.sub s !i (n - !i) :: !l; + i := n (* done *) + | j -> + if j > !i then l := String.sub s !i (j - !i) :: !l; + i := j + 1 + done; + List.rev !l + +let parse_query s : (_ list, string) result = + let pairs = ref [] in + let is_sep_ = function + | '&' | ';' -> true + | _ -> false + in + let i = ref 0 in + let j = ref 0 in + try + let percent_decode s = + match percent_decode s with + | Some x -> x + | None -> raise Invalid_query + in + let parse_pair () = + let eq = String.index_from s !i '=' in + let k = percent_decode @@ String.sub s !i (eq - !i) in + let v = percent_decode @@ String.sub s (eq + 1) (!j - eq - 1) in + pairs := (k, v) :: !pairs + in + while !i < String.length s do + while !j < String.length s && not (is_sep_ (String.get s !j)) do + incr j + done; + if !j < String.length s then ( + assert (is_sep_ (String.get s !j)); + parse_pair (); + i := !j + 1; + j := !i + ) else ( + parse_pair (); + i := String.length s (* done *) + ) + done; + Ok !pairs + with + | Invalid_argument _ | Not_found | Failure _ -> + Error (Printf.sprintf "error in parse_query for %S: i=%d,j=%d" s !i !j) + | Invalid_query -> Error ("invalid query string: " ^ s) + +let show_sockaddr = function + | Unix.ADDR_UNIX f -> f + | Unix.ADDR_INET (inet, port) -> + Printf.sprintf "%s:%d" (Unix.string_of_inet_addr inet) port diff --git a/src/core/util.mli b/src/core/util.mli new file mode 100644 index 00000000..ac996855 --- /dev/null +++ b/src/core/util.mli @@ -0,0 +1,40 @@ +(** {1 Some utils for writing web servers} + + @since 0.2 +*) + +val percent_encode : ?skip:(char -> bool) -> string -> string +(** Encode the string into a valid path following + https://tools.ietf.org/html/rfc3986#section-2.1 + @param skip if provided, allows to preserve some characters, e.g. '/' in a path. +*) + +val percent_decode : string -> string option +(** Inverse operation of {!percent_encode}. + Can fail since some strings are not valid percent encodings. *) + +val split_query : string -> string * string +(** Split a path between the path and the query + @since 0.5 *) + +val split_on_slash : string -> string list +(** Split a string on ['/'], remove the trailing ['/'] if any. + @since 0.6 *) + +val get_non_query_path : string -> string +(** get the part of the path that is not the query parameters. + @since 0.5 *) + +val get_query : string -> string +(** Obtain the query part of a path. + @since 0.4 *) + +val parse_query : string -> ((string * string) list, string) result +(** Parse a query as a list of ['&'] or [';'] separated [key=value] pairs. + The order might not be preserved. + @since 0.3 +*) + +val show_sockaddr : Unix.sockaddr -> string +(** Simple printer for socket addresses. + @since NEXT_RELEASE *)