type buf = Tiny_httpd_buf.t type byte_stream = Tiny_httpd_stream.t module Buf = Tiny_httpd_buf module Byte_stream = Tiny_httpd_stream module IO = Tiny_httpd_io module Pool = Tiny_httpd_pool module Log = Tiny_httpd_log exception Bad_req of int * string let bad_reqf c fmt = Printf.ksprintf (fun s -> raise (Bad_req (c, s))) fmt module Response_code = struct type t = int let ok = 200 let not_found = 404 let descr = function | 100 -> "Continue" | 200 -> "OK" | 201 -> "Created" | 202 -> "Accepted" | 204 -> "No content" | 300 -> "Multiple choices" | 301 -> "Moved permanently" | 302 -> "Found" | 304 -> "Not Modified" | 400 -> "Bad request" | 401 -> "Unauthorized" | 403 -> "Forbidden" | 404 -> "Not found" | 405 -> "Method not allowed" | 408 -> "Request timeout" | 409 -> "Conflict" | 410 -> "Gone" | 411 -> "Length required" | 413 -> "Payload too large" | 417 -> "Expectation failed" | 500 -> "Internal server error" | 501 -> "Not implemented" | 503 -> "Service unavailable" | n -> "Unknown response code " ^ string_of_int n (* TODO *) let[@inline] is_success n = n >= 200 && n < 400 end type resp_error = Response_code.t * string type 'a resp_result = ('a, resp_error) result let unwrap_resp_result = function | Ok x -> x | Error (c, s) -> raise (Bad_req (c, s)) module Meth = struct type t = [ `GET | `PUT | `POST | `HEAD | `DELETE | `OPTIONS ] let to_string = function | `GET -> "GET" | `PUT -> "PUT" | `HEAD -> "HEAD" | `POST -> "POST" | `DELETE -> "DELETE" | `OPTIONS -> "OPTIONS" let pp out s = Format.pp_print_string out (to_string s) let of_string = function | "GET" -> `GET | "PUT" -> `PUT | "POST" -> `POST | "HEAD" -> `HEAD | "DELETE" -> `DELETE | "OPTIONS" -> `OPTIONS | s -> bad_reqf 400 "unknown method %S" s end module Headers = struct type t = (string * string) list let empty = [] let contains name headers = let name' = String.lowercase_ascii name in List.exists (fun (n, _) -> name' = n) headers let get_exn ?(f = fun x -> x) x h = let x' = String.lowercase_ascii x in List.assoc x' h |> f let get ?(f = fun x -> x) x h = try Some (get_exn ~f x h) with Not_found -> None let remove x h = let x' = String.lowercase_ascii x in List.filter (fun (k, _) -> k <> x') h let set x y h = let x' = String.lowercase_ascii x in (x', y) :: List.filter (fun (k, _) -> k <> x') h let pp out l = let pp_pair out (k, v) = Format.fprintf out "@[%s: %s@]" k v in Format.fprintf out "@[%a@]" (Format.pp_print_list pp_pair) l (* token = 1*tchar tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA ; any VCHAR, except delimiters Reference: https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 *) let is_tchar = function | '0' .. '9' | 'a' .. 'z' | 'A' .. 'Z' | '!' | '#' | '$' | '%' | '&' | '\'' | '*' | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~' -> true | _ -> false let for_all pred s = try String.iter (fun c -> if not (pred c) then raise Exit) s; true with Exit -> false let parse_ ~buf (bs : byte_stream) : t = let rec loop acc = let line = Byte_stream.read_line ~buf bs in Log.debug (fun k -> k "parsed header line %S" line); if line = "\r" then acc else ( let k, v = try let i = String.index line ':' in let k = String.sub line 0 i in if not (for_all is_tchar k) then invalid_arg (Printf.sprintf "Invalid header key: %S" k); let v = String.sub line (i + 1) (String.length line - i - 1) |> String.trim in k, v with _ -> bad_reqf 400 "invalid header line: %S" line in loop ((String.lowercase_ascii k, v) :: acc) ) in loop [] end module Request = struct type 'body t = { meth: Meth.t; host: string; client_addr: Unix.sockaddr; headers: Headers.t; http_version: int * int; path: string; path_components: string list; query: (string * string) list; body: 'body; start_time: float; } let headers self = self.headers let host self = self.host let client_addr self = self.client_addr let meth self = self.meth let path self = self.path let body self = self.body let start_time self = self.start_time let query self = self.query let get_header ?f self h = Headers.get ?f h self.headers let get_header_int self h = match get_header self h with | Some x -> (try Some (int_of_string x) with _ -> None) | None -> None let set_header k v self = { self with headers = Headers.set k v self.headers } let update_headers f self = { self with headers = f self.headers } let set_body b self = { self with body = b } (** Should we close the connection after this request? *) let close_after_req (self : _ t) : bool = match self.http_version with | 1, 1 -> get_header self "connection" = Some "close" | 1, 0 -> not (get_header self "connection" = Some "keep-alive") | _ -> false let pp_comp_ out comp = Format.fprintf out "[%s]" (String.concat ";" @@ List.map (Printf.sprintf "%S") comp) let pp_query out q = Format.fprintf out "[%s]" (String.concat ";" @@ List.map (fun (a, b) -> Printf.sprintf "%S,%S" a b) q) let pp_ out self : unit = Format.fprintf out "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ path=%S;@ body=?;@ \ path_components=%a;@ query=%a@]}" (Meth.to_string self.meth) self.host Headers.pp self.headers self.path pp_comp_ self.path_components pp_query self.query let pp out self : unit = Format.fprintf out "{@[meth=%s;@ host=%s;@ headers=[@[%a@]];@ path=%S;@ body=%S;@ \ path_components=%a;@ query=%a@]}" (Meth.to_string self.meth) self.host Headers.pp self.headers self.path self.body pp_comp_ self.path_components pp_query self.query (* decode a "chunked" stream into a normal stream *) let read_stream_chunked_ ?buf (bs : byte_stream) : byte_stream = Log.debug (fun k -> k "body: start reading chunked stream..."); Byte_stream.read_chunked ?buf ~fail:(fun s -> Bad_req (400, s)) bs let limit_body_size_ ~max_size (bs : byte_stream) : byte_stream = Log.debug (fun k -> k "limit size of body to max-size=%d" max_size); Byte_stream.limit_size_to ~max_size ~close_rec:false bs ~too_big:(fun size -> (* read too much *) bad_reqf 413 "body size was supposed to be %d, but at least %d bytes received" max_size size) let limit_body_size ~max_size (req : byte_stream t) : byte_stream t = { req with body = limit_body_size_ ~max_size req.body } (* read exactly [size] bytes from the stream *) let read_exactly ~size (bs : byte_stream) : byte_stream = Log.debug (fun k -> k "body: must read exactly %d bytes" size); Byte_stream.read_exactly bs ~close_rec:false ~size ~too_short:(fun size -> bad_reqf 400 "body is too short by %d bytes" size) (* parse request, but not body (yet) *) let parse_req_start ~client_addr ~get_time_s ~buf (bs : byte_stream) : unit t option resp_result = try let line = Byte_stream.read_line ~buf bs in let start_time = get_time_s () in let meth, path, version = try let meth, path, version = Scanf.sscanf line "%s %s HTTP/1.%d\r" (fun x y z -> x, y, z) in if version != 0 && version != 1 then raise Exit; meth, path, version with _ -> Log.error (fun k -> k "invalid request line: `%s`" line); raise (Bad_req (400, "Invalid request line")) in let meth = Meth.of_string meth in Log.debug (fun k -> k "got meth: %s, path %S" (Meth.to_string meth) path); let headers = Headers.parse_ ~buf bs in let host = match Headers.get "Host" headers with | None -> bad_reqf 400 "No 'Host' header in request" | Some h -> h in let path_components, query = Tiny_httpd_util.split_query path in let path_components = Tiny_httpd_util.split_on_slash path_components in let query = match Tiny_httpd_util.(parse_query query) with | Ok l -> l | Error e -> bad_reqf 400 "invalid query: %s" e in let req = { meth; query; host; client_addr; path; path_components; headers; http_version = 1, version; body = (); start_time; } in Ok (Some req) with | End_of_file | Sys_error _ | Unix.Unix_error _ -> Ok None | Bad_req (c, s) -> Error (c, s) | e -> Error (400, Printexc.to_string e) (* parse body, given the headers. @param tr_stream a transformation of the input stream. *) let parse_body_ ~tr_stream ~buf (req : byte_stream t) : byte_stream t resp_result = try let size = match Headers.get_exn "Content-Length" req.headers |> int_of_string with | n -> n (* body of fixed size *) | exception Not_found -> 0 | exception _ -> bad_reqf 400 "invalid content-length" in let body = match get_header ~f:String.trim req "Transfer-Encoding" with | None -> read_exactly ~size @@ tr_stream req.body | Some "chunked" -> let bs = read_stream_chunked_ ~buf @@ tr_stream req.body (* body sent by chunks *) in if size > 0 then limit_body_size_ ~max_size:size bs else bs | Some s -> bad_reqf 500 "cannot handle transfer encoding: %s" s in Ok { req with body } with | End_of_file -> Error (400, "unexpected end of file") | Bad_req (c, s) -> Error (c, s) | e -> Error (400, Printexc.to_string e) let read_body_full ?buf ?buf_size (self : byte_stream t) : string t = try let buf = match buf with | Some b -> b | None -> Buf.create ?size:buf_size () in let body = Byte_stream.read_all ~buf self.body in { self with body } with | Bad_req _ as e -> raise e | e -> bad_reqf 500 "failed to read body: %s" (Printexc.to_string e) module Internal_ = struct let parse_req_start ?(buf = Buf.create ()) ~client_addr ~get_time_s bs = parse_req_start ~client_addr ~get_time_s ~buf bs |> unwrap_resp_result let parse_body ?(buf = Buf.create ()) req bs : _ t = parse_body_ ~tr_stream:(fun s -> s) ~buf { req with body = bs } |> unwrap_resp_result end end module Response = struct type body = [ `String of string | `Stream of byte_stream | `Writer of IO.Writer.t | `Void ] type t = { code: Response_code.t; headers: Headers.t; body: body } let set_body body self = { self with body } let set_headers headers self = { self with headers } let update_headers f self = { self with headers = f self.headers } let set_header k v self = { self with headers = Headers.set k v self.headers } let set_code code self = { self with code } let make_raw ?(headers = []) ~code body : t = (* add content length to response *) let headers = Headers.set "Content-Length" (string_of_int (String.length body)) headers in { code; headers; body = `String body } let make_raw_stream ?(headers = []) ~code body : t = let headers = Headers.set "Transfer-Encoding" "chunked" headers in { code; headers; body = `Stream body } let make_raw_writer ?(headers = []) ~code body : t = let headers = Headers.set "Transfer-Encoding" "chunked" headers in { code; headers; body = `Writer body } let make_void_force_ ?(headers = []) ~code () : t = { code; headers; body = `Void } let make_void ?(headers = []) ~code () : t = let is_ok = code < 200 || code = 204 || code = 304 in if is_ok then make_void_force_ ~headers ~code () else make_raw ~headers ~code "" (* invalid to not have a body *) let make_string ?headers ?(code = 200) r = match r with | Ok body -> make_raw ?headers ~code body | Error (code, msg) -> make_raw ?headers ~code msg let make_stream ?headers ?(code = 200) r = match r with | Ok body -> make_raw_stream ?headers ~code body | Error (code, msg) -> make_raw ?headers ~code msg let make_writer ?headers ?(code = 200) r : t = match r with | Ok body -> make_raw_writer ?headers ~code body | Error (code, msg) -> make_raw ?headers ~code msg let make ?headers ?(code = 200) r : t = match r with | Ok (`String body) -> make_raw ?headers ~code body | Ok (`Stream body) -> make_raw_stream ?headers ~code body | Ok `Void -> make_void ?headers ~code () | Ok (`Writer f) -> make_raw_writer ?headers ~code f | Error (code, msg) -> make_raw ?headers ~code msg let fail ?headers ~code fmt = Printf.ksprintf (fun msg -> make_raw ?headers ~code msg) fmt let fail_raise ~code fmt = Printf.ksprintf (fun msg -> raise (Bad_req (code, msg))) fmt let pp out self : unit = let pp_body out = function | `String s -> Format.fprintf out "%S" s | `Stream _ -> Format.pp_print_string out "" | `Writer _ -> Format.pp_print_string out "" | `Void -> () in Format.fprintf out "{@[code=%d;@ headers=[@[%a@]];@ body=%a@]}" self.code Headers.pp self.headers pp_body self.body let output_ ~buf (oc : IO.Output.t) (self : t) : unit = (* double indirection: - print into [buffer] using [bprintf] - transfer to [buf_] so we can output from there *) let tmp_buffer = Buffer.create 32 in Buf.clear buf; (* write start of reply *) Printf.bprintf tmp_buffer "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code); Buf.add_buffer buf tmp_buffer; Buffer.clear tmp_buffer; let body, is_chunked = match self.body with | `String s when String.length s > 1024 * 500 -> (* chunk-encode large bodies *) `Writer (IO.Writer.of_string s), true | `String _ as b -> b, false | `Stream _ as b -> b, true | `Writer _ as b -> b, true | `Void as b -> b, false in let headers = if is_chunked then self.headers |> Headers.set "transfer-encoding" "chunked" |> Headers.remove "content-length" else self.headers in let self = { self with headers; body } in Log.debug (fun k -> k "output response: %s" (Format.asprintf "%a" pp { self with body = `String "<...>" })); (* write headers, using [buf] to batch writes *) List.iter (fun (k, v) -> Printf.bprintf tmp_buffer "%s: %s\r\n" k v; Buf.add_buffer buf tmp_buffer; Buffer.clear tmp_buffer) headers; IO.Output.output_buf oc buf; IO.Output.output_string oc "\r\n"; Buf.clear buf; (match body with | `String "" | `Void -> () | `String s -> IO.Output.output_string oc s | `Writer w -> (* use buffer to chunk encode [w] *) let oc' = IO.Output.chunk_encoding ~buf ~close_rec:false oc in (try IO.Writer.write oc' w; IO.Output.close oc' with e -> IO.Output.close oc'; raise e) | `Stream str -> (try Byte_stream.output_chunked' ~buf oc str; Byte_stream.close str with e -> Byte_stream.close str; raise e)); IO.Output.flush oc end (* semaphore, for limiting concurrency. *) module Sem_ = struct type t = { mutable n: int; max: int; mutex: Mutex.t; cond: Condition.t } let create n = if n <= 0 then invalid_arg "Semaphore.create"; { n; max = n; mutex = Mutex.create (); cond = Condition.create () } let acquire m t = Mutex.lock t.mutex; while t.n < m do Condition.wait t.cond t.mutex done; assert (t.n >= m); t.n <- t.n - m; Condition.broadcast t.cond; Mutex.unlock t.mutex let release m t = Mutex.lock t.mutex; t.n <- t.n + m; Condition.broadcast t.cond; Mutex.unlock t.mutex let num_acquired t = t.max - t.n end module Route = struct type path = string list (* split on '/' *) type (_, _) comp = | Exact : string -> ('a, 'a) comp | Int : (int -> 'a, 'a) comp | String : (string -> 'a, 'a) comp | String_urlencoded : (string -> 'a, 'a) comp type (_, _) t = | Fire : ('b, 'b) t | Rest : { url_encoded: bool } -> (string -> 'b, 'b) t | Compose : ('a, 'b) comp * ('b, 'c) t -> ('a, 'c) t let return = Fire let rest_of_path = Rest { url_encoded = false } let rest_of_path_urlencoded = Rest { url_encoded = true } let ( @/ ) a b = Compose (a, b) let string = String let string_urlencoded = String_urlencoded let int = Int let exact (s : string) = Exact s let exact_path (s : string) tail = let rec fn = function | [] -> tail | "" :: ls -> fn ls | s :: ls -> exact s @/ fn ls in fn (String.split_on_char '/' s) let rec eval : type a b. path -> (a, b) t -> a -> b option = fun path route f -> match path, route with | [], Fire -> Some f | _, Fire -> None | _, Rest { url_encoded } -> let whole_path = String.concat "/" path in (match if url_encoded then ( match Tiny_httpd_util.percent_decode whole_path with | Some s -> s | None -> raise_notrace Exit ) else whole_path with | whole_path -> Some (f whole_path) | exception Exit -> None) | c1 :: path', Compose (comp, route') -> (match comp with | Int -> (match int_of_string c1 with | i -> eval path' route' (f i) | exception _ -> None) | String -> eval path' route' (f c1) | String_urlencoded -> (match Tiny_httpd_util.percent_decode c1 with | None -> None | Some s -> eval path' route' (f s)) | Exact s -> if s = c1 then eval path' route' f else None) | [], Compose (String, Fire) -> Some (f "") (* trailing *) | [], Compose (String_urlencoded, Fire) -> Some (f "") (* trailing *) | [], Compose _ -> None let bpf = Printf.bprintf let rec pp_ : type a b. Buffer.t -> (a, b) t -> unit = fun out -> function | Fire -> bpf out "/" | Rest { url_encoded } -> bpf out "" (if url_encoded then "_urlencoded" else "") | Compose (Exact s, tl) -> bpf out "%s/%a" s pp_ tl | Compose (Int, tl) -> bpf out "/%a" pp_ tl | Compose (String, tl) -> bpf out "/%a" pp_ tl | Compose (String_urlencoded, tl) -> bpf out "/%a" pp_ tl let to_string x = let b = Buffer.create 16 in pp_ b x; Buffer.contents b let pp out x = Format.pp_print_string out (to_string x) end module Middleware = struct type handler = byte_stream Request.t -> resp:(Response.t -> unit) -> unit type t = handler -> handler (** Apply a list of middlewares to [h] *) let apply_l (l : t list) (h : handler) : handler = List.fold_right (fun m h -> m h) l h let[@inline] nil : t = fun h -> h end (* a request handler. handles a single request. *) type cb_path_handler = IO.Output.t -> Middleware.handler module type SERVER_SENT_GENERATOR = sig val set_headers : Headers.t -> unit val send_event : ?event:string -> ?id:string -> ?retry:string -> data:string -> unit -> unit val close : unit -> unit end type server_sent_generator = (module SERVER_SENT_GENERATOR) (** Handler that upgrades to another protocol *) module type UPGRADE_HANDLER = sig type handshake_state (** Some specific state returned after handshake *) val name : string (** Name in the "upgrade" header *) val handshake : unit Request.t -> (Headers.t * handshake_state, string) result (** Perform the handshake and upgrade the connection. The returned code is [101] alongside these headers. *) val handle_connection : Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit (** Take control of the connection and take it from there *) end type upgrade_handler = (module UPGRADE_HANDLER) exception Upgrade of unit Request.t * upgrade_handler module type IO_BACKEND = sig val init_addr : unit -> string val init_port : unit -> int val get_time_s : unit -> float (** obtain the current timestamp in seconds. *) val tcp_server : unit -> Tiny_httpd_io.TCP_server.builder (** Server that can listen on a port and handle clients. *) end type handler_result = | Handle of cb_path_handler | Fail of resp_error | Upgrade of upgrade_handler let unwrap_handler_result req = function | Handle x -> x | Fail (c, s) -> raise (Bad_req (c, s)) | Upgrade up -> raise (Upgrade (req, up)) type t = { backend: (module IO_BACKEND); mutable tcp_server: IO.TCP_server.t option; buf_size: int; mutable handler: byte_stream Request.t -> Response.t; (** toplevel handler, if any *) mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) mutable middlewares_sorted: (int * Middleware.t) list lazy_t; (** sorted version of {!middlewares} *) mutable path_handlers: (unit Request.t -> handler_result option) list; (** path handlers *) buf_pool: Buf.t Pool.t; } let get_addr_ sock = match Unix.getsockname sock with | Unix.ADDR_INET (addr, port) -> addr, port | _ -> invalid_arg "httpd: address is not INET" let addr (self : t) = match self.tcp_server with | None -> let (module B) = self.backend in B.init_addr () | Some s -> fst @@ s.endpoint () let port (self : t) = match self.tcp_server with | None -> let (module B) = self.backend in B.init_port () | Some s -> snd @@ s.endpoint () let active_connections (self : t) = match self.tcp_server with | None -> 0 | Some s -> s.active_connections () let add_middleware ~stage self m = let stage = match stage with | `Encoding -> 0 | `Stage n when n < 1 -> invalid_arg "add_middleware: bad stage" | `Stage n -> n in self.middlewares <- (stage, m) :: self.middlewares; self.middlewares_sorted <- lazy (List.stable_sort (fun (s1, _) (s2, _) -> compare s1 s2) self.middlewares) let add_decode_request_cb self f = (* turn it into a middleware *) let m h req ~resp = (* see if [f] modifies the stream *) let req0 = { req with Request.body = () } in match f req0 with | None -> h req ~resp (* pass through *) | Some (req1, tr_stream) -> let req = { req1 with Request.body = tr_stream req.Request.body } in h req ~resp in add_middleware self ~stage:`Encoding m let add_encode_response_cb self f = let m h req ~resp = h req ~resp:(fun r -> let req0 = { req with Request.body = () } in (* now transform [r] if we want to *) match f req0 r with | None -> resp r | Some r' -> resp r') in add_middleware self ~stage:`Encoding m let set_top_handler self f = self.handler <- f (* route the given handler. @param tr_req wraps the actual concrete function returned by the route and makes it into a handler. *) let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth ~tr_req self (route : _ Route.t) f = let ph req : handler_result option = match meth with | Some m when m <> req.Request.meth -> None (* ignore *) | _ -> (match Route.eval req.Request.path_components route f with | Some handler -> (* we have a handler, do we accept the request based on its headers? *) (match accept req with | Ok () -> Some (Handle (fun oc -> Middleware.apply_l middlewares @@ fun req ~resp -> tr_req oc req ~resp handler)) | Error err -> Some (Fail err)) | None -> None (* path didn't match *)) in self.path_handlers <- ph :: self.path_handlers let add_route_handler (type a) ?accept ?middlewares ?meth self (route : (a, _) Route.t) (f : _) : unit = let tr_req _oc req ~resp f = let req = Pool.with_resource self.buf_pool @@ fun buf -> Request.read_body_full ~buf req in resp (f req) in add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f let add_route_handler_stream ?accept ?middlewares ?meth self route f = let tr_req _oc req ~resp f = resp (f req) in add_route_handler_ ?accept ?middlewares ?meth self route ~tr_req f let[@inline] _opt_iter ~f o = match o with | None -> () | Some x -> f x exception Exit_SSE let add_route_server_sent_handler ?accept self route f = let tr_req (oc : IO.Output.t) req ~resp f = let req = Pool.with_resource self.buf_pool @@ fun buf -> Request.read_body_full ~buf req in let headers = ref Headers.(empty |> set "content-type" "text/event-stream") in (* send response once *) let resp_sent = ref false in let send_response_idempotent_ () = if not !resp_sent then ( resp_sent := true; (* send 200 response now *) let initial_resp = Response.make_void_force_ ~headers:!headers ~code:200 () in resp initial_resp ) in let[@inline] writef fmt = Printf.ksprintf (IO.Output.output_string oc) fmt in let send_event ?event ?id ?retry ~data () : unit = send_response_idempotent_ (); _opt_iter event ~f:(fun e -> writef "event: %s\n" e); _opt_iter id ~f:(fun e -> writef "id: %s\n" e); _opt_iter retry ~f:(fun e -> writef "retry: %s\n" e); let l = String.split_on_char '\n' data in List.iter (fun s -> writef "data: %s\n" s) l; IO.Output.output_string oc "\n"; (* finish group *) IO.Output.flush oc in let module SSG = struct let set_headers h = if not !resp_sent then ( headers := List.rev_append h !headers; send_response_idempotent_ () ) let send_event = send_event let close () = raise Exit_SSE end in (try f req (module SSG : SERVER_SENT_GENERATOR) with Exit_SSE -> IO.Output.close oc); Log.info (fun k -> k "closed SSE connection") in add_route_handler_ self ?accept ~meth:`GET route ~tr_req f let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = []) (self : t) route f : unit = let ph req : handler_result option = if req.Request.meth <> `GET then None else ( match accept req with | Ok () -> (match Route.eval req.Request.path_components route f with | Some up -> Some (Upgrade up) | None -> None (* path didn't match *)) | Error err -> Some (Fail err) ) in self.path_handlers <- ph :: self.path_handlers let get_max_connection_ ?(max_connections = 64) () : int = let max_connections = max 4 max_connections in max_connections let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t = let handler _req = Response.fail ~code:404 "no top handler" in let self = { backend; tcp_server = None; handler; buf_size; path_handlers = []; middlewares = []; middlewares_sorted = lazy []; buf_pool = Pool.create ~clear:Buf.clear_and_zero ~mk_item:(fun () -> Buf.create ~size:buf_size ()) (); } in List.iter (fun (stage, m) -> add_middleware self ~stage m) middlewares; self let is_ipv6_str addr : bool = String.contains addr ':' module Unix_tcp_server_ = struct type t = { addr: string; port: int; max_connections: int; sem_max_connections: Sem_.t; (** semaphore to restrict the number of active concurrent connections *) mutable sock: Unix.file_descr option; (** Socket *) new_thread: (unit -> unit) -> unit; timeout: float; masksigpipe: bool; mutable running: bool; (* TODO: use an atomic? *) } let to_tcp_server (self : t) : IO.TCP_server.builder = { IO.TCP_server.serve = (fun ~after_init ~handle () : unit -> if self.masksigpipe then ignore (Unix.sigprocmask Unix.SIG_BLOCK [ Sys.sigpipe ] : _ list); let sock, should_bind = match self.sock with | Some s -> ( s, false (* Because we're getting a socket from the caller (e.g. systemd) *) ) | None -> ( Unix.socket (if is_ipv6_str self.addr then Unix.PF_INET6 else Unix.PF_INET) Unix.SOCK_STREAM 0, true (* Because we're creating the socket ourselves *) ) in Unix.clear_nonblock sock; Unix.setsockopt_optint sock Unix.SO_LINGER None; if should_bind then ( let inet_addr = Unix.inet_addr_of_string self.addr in Unix.setsockopt sock Unix.SO_REUSEADDR true; Unix.bind sock (Unix.ADDR_INET (inet_addr, self.port)); let n_listen = 2 * self.max_connections in Unix.listen sock n_listen ); self.sock <- Some sock; let tcp_server = { IO.TCP_server.stop = (fun () -> self.running <- false); running = (fun () -> self.running); active_connections = (fun () -> Sem_.num_acquired self.sem_max_connections - 1); endpoint = (fun () -> let addr, port = get_addr_ sock in Unix.string_of_inet_addr addr, port); } in after_init tcp_server; (* how to handle a single client *) let handle_client_unix_ (client_sock : Unix.file_descr) (client_addr : Unix.sockaddr) : unit = Log.info (fun k -> k "serving new client on %s" (Tiny_httpd_util.show_sockaddr client_addr)); Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout); Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout); let oc = IO.Output.of_out_channel @@ Unix.out_channel_of_descr client_sock in let ic = IO.Input.of_unix_fd client_sock in handle.handle ~client_addr ic oc; Log.info (fun k -> k "done with client on %s, exiting" @@ Tiny_httpd_util.show_sockaddr client_addr); (try Unix.shutdown client_sock Unix.SHUTDOWN_ALL; Unix.close client_sock with e -> Log.error (fun k -> k "error when closing sock for client %s: %s" (Tiny_httpd_util.show_sockaddr client_addr) (Printexc.to_string e))); () in Unix.set_nonblock sock; while self.running do ignore (Unix.select [ sock ] [] [ sock ] 1.0 : _ * _ * _); match Unix.accept sock with | client_sock, client_addr -> (* limit concurrency *) Sem_.acquire 1 self.sem_max_connections; (* Block INT/HUP while cloning to avoid children handling them. When thread gets them, our Unix.accept raises neatly. *) ignore Unix.(sigprocmask SIG_BLOCK Sys.[ sigint; sighup ]); self.new_thread (fun () -> try Unix.setsockopt client_sock Unix.TCP_NODELAY true; handle_client_unix_ client_sock client_addr; Sem_.release 1 self.sem_max_connections with e -> let bt = Printexc.get_raw_backtrace () in (try Unix.close client_sock with _ -> ()); Sem_.release 1 self.sem_max_connections; Log.error (fun k -> k "@[Handler: uncaught exception for client %s:@ \ %s@ %s@]" (Tiny_httpd_util.show_sockaddr client_addr) (Printexc.to_string e) (Printexc.raw_backtrace_to_string bt))); ignore Unix.(sigprocmask SIG_UNBLOCK Sys.[ sigint; sighup ]) | exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) -> () | exception e -> Log.error (fun k -> k "Unix.accept or Thread.create raised an exception: %s" (Printexc.to_string e)) done; (* Wait for all threads to be done: this only works if all threads are done. *) Unix.close sock; Sem_.acquire self.sem_max_connections.max self.sem_max_connections; ()); } end let create ?(masksigpipe = true) ?max_connections ?(timeout = 0.0) ?buf_size ?(get_time_s = Unix.gettimeofday) ?(new_thread = fun f -> ignore (Thread.create f () : Thread.t)) ?(addr = "127.0.0.1") ?(port = 8080) ?sock ?middlewares () : t = let max_connections = get_max_connection_ ?max_connections () in let server = { Unix_tcp_server_.addr; new_thread; running = true; port; sock; max_connections; sem_max_connections = Sem_.create max_connections; masksigpipe; timeout; } in let tcp_server_builder = Unix_tcp_server_.to_tcp_server server in let module B = struct let init_addr () = addr let init_port () = port let get_time_s = get_time_s let tcp_server () = tcp_server_builder end in let backend = (module B : IO_BACKEND) in create_from ?buf_size ?middlewares ~backend () let stop (self : t) = match self.tcp_server with | None -> () | Some s -> s.stop () let running (self : t) = match self.tcp_server with | None -> false | Some s -> s.running () let find_map f l = let rec aux f = function | [] -> None | x :: l' -> (match f x with | Some _ as res -> res | None -> aux f l') in aux f l let string_as_list_contains_ (s : string) (sub : string) : bool = let fragments = String.split_on_char ',' s in List.exists (fun fragment -> String.trim fragment = sub) fragments (* handle client on [ic] and [oc] *) let client_handle_for (self : t) ~client_addr ic oc : unit = Pool.with_resource self.buf_pool @@ fun buf -> Pool.with_resource self.buf_pool @@ fun buf_res -> let is = Byte_stream.of_input ~buf_size:self.buf_size ic in let (module B) = self.backend in (* how to log the response to this query *) let log_response (req : _ Request.t) (resp : Response.t) = if not Log.dummy then ( let msgf k = let elapsed = B.get_time_s () -. req.start_time in k ("response to=%s code=%d time=%.3fs path=%S" : _ format4) (Tiny_httpd_util.show_sockaddr client_addr) resp.code elapsed req.path in if Response_code.is_success resp.code then Log.info msgf else Log.error msgf ) in (* handle generic exception *) let handle_exn e = let resp = Response.fail ~code:500 "server error: %s" (Printexc.to_string e) in if not Log.dummy then Log.error (fun k -> k "response to %s code=%d" (Tiny_httpd_util.show_sockaddr client_addr) resp.code); Response.output_ ~buf:buf_res oc resp in let handle_bad_req req e = let resp = Response.fail ~code:500 "server error: %s" (Printexc.to_string e) in log_response req resp; Response.output_ ~buf:buf_res oc resp in let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit = Log.debug (fun k -> k "upgrade connection"); try (* check headers *) (match Request.get_header req "connection" with | Some str when string_as_list_contains_ str "Upgrade" -> () | _ -> bad_reqf 426 "connection header must contain 'Upgrade'"); (match Request.get_header req "upgrade" with | Some u when u = UP.name -> () | Some u -> bad_reqf 426 "expected upgrade to be '%s', got '%s'" UP.name u | None -> bad_reqf 426 "expected 'connection: upgrade' header"); (* ok, this is the upgrade we expected *) match UP.handshake req with | Error msg -> (* fail the upgrade *) Log.error (fun k -> k "upgrade failed: %s" msg); let resp = Response.make_raw ~code:429 "upgrade required" in log_response req resp; Response.output_ ~buf:buf_res oc resp | Ok (headers, handshake_st) -> (* send the upgrade reply *) let headers = [ "connection", "upgrade"; "upgrade", UP.name ] @ headers in let resp = Response.make_string ~code:101 ~headers (Ok "") in log_response req resp; Response.output_ ~buf:buf_res oc resp; (* now, give the whole connection over to the upgraded connection. Make sure to give the leftovers from [is] as well, if any. There might not be any because the first message doesn't normally come directly in the same packet as the handshake, but still. *) let ic = if is.len > 0 then ( Log.debug (fun k -> k "LEFTOVERS! %d B" is.len); IO.Input.append (IO.Input.of_slice is.bs is.off is.len) ic ) else ic in UP.handle_connection client_addr handshake_st ic oc with e -> handle_bad_req req e in let continue = ref true in let handle_one_req () = match Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is with | Ok None -> continue := false (* client is done *) | Error (c, s) -> (* connection error, close *) let res = Response.make_raw ~code:c s in (try Response.output_ ~buf:buf_res oc res with Sys_error _ -> ()); continue := false | Ok (Some req) -> Log.debug (fun k -> k "parsed request: %s" (Format.asprintf "@[%a@]" Request.pp_ req)); if Request.close_after_req req then continue := false; (try (* is there a handler for this path? *) let base_handler = match find_map (fun ph -> ph req) self.path_handlers with | Some f -> unwrap_handler_result req f | None -> fun _oc req ~resp -> resp (self.handler req) in (* handle expect/continue *) (match Request.get_header ~f:String.trim req "Expect" with | Some "100-continue" -> Log.debug (fun k -> k "send back: 100 CONTINUE"); Response.output_ ~buf:buf_res oc (Response.make_raw ~code:100 "") | Some s -> bad_reqf 417 "unknown expectation %s" s | None -> ()); (* apply middlewares *) let handler oc = List.fold_right (fun (_, m) h -> m h) (Lazy.force self.middlewares_sorted) (base_handler oc) in (* now actually read request's body into a stream *) let req = Request.parse_body_ ~tr_stream:(fun s -> s) ~buf { req with body = is } |> unwrap_resp_result in (* how to reply *) let resp r = try if Headers.get "connection" r.Response.headers = Some "close" then continue := false; log_response req r; Response.output_ ~buf:buf_res oc r with Sys_error _ -> continue := false in (* call handler *) try handler oc req ~resp with Sys_error _ -> continue := false with | Sys_error _ -> (* connection broken somehow *) Log.debug (fun k -> k "connection broken"); continue := false | Bad_req (code, s) -> continue := false; let resp = Response.make_raw ~code s in log_response req resp; Response.output_ ~buf:buf_res oc resp | Upgrade _ as e -> raise e | e -> handle_bad_req req e) in try while !continue && running self do Log.debug (fun k -> k "read next request"); handle_one_req () done with | Upgrade (req, up) -> (* upgrades take over the whole connection, we won't process any further request *) handle_upgrade req up | e -> handle_exn e let client_handler (self : t) : IO.TCP_server.conn_handler = { IO.TCP_server.handle = client_handle_for self } let is_ipv6 (self : t) = let (module B) = self.backend in is_ipv6_str (B.init_addr ()) let run_exn ?(after_init = ignore) (self : t) : unit = let (module B) = self.backend in let server = B.tcp_server () in server.serve ~after_init:(fun tcp_server -> self.tcp_server <- Some tcp_server; after_init ()) ~handle:(client_handler self) () let run ?after_init self : _ result = try Ok (run_exn ?after_init self) with e -> Error e