This commit is contained in:
Simon Cruanes 2024-03-28 15:56:46 +00:00 committed by GitHub
commit f5f4e3c2d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 434 additions and 238 deletions

View file

@ -22,6 +22,7 @@
base-threads base-threads
result result
hmap hmap
base-unix
(iostream (>= 0.2)) (iostream (>= 0.2))
(ocaml (>= 4.08)) (ocaml (>= 4.08))
(odoc :with-doc) (odoc :with-doc)

View file

@ -28,7 +28,7 @@ let handle_ws _client_addr ic oc =
let buf = Bytes.create 32 in let buf = Bytes.create 32 in
let continue = ref true in let continue = ref true in
while !continue do while !continue do
let n = IO.Input.input ic buf 0 (Bytes.length buf) in let n = IO.Input_with_timeout.input ic buf 0 (Bytes.length buf) in
Log.debug (fun k -> Log.debug (fun k ->
k "echo %d bytes from websocket: %S" n (Bytes.sub_string buf 0 n)); k "echo %d bytes from websocket: %S" n (Bytes.sub_string buf 0 n));

View file

@ -10,6 +10,7 @@ module Meth = Tiny_httpd_core.Meth
module Pool = Tiny_httpd_core.Pool module Pool = Tiny_httpd_core.Pool
module Log = Tiny_httpd_core.Log module Log = Tiny_httpd_core.Log
module Server = Tiny_httpd_core.Server module Server = Tiny_httpd_core.Server
module Time = Time
module Util = Tiny_httpd_core.Util module Util = Tiny_httpd_core.Util
include Server include Server
module Dir = Tiny_httpd_unix.Dir module Dir = Tiny_httpd_unix.Dir

View file

@ -85,6 +85,10 @@ module Buf = Buf
module IO = Tiny_httpd_core.IO module IO = Tiny_httpd_core.IO
(** {2 Time} *)
module Time = Time
(** {2 Logging *) (** {2 Logging *)
module Log = Tiny_httpd_core.Log module Log = Tiny_httpd_core.Log

View file

@ -1,8 +1,9 @@
module W = IO.Writer module W = IO.Writer
let decode_deflate_stream_ ~buf_size (ic : IO.Input.t) : IO.Input.t = let decode_deflate_stream_ ~buf_size (ic : #IO.Input_with_timeout.t) :
IO.Input_with_timeout.t =
Log.debug (fun k -> k "wrap stream with deflate.decode"); Log.debug (fun k -> k "wrap stream with deflate.decode");
Iostream_camlzip.decompress_in_buf ~buf_size ic Iostream_camlzip.decompress_in_buf_with_timeout ~now_s:Time.now_s ~buf_size ic
let encode_deflate_writer_ ~buf_size (w : W.t) : W.t = let encode_deflate_writer_ ~buf_size (w : W.t) : W.t =
Log.debug (fun k -> k "wrap writer with deflate.encode"); Log.debug (fun k -> k "wrap writer with deflate.encode");
@ -27,8 +28,8 @@ let has_deflate s =
try Scanf.sscanf s "deflate, %s" (fun _ -> true) with _ -> false try Scanf.sscanf s "deflate, %s" (fun _ -> true) with _ -> false
(* decompress [req]'s body if needed *) (* decompress [req]'s body if needed *)
let decompress_req_stream_ ~buf_size (req : IO.Input.t Request.t) : _ Request.t let decompress_req_stream_ ~buf_size (req : #IO.Input_with_timeout.t Request.t)
= : _ Request.t =
match Request.get_header ~f:String.trim req "Transfer-Encoding" with match Request.get_header ~f:String.trim req "Transfer-Encoding" with
(* TODO (* TODO
| Some "gzip" -> | Some "gzip" ->

View file

@ -11,6 +11,7 @@
open Common_ open Common_
module Buf = Buf module Buf = Buf
module Slice = Iostream.Slice module Slice = Iostream.Slice
module A = Atomic_
(** Output channel (byte sink) *) (** Output channel (byte sink) *)
module Output = struct module Output = struct
@ -44,13 +45,11 @@ module Output = struct
done done
method private close_underlying () = method private close_underlying () =
if not !closed then ( if not (A.exchange closed true) then
closed := true;
if close_noerr then ( if close_noerr then (
try Unix.close fd with _ -> () try Unix.close fd with _ -> ()
) else ) else
Unix.close fd Unix.close fd
)
end end
let output_buf (self : t) (buf : Buf.t) : unit = let output_buf (self : t) (buf : Buf.t) : unit =
@ -108,38 +107,28 @@ module Input = struct
let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t) let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t)
(fd : Unix.file_descr) : t = (fd : Unix.file_descr) : t =
let eof = ref false in let eof = ref false in
let input buf i len : int =
let n = ref 0 in
if not !eof then (
n := Unix.read fd buf i len;
if !n = 0 then eof := true
);
!n
in
object object
inherit Iostream.In_buf.t_from_refill ~bytes:buf.bytes () inherit Iostream.In_buf.t_from_refill ~bytes:buf.bytes ()
method private refill (slice : Slice.t) = method private refill (slice : Slice.t) =
if not !eof then ( if not !eof then (
slice.off <- 0; slice.off <- 0;
let continue = ref true in slice.len <- input slice.bytes 0 (Bytes.length slice.bytes);
while !continue do
match Unix.read fd slice.bytes 0 (Bytes.length slice.bytes) with
| n ->
slice.len <- n;
continue := false
| exception
Unix.Unix_error
( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN
| Unix.ECONNRESET | Unix.EPIPE ),
_,
_ ) ->
eof := true;
continue := false
| exception
Unix.Unix_error
((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) ->
ignore (Unix.select [ fd ] [] [] 1.)
done;
(* Printf.eprintf "read returned %d B\n%!" !n; *) (* Printf.eprintf "read returned %d B\n%!" !n; *)
if slice.len = 0 then eof := true if slice.len = 0 then eof := true
) )
method close () = method close () =
if not !closed then ( if not (A.exchange closed true) then (
closed := true;
eof := true; eof := true;
if close_noerr then ( if close_noerr then (
try Unix.close fd with _ -> () try Unix.close fd with _ -> ()
@ -148,6 +137,8 @@ module Input = struct
) )
end end
let[@inline] of_string s : t = (of_string s :> t)
let of_slice (slice : Slice.t) : t = let of_slice (slice : Slice.t) : t =
object object
inherit Iostream.In_buf.t_from_refill ~bytes:slice.bytes () inherit Iostream.In_buf.t_from_refill ~bytes:slice.bytes ()
@ -168,7 +159,7 @@ module Input = struct
(** Read exactly [len] bytes. (** Read exactly [len] bytes.
@raise End_of_file if the input did not contain enough data. *) @raise End_of_file if the input did not contain enough data. *)
let really_input (self : t) buf i len : unit = let really_input (self : #t) buf i len : unit =
let i = ref i in let i = ref i in
let len = ref len in let len = ref len in
while !len > 0 do while !len > 0 do
@ -178,31 +169,6 @@ module Input = struct
len := !len - n len := !len - n
done done
let append (i1 : #t) (i2 : #t) : t =
let use_i1 = ref true in
let rec input_rec (slice : Slice.t) =
if !use_i1 then (
slice.len <- input i1 slice.bytes 0 (Bytes.length slice.bytes);
if slice.len = 0 then (
use_i1 := false;
input_rec slice
)
) else
slice.len <- input i1 slice.bytes 0 (Bytes.length slice.bytes)
in
object
inherit Iostream.In_buf.t_from_refill ()
method private refill (slice : Slice.t) =
slice.off <- 0;
input_rec slice
method close () =
close i1;
close i2
end
let iter_slice (f : Slice.t -> unit) (self : #t) : unit = let iter_slice (f : Slice.t -> unit) (self : #t) : unit =
let continue = ref true in let continue = ref true in
while !continue do while !continue do
@ -231,11 +197,131 @@ module Input = struct
Iostream.Out.output oc slice.bytes slice.off slice.len) Iostream.Out.output oc slice.bytes slice.off slice.len)
self self
let read_all_using ~buf (self : #t) : string = (** Output a stream using chunked encoding *)
let output_chunked' ?buf (oc : #Iostream.Out_buf.t) (self : #t) : unit =
let oc' = Output.chunk_encoding ?buf oc ~close_rec:false in
match to_chan' oc' self with
| () -> Output.close oc'
| exception e ->
let bt = Printexc.get_raw_backtrace () in
Output.close oc';
Printexc.raise_with_backtrace e bt
(** print a stream as a series of chunks *)
let output_chunked ?buf (oc : out_channel) (self : #t) : unit =
output_chunked' ?buf (Output.of_out_channel oc) self
end
(** Input channel (byte source) with read-with-timeout *)
module Input_with_timeout = struct
include Iostream.In_buf
class type t = Iostream.In_buf.t_with_timeout
exception Timeout = Iostream.Timeout
(** Exception for timeouts *)
exception Timeout_partial_read of int
(** Exception for timeouts with a partial read *)
(** fill buffer, but stop at the deadline *)
let fill_buf_with_deadline (self : #t) ~(deadline : float) : Slice.t =
let timeout = deadline -. Time.now_s () in
if timeout <= 0. then raise Timeout;
fill_buf_with_timeout self timeout
(** fill buffer, but stop at the deadline if provided *)
let fill_buf_with_deadline_opt (self : #t) ~(deadline : float option) :
Slice.t =
match deadline with
| None -> fill_buf self
| Some d -> fill_buf_with_deadline self ~deadline:d
let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t)
(fd : Unix.file_descr) : t =
let eof = ref false in
let input_with_timeout t buf i len : int =
let deadline = Time.now_s () +. t in
let n = ref 0 in
while
(not (Atomic.get closed))
&& (not !eof)
&&
try
n := Unix.read fd buf i len;
false
with
| Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
(* sleep *)
true
| Unix.Unix_error ((Unix.ECONNRESET | Unix.ESHUTDOWN | Unix.EPIPE), _, _)
->
(* exit *)
false
do
let now = Time.now_s () in
if now >= deadline then raise Timeout;
ignore (Unix.select [ fd ] [] [] (deadline -. now) : _ * _ * _)
done;
!n
in
object
inherit Iostream.In_buf.t_with_timeout_from_refill ~bytes:buf.bytes ()
method private refill_with_timeout t (slice : Slice.t) =
if not !eof then (
slice.off <- 0;
slice.len <-
input_with_timeout t slice.bytes 0 (Bytes.length slice.bytes);
(* Printf.eprintf "read returned %d B\n%!" !n; *)
if slice.len = 0 then eof := true
)
method close () =
if not (A.exchange closed true) then (
eof := true;
if close_noerr then (
try Unix.close fd with _ -> ()
) else
Unix.close fd
)
end
let of_slice (slice : Slice.t) : t =
object
inherit Iostream.In_buf.t_with_timeout_from_refill ~bytes:slice.bytes ()
method private refill_with_timeout _t (slice : Slice.t) =
slice.off <- 0;
slice.len <- 0
method close () = ()
end
(** Read into the given slice.
@return the number of bytes read, [0] means end of input. *)
let[@inline] input (self : t) buf i len = self#input buf i len
(** Close the channel. *)
let[@inline] close self : unit = self#close ()
let iter_slice = Input.iter_slice
let iter = Input.iter
let to_chan = Input.to_chan
let to_chan' = Input.to_chan'
(** Read the whole body
@param deadline a deadline before which the operation must complete
@raise Timeout if deadline expires (leftovers are in [buf] *)
let read_all_using ~buf ~(deadline : float) (self : #t) : string =
Buf.clear buf; Buf.clear buf;
let continue = ref true in let continue = ref true in
while !continue do while !continue do
let slice = fill_buf self in let timeout = deadline -. Time.now_s () in
if timeout <= 0. then raise Timeout;
let slice = fill_buf_with_timeout self timeout in
if slice.len = 0 then if slice.len = 0 then
continue := false continue := false
else ( else (
@ -246,12 +332,17 @@ module Input = struct
done; done;
Buf.contents_and_clear buf Buf.contents_and_clear buf
(** Read [n] bytes from the input into [bytes]. *) (** Read [n] bytes from the input into [bytes].
let read_exactly_ ~too_short (self : #t) (bytes : bytes) (n : int) : unit = @raise Timeout_partial_read if timeout occurs before it's done *)
assert (Bytes.length bytes >= n); let read_exactly_ ?(off = 0) ~too_short ~(deadline : float) (self : #t)
let offset = ref 0 in (bytes : bytes) (n : int) : unit =
assert (Bytes.length bytes >= off + n);
let offset = ref off in
while !offset < n do while !offset < n do
let slice = self#fill_buf () in let slice =
try fill_buf_with_deadline self ~deadline
with Timeout -> raise (Timeout_partial_read (!offset - off))
in
let n_read = min slice.len (n - !offset) in let n_read = min slice.len (n - !offset) in
Bytes.blit slice.bytes slice.off bytes !offset n_read; Bytes.blit slice.bytes slice.off bytes !offset n_read;
offset := !offset + n_read; offset := !offset + n_read;
@ -259,12 +350,16 @@ module Input = struct
if n_read = 0 then too_short () if n_read = 0 then too_short ()
done done
let[@inline] really_input (self : #t) ~deadline buf i len =
read_exactly_ ~off:i ~deadline self buf len ~too_short:(fun () ->
raise End_of_file)
(** read a line into the buffer, after clearing it. *) (** read a line into the buffer, after clearing it. *)
let read_line_into (self : t) ~buf : unit = let read_line_into (self : #t) ~(deadline : float) ~buf : unit =
Buf.clear buf; Buf.clear buf;
let continue = ref true in let continue = ref true in
while !continue do while !continue do
let slice = self#fill_buf () in let slice = fill_buf_with_deadline self ~deadline in
if slice.len = 0 then ( if slice.len = 0 then (
continue := false; continue := false;
if Buf.size buf = 0 then raise End_of_file if Buf.size buf = 0 then raise End_of_file
@ -286,32 +381,32 @@ module Input = struct
) )
done done
let read_line_using ~buf (self : #t) : string = let read_line_using ~buf ~deadline (self : #t) : string =
read_line_into self ~buf; read_line_into self ~deadline ~buf;
Buf.contents_and_clear buf Buf.contents_and_clear buf
let read_line_using_opt ~buf (self : #t) : string option = let read_line_using_opt ~buf ~deadline (self : #t) : string option =
match read_line_into self ~buf with match read_line_into self ~buf ~deadline with
| () -> Some (Buf.contents_and_clear buf) | () -> Some (Buf.contents_and_clear buf)
| exception End_of_file -> None | exception End_of_file -> None
(* helper for making a new input stream that either contains at most [size] (* helper for making a new input stream that either contains at most [size]
bytes, or contains exactly [size] bytes. *) bytes, or contains exactly [size] bytes. *)
let reading_exactly_ ~skip_on_close ~close_rec ~size ~bytes (arg : t) : t = let reading_exactly_ ~skip_on_close ~close_rec ~size ~bytes (arg : #t) : t =
let remaining_size = ref size in let remaining_size = ref size in
object object
inherit t_from_refill ~bytes () inherit t_with_timeout_from_refill ~bytes ()
method close () = method close () =
if !remaining_size > 0 && skip_on_close then skip arg !remaining_size; if !remaining_size > 0 && skip_on_close then skip arg !remaining_size;
if close_rec then close arg if close_rec then close arg
method private refill (slice : Slice.t) = method private refill_with_timeout t (slice : Slice.t) =
slice.off <- 0; slice.off <- 0;
slice.len <- 0; slice.len <- 0;
if !remaining_size > 0 then ( if !remaining_size > 0 then (
let sub = fill_buf arg in let sub = fill_buf_with_timeout arg t in
let n = let n =
min !remaining_size (min sub.len (Bytes.length slice.bytes)) min !remaining_size (min sub.len (Bytes.length slice.bytes))
in in
@ -324,7 +419,7 @@ module Input = struct
(** new stream with maximum size [max_size]. (** new stream with maximum size [max_size].
@param close_rec if true, closing this will also close the input stream *) @param close_rec if true, closing this will also close the input stream *)
let limit_size_to ~close_rec ~max_size ~bytes (arg : t) : t = let limit_size_to ~close_rec ~max_size ~bytes (arg : #t) : t =
reading_exactly_ ~size:max_size ~skip_on_close:false ~bytes ~close_rec arg reading_exactly_ ~size:max_size ~skip_on_close:false ~bytes ~close_rec arg
(** New stream that consumes exactly [size] bytes from the input. (** New stream that consumes exactly [size] bytes from the input.
@ -339,15 +434,15 @@ module Input = struct
(* small buffer to read the chunk sizes *) (* small buffer to read the chunk sizes *)
let line_buf = Buf.create ~size:32 () in let line_buf = Buf.create ~size:32 () in
let read_next_chunk_len () : int = let read_next_chunk_len ~deadline () : int =
if !first then if !first then
first := false first := false
else ( else (
let line = read_line_using ~buf:line_buf ic in let line = read_line_using ~buf:line_buf ~deadline ic in
if String.trim line <> "" then if String.trim line <> "" then
raise (fail "expected crlf between chunks") raise (fail "expected crlf between chunks")
); );
let line = read_line_using ~buf:line_buf ic in let line = read_line_using ~buf:line_buf ~deadline ic in
(* parse chunk length, ignore extensions *) (* parse chunk length, ignore extensions *)
let chunk_size = let chunk_size =
if String.trim line = "" then if String.trim line = "" then
@ -367,11 +462,12 @@ module Input = struct
let chunk_size = ref 0 in let chunk_size = ref 0 in
object object
inherit t_from_refill ~bytes () inherit t_with_timeout_from_refill ~bytes ()
method private refill (slice : Slice.t) : unit = method private refill_with_timeout t (slice : Slice.t) : unit =
let deadline = Time.now_s () +. t in
if !chunk_size = 0 && not !eof then ( if !chunk_size = 0 && not !eof then (
chunk_size := read_next_chunk_len (); chunk_size := read_next_chunk_len ~deadline ();
if !chunk_size = 0 then eof := true (* stream is finished *) if !chunk_size = 0 then eof := true (* stream is finished *)
); );
slice.off <- 0; slice.off <- 0;
@ -379,7 +475,7 @@ module Input = struct
if !chunk_size > 0 then ( if !chunk_size > 0 then (
(* read the whole chunk, or [Bytes.length bytes] of it *) (* read the whole chunk, or [Bytes.length bytes] of it *)
let to_read = min !chunk_size (Bytes.length slice.bytes) in let to_read = min !chunk_size (Bytes.length slice.bytes) in
read_exactly_ read_exactly_ ~deadline
~too_short:(fun () -> raise (fail "chunk is too short")) ~too_short:(fun () -> raise (fail "chunk is too short"))
ic slice.bytes to_read; ic slice.bytes to_read;
slice.len <- to_read; slice.len <- to_read;
@ -389,19 +485,8 @@ module Input = struct
method close () = eof := true (* do not close underlying stream *) method close () = eof := true (* do not close underlying stream *)
end end
(** Output a stream using chunked encoding *) let output_chunked = Input.output_chunked
let output_chunked' ?buf (oc : #Iostream.Out_buf.t) (self : #t) : unit = let output_chunked' = Input.output_chunked'
let oc' = Output.chunk_encoding ?buf oc ~close_rec:false in
match to_chan' oc' self with
| () -> Output.close oc'
| exception e ->
let bt = Printexc.get_raw_backtrace () in
Output.close oc';
Printexc.raise_with_backtrace e bt
(** print a stream as a series of chunks *)
let output_chunked ?buf (oc : out_channel) (self : #t) : unit =
output_chunked' ?buf (Output.of_out_channel oc) self
end end
(** A writer abstraction. *) (** A writer abstraction. *)
@ -441,7 +526,8 @@ end
(** A TCP server abstraction. *) (** A TCP server abstraction. *)
module TCP_server = struct module TCP_server = struct
type conn_handler = { type conn_handler = {
handle: client_addr:Unix.sockaddr -> Input.t -> Output.t -> unit; handle:
client_addr:Unix.sockaddr -> Input_with_timeout.t -> Output.t -> unit;
(** Handle client connection *) (** Handle client connection *)
} }

View file

@ -4,6 +4,9 @@
(public_name tiny_httpd.core) (public_name tiny_httpd.core)
(private_modules parse_ common_) (private_modules parse_ common_)
(libraries threads seq hmap iostream (libraries threads seq hmap iostream
(select time.ml from
(mtime mtime.clock.os -> time.mtime.ml)
(unix -> time.default.ml))
(select log.ml from (select log.ml from
(logs -> log.logs.ml) (logs -> log.logs.ml)
(-> log.default.ml)))) (-> log.default.ml))))

View file

@ -46,9 +46,9 @@ let for_all pred s =
true true
with Exit -> false with Exit -> false
let parse_ ~(buf : Buf.t) (bs : IO.Input.t) : t = let parse_ ~(buf : Buf.t) ~deadline (bs : #IO.Input_with_timeout.t) : t =
let rec loop acc = let rec loop acc =
match IO.Input.read_line_using_opt ~buf bs with match IO.Input_with_timeout.read_line_using_opt ~buf ~deadline bs with
| None -> raise End_of_file | None -> raise End_of_file
| Some "\r" -> acc | Some "\r" -> acc
| Some line -> | Some line ->

View file

@ -32,4 +32,4 @@ val contains : string -> t -> bool
val pp : Format.formatter -> t -> unit val pp : Format.formatter -> t -> unit
(** Pretty print the headers. *) (** Pretty print the headers. *)
val parse_ : buf:Buf.t -> IO.Input.t -> t val parse_ : buf:Buf.t -> deadline:float -> #IO.Input_with_timeout.t -> t

View file

@ -88,29 +88,33 @@ let pp out self : unit =
pp_with ~pp_body () out self pp_with ~pp_body () out self
(* decode a "chunked" stream into a normal stream *) (* decode a "chunked" stream into a normal stream *)
let read_stream_chunked_ ~bytes (bs : #IO.Input.t) : IO.Input.t = let read_stream_chunked_ ~bytes (bs : #IO.Input_with_timeout.t) :
IO.Input_with_timeout.t =
Log.debug (fun k -> k "body: start reading chunked stream..."); Log.debug (fun k -> k "body: start reading chunked stream...");
IO.Input.read_chunked ~bytes ~fail:(fun s -> Bad_req (400, s)) bs IO.Input_with_timeout.read_chunked ~bytes ~fail:(fun s -> Bad_req (400, s)) bs
let limit_body_size_ ~max_size ~bytes (bs : #IO.Input.t) : IO.Input.t = let limit_body_size_ ~max_size ~bytes (bs : #IO.Input_with_timeout.t) :
IO.Input_with_timeout.t =
Log.debug (fun k -> k "limit size of body to max-size=%d" max_size); Log.debug (fun k -> k "limit size of body to max-size=%d" max_size);
IO.Input.limit_size_to ~max_size ~close_rec:false ~bytes bs IO.Input_with_timeout.limit_size_to ~max_size ~close_rec:false ~bytes bs
let limit_body_size ~max_size ~bytes (req : IO.Input.t t) : IO.Input.t t = let limit_body_size ~max_size ~bytes (req : #IO.Input_with_timeout.t t) :
IO.Input_with_timeout.t t =
{ req with body = limit_body_size_ ~max_size ~bytes req.body } { req with body = limit_body_size_ ~max_size ~bytes req.body }
(** read exactly [size] bytes from the stream *) (** read exactly [size] bytes from the stream *)
let read_exactly ~size ~bytes (bs : #IO.Input.t) : IO.Input.t = let read_exactly ~size ~bytes (bs : #IO.Input_with_timeout.t) :
IO.Input_with_timeout.t =
Log.debug (fun k -> k "body: must read exactly %d bytes" size); Log.debug (fun k -> k "body: must read exactly %d bytes" size);
IO.Input.reading_exactly bs ~close_rec:false ~bytes ~size IO.Input_with_timeout.reading_exactly bs ~close_rec:false ~bytes ~size
(* parse request, but not body (yet) *) (* parse request, but not body (yet) *)
let parse_req_start ~client_addr ~get_time_s ~buf (bs : IO.Input.t) : let parse_req_start ~client_addr ~(deadline : float) ~buf
unit t option resp_result = (bs : #IO.Input_with_timeout.t) : unit t option resp_result =
try try
let line = IO.Input.read_line_using ~buf bs in let line = IO.Input_with_timeout.read_line_using ~buf ~deadline bs in
Log.debug (fun k -> k "parse request line: %s" line); Log.debug (fun k -> k "parse request line: %s" line);
let start_time = get_time_s () in let start_time = Time.now_s () in
let meth, path, version = let meth, path, version =
try try
let off = ref 0 in let off = ref 0 in
@ -134,7 +138,7 @@ let parse_req_start ~client_addr ~get_time_s ~buf (bs : IO.Input.t) :
in in
let meth = Meth.of_string meth in let meth = Meth.of_string meth in
Log.debug (fun k -> k "got meth: %s, path %S" (Meth.to_string meth) path); Log.debug (fun k -> k "got meth: %s, path %S" (Meth.to_string meth) path);
let headers = Headers.parse_ ~buf bs in let headers = Headers.parse_ ~buf ~deadline bs in
let host = let host =
match Headers.get "Host" headers with match Headers.get "Host" headers with
| None -> bad_reqf 400 "No 'Host' header in request" | None -> bad_reqf 400 "No 'Host' header in request"
@ -170,8 +174,8 @@ let parse_req_start ~client_addr ~get_time_s ~buf (bs : IO.Input.t) :
(* parse body, given the headers. (* parse body, given the headers.
@param tr_stream a transformation of the input stream. *) @param tr_stream a transformation of the input stream. *)
let parse_body_ ~tr_stream ~bytes (req : IO.Input.t t) : let parse_body_ ~tr_stream ~bytes (req : #IO.Input_with_timeout.t t) :
IO.Input.t t resp_result = IO.Input_with_timeout.t t resp_result =
try try
let size, has_size = let size, has_size =
match Headers.get_exn "Content-Length" req.headers |> int_of_string with match Headers.get_exn "Content-Length" req.headers |> int_of_string with
@ -186,7 +190,7 @@ let parse_body_ ~tr_stream ~bytes (req : IO.Input.t t) :
bad_reqf 400 "specifying both transfer-encoding and content-length" bad_reqf 400 "specifying both transfer-encoding and content-length"
| Some "chunked" -> | Some "chunked" ->
(* body sent by chunks *) (* body sent by chunks *)
let bs : IO.Input.t = let bs : IO.Input_with_timeout.t =
read_stream_chunked_ ~bytes @@ tr_stream req.body read_stream_chunked_ ~bytes @@ tr_stream req.body
in in
if size > 0 then ( if size > 0 then (
@ -203,14 +207,15 @@ let parse_body_ ~tr_stream ~bytes (req : IO.Input.t t) :
| Bad_req (c, s) -> Error (c, s) | Bad_req (c, s) -> Error (c, s)
| e -> Error (400, Printexc.to_string e) | e -> Error (400, Printexc.to_string e)
let read_body_full ?bytes ?buf_size (self : IO.Input.t t) : string t = let read_body_full ?bytes ?buf_size ~deadline
(self : #IO.Input_with_timeout.t t) : string t =
try try
let buf = let buf =
match bytes with match bytes with
| Some b -> Buf.of_bytes b | Some b -> Buf.of_bytes b
| None -> Buf.create ?size:buf_size () | None -> Buf.create ?size:buf_size ()
in in
let body = IO.Input.read_all_using ~buf self.body in let body = IO.Input_with_timeout.read_all_using ~buf ~deadline self.body in
{ self with body } { self with body }
with with
| Bad_req _ as e -> raise e | Bad_req _ as e -> raise e
@ -220,11 +225,13 @@ module Private_ = struct
let close_after_req = close_after_req let close_after_req = close_after_req
let parse_req_start = parse_req_start let parse_req_start = parse_req_start
let parse_req_start_exn ?(buf = Buf.create ()) ~client_addr ~get_time_s bs = let parse_req_start_exn ?(buf = Buf.create ()) ~client_addr ~deadline bs =
parse_req_start ~client_addr ~get_time_s ~buf bs |> unwrap_resp_result parse_req_start ~client_addr ~deadline ~buf bs |> unwrap_resp_result
let parse_body ?(bytes = Bytes.create 4096) req bs : _ t = let parse_body ?(bytes = Bytes.create 4096) req bs : _ t =
parse_body_ ~tr_stream:(fun s -> s) ~bytes { req with body = bs } parse_body_
~tr_stream:(fun s -> (s :> IO.Input_with_timeout.t))
~bytes { req with body = bs }
|> unwrap_resp_result |> unwrap_resp_result
let[@inline] set_body body self = { self with body } let[@inline] set_body body self = { self with body }

View file

@ -129,17 +129,26 @@ val start_time : _ t -> float
@since 0.11 *) @since 0.11 *)
val limit_body_size : val limit_body_size :
max_size:int -> bytes:bytes -> IO.Input.t t -> IO.Input.t t max_size:int ->
bytes:bytes ->
#IO.Input_with_timeout.t t ->
IO.Input_with_timeout.t t
(** Limit the body size to [max_size] bytes, or return (** Limit the body size to [max_size] bytes, or return
a [413] error. a [413] error.
@since 0.3 @since 0.3
*) *)
val read_body_full : ?bytes:bytes -> ?buf_size:int -> IO.Input.t t -> string t val read_body_full :
?bytes:bytes ->
?buf_size:int ->
deadline:float ->
#IO.Input_with_timeout.t t ->
string t
(** Read the whole body into a string. Potentially blocking. (** Read the whole body into a string. Potentially blocking.
@param buf_size initial size of underlying buffer (since 0.11) @param buf_size initial size of underlying buffer (since 0.11)
@param bytes the initial buffer (since 0.14) @param bytes the initial buffer (since 0.14)
@param deadline time after which this should fail with [Timeout] (since NEXT_RELEASE)
*) *)
(**/**) (**/**)
@ -148,20 +157,26 @@ val read_body_full : ?bytes:bytes -> ?buf_size:int -> IO.Input.t t -> string t
module Private_ : sig module Private_ : sig
val parse_req_start : val parse_req_start :
client_addr:Unix.sockaddr -> client_addr:Unix.sockaddr ->
get_time_s:(unit -> float) -> deadline:float ->
buf:Buf.t -> buf:Buf.t ->
IO.Input.t -> IO.Input_with_timeout.t ->
unit t option resp_result unit t option resp_result
val parse_req_start_exn : val parse_req_start_exn :
?buf:Buf.t -> ?buf:Buf.t ->
client_addr:Unix.sockaddr -> client_addr:Unix.sockaddr ->
get_time_s:(unit -> float) -> deadline:float ->
IO.Input.t -> #IO.Input_with_timeout.t ->
unit t option unit t option
val close_after_req : _ t -> bool val close_after_req : _ t -> bool
val parse_body : ?bytes:bytes -> unit t -> IO.Input.t -> IO.Input.t t
val parse_body :
?bytes:bytes ->
unit t ->
#IO.Input_with_timeout.t ->
IO.Input_with_timeout.t t
val set_body : 'a -> _ t -> 'a t val set_body : 'a -> _ t -> 'a t
end end

View file

@ -3,7 +3,9 @@ open Common_
type resp_error = Response_code.t * string type resp_error = Response_code.t * string
module Middleware = struct module Middleware = struct
type handler = IO.Input.t Request.t -> resp:(Response.t -> unit) -> unit type handler =
IO.Input_with_timeout.t Request.t -> resp:(Response.t -> unit) -> unit
type t = handler -> handler type t = handler -> handler
(** Apply a list of middlewares to [h] *) (** Apply a list of middlewares to [h] *)
@ -40,7 +42,11 @@ module type UPGRADE_HANDLER = sig
code is [101] alongside these headers. *) code is [101] alongside these headers. *)
val handle_connection : val handle_connection :
Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit Unix.sockaddr ->
handshake_state ->
IO.Input_with_timeout.t ->
IO.Output.t ->
unit
(** Take control of the connection and take it from there *) (** Take control of the connection and take it from there *)
end end
@ -52,9 +58,6 @@ module type IO_BACKEND = sig
val init_addr : unit -> string val init_addr : unit -> string
val init_port : unit -> int val init_port : unit -> int
val get_time_s : unit -> float
(** obtain the current timestamp in seconds. *)
val tcp_server : unit -> IO.TCP_server.builder val tcp_server : unit -> IO.TCP_server.builder
(** Server that can listen on a port and handle clients. *) (** Server that can listen on a port and handle clients. *)
end end
@ -72,13 +75,14 @@ let unwrap_handler_result req = function
type t = { type t = {
backend: (module IO_BACKEND); backend: (module IO_BACKEND);
mutable tcp_server: IO.TCP_server.t option; mutable tcp_server: IO.TCP_server.t option;
mutable handler: IO.Input.t Request.t -> Response.t; mutable handler: IO.Input_with_timeout.t Request.t -> Response.t;
(** toplevel handler, if any *) (** toplevel handler, if any *)
mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) mutable middlewares: (int * Middleware.t) list; (** Global middlewares *)
mutable middlewares_sorted: (int * Middleware.t) list lazy_t; mutable middlewares_sorted: (int * Middleware.t) list lazy_t;
(** sorted version of {!middlewares} *) (** sorted version of {!middlewares} *)
mutable path_handlers: (unit Request.t -> handler_result option) list; mutable path_handlers: (unit Request.t -> handler_result option) list;
(** path handlers *) (** path handlers *)
request_timeout_s: float; (** Timeout for parsing requests *)
bytes_pool: bytes Pool.t; bytes_pool: bytes Pool.t;
} }
@ -169,7 +173,8 @@ let add_route_handler (type a) ?accept ?middlewares ?meth self
let tr_req _oc req ~resp f = let tr_req _oc req ~resp f =
let req = let req =
Pool.with_resource self.bytes_pool @@ fun bytes -> Pool.with_resource self.bytes_pool @@ fun bytes ->
Request.read_body_full ~bytes req let deadline = Time.now_s () +. self.request_timeout_s in
Request.read_body_full ~bytes ~deadline req
in in
resp (f req) resp (f req)
in in
@ -190,7 +195,8 @@ let add_route_server_sent_handler ?accept self route f =
let tr_req (oc : IO.Output.t) req ~resp f = let tr_req (oc : IO.Output.t) req ~resp f =
let req = let req =
Pool.with_resource self.bytes_pool @@ fun bytes -> Pool.with_resource self.bytes_pool @@ fun bytes ->
Request.read_body_full ~bytes req let deadline = Time.now_s () +. self.request_timeout_s in
Request.read_body_full ~bytes ~deadline req
in in
let headers = let headers =
ref Headers.(empty |> set "content-type" "text/event-stream") ref Headers.(empty |> set "content-type" "text/event-stream")
@ -257,7 +263,11 @@ let add_upgrade_handler ?(accept = fun _ -> Ok ()) (self : t) route f : unit =
let clear_bytes_ bs = Bytes.fill bs 0 (Bytes.length bs) '\x00' let clear_bytes_ bs = Bytes.fill bs 0 (Bytes.length bs) '\x00'
let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t = (* client has at most 10s to send the request, unless it's a streaming request *)
let default_req_timeout_s_ = 30.
let create_from ?(buf_size = 16 * 1_024) ?(middlewares = [])
?(request_timeout_s = default_req_timeout_s_) ~backend () : t =
let handler _req = Response.fail ~code:404 "no top handler" in let handler _req = Response.fail ~code:404 "no top handler" in
let self = let self =
{ {
@ -267,6 +277,7 @@ let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t =
path_handlers = []; path_handlers = [];
middlewares = []; middlewares = [];
middlewares_sorted = lazy []; middlewares_sorted = lazy [];
request_timeout_s;
bytes_pool = bytes_pool =
Pool.create ~clear:clear_bytes_ Pool.create ~clear:clear_bytes_
~mk_item:(fun () -> Bytes.create buf_size) ~mk_item:(fun () -> Bytes.create buf_size)
@ -304,13 +315,11 @@ let string_as_list_contains_ (s : string) (sub : string) : bool =
let client_handle_for (self : t) ~client_addr ic oc : unit = let client_handle_for (self : t) ~client_addr ic oc : unit =
Pool.with_resource self.bytes_pool @@ fun bytes_req -> Pool.with_resource self.bytes_pool @@ fun bytes_req ->
Pool.with_resource self.bytes_pool @@ fun bytes_res -> Pool.with_resource self.bytes_pool @@ fun bytes_res ->
let (module B) = self.backend in
(* how to log the response to this query *) (* how to log the response to this query *)
let log_response (req : _ Request.t) (resp : Response.t) = let log_response (req : _ Request.t) (resp : Response.t) =
if not Log.dummy then ( if not Log.dummy then (
let msgf k = let msgf k =
let elapsed = B.get_time_s () -. req.start_time in let elapsed = Time.now_s () -. req.start_time in
k k
("response to=%s code=%d time=%.3fs meth=%s path=%S" : _ format4) ("response to=%s code=%d time=%.3fs meth=%s path=%S" : _ format4)
(Util.show_sockaddr client_addr) (Util.show_sockaddr client_addr)
@ -387,10 +396,10 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
let continue = ref true in let continue = ref true in
let handle_one_req () = let handle_one_req () =
let deadline = Time.now_s () +. self.request_timeout_s in
match match
let buf = Buf.of_bytes bytes_req in let buf = Buf.of_bytes bytes_req in
Request.Private_.parse_req_start ~client_addr ~get_time_s:B.get_time_s Request.Private_.parse_req_start ~client_addr ~deadline ~buf ic
~buf ic
with with
| Ok None -> continue := false (* client is done *) | Ok None -> continue := false (* client is done *)
| Error (c, s) -> | Error (c, s) ->

View file

@ -17,7 +17,8 @@
*) *)
module Middleware : sig module Middleware : sig
type handler = IO.Input.t Request.t -> resp:(Response.t -> unit) -> unit type handler =
IO.Input_with_timeout.t Request.t -> resp:(Response.t -> unit) -> unit
(** Handlers are functions returning a response to a request. (** Handlers are functions returning a response to a request.
The response can be delayed, hence the use of a continuation The response can be delayed, hence the use of a continuation
as the [resp] parameter. *) as the [resp] parameter. *)
@ -52,9 +53,6 @@ module type IO_BACKEND = sig
val init_port : unit -> int val init_port : unit -> int
(** Initial port *) (** Initial port *)
val get_time_s : unit -> float
(** Obtain the current timestamp in seconds. *)
val tcp_server : unit -> IO.TCP_server.builder val tcp_server : unit -> IO.TCP_server.builder
(** TCP server builder, to create servers that can listen (** TCP server builder, to create servers that can listen
on a port and handle clients. *) on a port and handle clients. *)
@ -63,6 +61,7 @@ end
val create_from : val create_from :
?buf_size:int -> ?buf_size:int ->
?middlewares:([ `Encoding | `Stage of int ] * Middleware.t) list -> ?middlewares:([ `Encoding | `Stage of int ] * Middleware.t) list ->
?request_timeout_s:float ->
backend:(module IO_BACKEND) -> backend:(module IO_BACKEND) ->
unit -> unit ->
t t
@ -74,6 +73,7 @@ val create_from :
@param buf_size size for buffers (since 0.11) @param buf_size size for buffers (since 0.11)
@param middlewares see {!add_middleware} for more details. @param middlewares see {!add_middleware} for more details.
@param request_timeout_s default timeout for requests (headers+body) (since NEXT_RELEASE)
@since 0.14 @since 0.14
*) *)
@ -95,7 +95,8 @@ val active_connections : t -> int
val add_decode_request_cb : val add_decode_request_cb :
t -> t ->
(unit Request.t -> (unit Request.t * (IO.Input.t -> IO.Input.t)) option) -> (unit Request.t ->
(unit Request.t * (IO.Input_with_timeout.t -> IO.Input_with_timeout.t)) option) ->
unit unit
[@@deprecated "use add_middleware"] [@@deprecated "use add_middleware"]
(** Add a callback for every request. (** Add a callback for every request.
@ -130,7 +131,8 @@ val add_middleware :
(** {2 Request handlers} *) (** {2 Request handlers} *)
val set_top_handler : t -> (IO.Input.t Request.t -> Response.t) -> unit val set_top_handler :
t -> (IO.Input_with_timeout.t Request.t -> Response.t) -> unit
(** Setup a handler called by default. (** Setup a handler called by default.
This handler is called with any request not accepted by any handler This handler is called with any request not accepted by any handler
@ -174,7 +176,7 @@ val add_route_handler_stream :
?middlewares:Middleware.t list -> ?middlewares:Middleware.t list ->
?meth:Meth.t -> ?meth:Meth.t ->
t -> t ->
('a, IO.Input.t Request.t -> Response.t) Route.t -> ('a, IO.Input_with_timeout.t Request.t -> Response.t) Route.t ->
'a -> 'a ->
unit unit
(** Similar to {!add_route_handler}, but where the body of the request (** Similar to {!add_route_handler}, but where the body of the request
@ -257,7 +259,11 @@ module type UPGRADE_HANDLER = sig
The connection is closed without further ado. *) The connection is closed without further ado. *)
val handle_connection : val handle_connection :
Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit Unix.sockaddr ->
handshake_state ->
IO.Input_with_timeout.t ->
IO.Output.t ->
unit
(** Take control of the connection and take it from ther.e *) (** Take control of the connection and take it from ther.e *)
end end

View file

@ -1,3 +1,5 @@
let now_s = Unix.gettimeofday
let[@inline] now_us () = let[@inline] now_us () =
let t = Unix.gettimeofday () in let t = Unix.gettimeofday () in
t *. 1e6 |> ceil t *. 1e6 |> ceil

10
src/core/time.mli Normal file
View file

@ -0,0 +1,10 @@
(** Basic time measurement.
This provides a basic clock, monotonic if [mtime] is installed,
or based on [Unix.gettimeofday] otherwise *)
val now_us : unit -> float
(** Current time in microseconds. The precision should be at least below the millisecond. *)
val now_s : unit -> float
(** Current time in seconds. The precision should be at least below the millisecond. *)

7
src/core/time.mtime.ml Normal file
View file

@ -0,0 +1,7 @@
let[@inline] now_s () =
let t = Mtime_clock.now_ns () in
Int64.(div t 1_000_000_000L |> to_float)
let[@inline] now_us () =
let t = Mtime_clock.now_ns () in
Int64.(div t 1000L |> to_float)

View file

@ -1,3 +1,4 @@
module A = Tiny_httpd_core.Atomic_ module A = Tiny_httpd_core.Atomic_
module Time = Tiny_httpd_core.Time
let spf = Printf.sprintf let spf = Printf.sprintf

View file

@ -4,10 +4,7 @@
(name tiny_httpd_prometheus) (name tiny_httpd_prometheus)
(public_name tiny_httpd.prometheus) (public_name tiny_httpd.prometheus)
(synopsis "Metrics using prometheus") (synopsis "Metrics using prometheus")
(private_modules common_p_ time_) (private_modules common_p_)
(flags :standard -open Tiny_httpd_core) (flags :standard -open Tiny_httpd_core)
(libraries (libraries
tiny_httpd.core unix tiny_httpd.core unix))
(select time_.ml from
(mtime mtime.clock.os -> time_.mtime.ml)
(-> time_.default.ml))))

View file

@ -1 +0,0 @@
val now_us : unit -> float

View file

@ -1,3 +0,0 @@
let[@inline] now_us () =
let t = Mtime_clock.now_ns () in
Int64.(div t 1000L |> to_float)

View file

@ -189,12 +189,12 @@ let http_middleware (reg : Registry.t) : Server.Middleware.t =
fun h : Server.Middleware.handler -> fun h : Server.Middleware.handler ->
fun req ~resp : unit -> fun req ~resp : unit ->
let start = Time_.now_us () in let start = Time.now_us () in
Counter.incr c_req; Counter.incr c_req;
h req ~resp:(fun (response : Response.t) -> h req ~resp:(fun (response : Response.t) ->
let code = response.code in let code = response.code in
let elapsed_us = Time_.now_us () -. start in let elapsed_us = Time.now_us () -. start in
let elapsed_s = elapsed_us /. 1e6 in let elapsed_s = elapsed_us /. 1e6 in
Histogram.add h_latency elapsed_s; Histogram.add h_latency elapsed_s;

View file

@ -93,12 +93,12 @@ let vfs_of_dir (top : string) : vfs =
let contains f = Sys.file_exists (top // f) let contains f = Sys.file_exists (top // f)
let list_dir f = Sys.readdir (top // f) let list_dir f = Sys.readdir (top // f)
let read_file_content f = let read_file_content f : IO.Input.t =
let fpath = top // f in let fpath = top // f in
match Unix.stat fpath with match Unix.stat fpath with
| { st_kind = Unix.S_REG; _ } -> | { st_kind = Unix.S_REG; _ } ->
let ic = Unix.(openfile fpath [ O_RDONLY ] 0) in let ic = Unix.(openfile fpath [ O_RDONLY ] 0) in
let closed = ref false in let closed = Atomic_.make false in
let buf = IO.Slice.create 4096 in let buf = IO.Slice.create 4096 in
IO.Input.of_unix_fd ~buf ~close_noerr:true ~closed ic IO.Input.of_unix_fd ~buf ~close_noerr:true ~closed ic
| _ -> failwith (Printf.sprintf "not a regular file: %S" f) | _ -> failwith (Printf.sprintf "not a regular file: %S" f)

View file

@ -92,15 +92,15 @@ module Unix_tcp_server_ = struct
Pool.with_resource self.slice_pool @@ fun ic_buf -> Pool.with_resource self.slice_pool @@ fun ic_buf ->
Pool.with_resource self.slice_pool @@ fun oc_buf -> Pool.with_resource self.slice_pool @@ fun oc_buf ->
let closed = ref false in let closed = Atomic_.make false in
let oc = let oc =
new IO.Output.of_unix_fd new IO.Output.of_unix_fd
~close_noerr:true ~closed ~buf:oc_buf client_sock ~close_noerr:true ~closed ~buf:oc_buf client_sock
in in
let ic = let ic =
IO.Input.of_unix_fd ~close_noerr:true ~closed ~buf:ic_buf IO.Input_with_timeout.of_unix_fd ~close_noerr:true ~closed
client_sock ~buf:ic_buf client_sock
in in
handle.handle ~client_addr ic oc handle.handle ~client_addr ic oc
in in

View file

@ -1,6 +1,6 @@
open Common_ws_ open Common_ws_
type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit type handler = Unix.sockaddr -> IO.Input_with_timeout.t -> IO.Output.t -> unit
module Frame_type = struct module Frame_type = struct
type t = int type t = int
@ -196,7 +196,7 @@ module Reader = struct
| Close | Close
type t = { type t = {
ic: IO.Input.t; ic: IO.Input_with_timeout.t;
writer: Writer.t; (** Writer, to send "pong" *) writer: Writer.t; (** Writer, to send "pong" *)
header_buf: bytes; (** small buffer to read frame headers *) header_buf: bytes; (** small buffer to read frame headers *)
small_buf: bytes; (** Used for control frames *) small_buf: bytes; (** Used for control frames *)
@ -220,52 +220,65 @@ module Reader = struct
let max_fragment_size = 1 lsl 30 let max_fragment_size = 1 lsl 30
(** Read next frame header into [self.header] *) (** Read next frame header into [self.header] *)
let read_frame_header (self : t) : unit = let read_frame_header (self : t) ~deadline : unit =
(* read header *) try
IO.Input.really_input self.ic self.header_buf 0 2; (* read header *)
IO.Input_with_timeout.really_input self.ic ~deadline self.header_buf 0 2;
let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in
let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in
self.header.fin <- b0 land 1 == 1; self.header.fin <- b0 land 1 == 1;
let ext = (b0 lsr 4) land 0b0111 in let ext = (b0 lsr 4) land 0b0111 in
if ext <> 0 then ( if ext <> 0 then (
Log.error (fun k -> k "websocket: unknown extension %d, closing" ext); Log.error (fun k -> k "websocket: unknown extension %d, closing" ext);
raise Close_connection
);
self.header.ty <- b0 land 0b0000_1111;
self.header.mask <- b1 land 0b1000_0000 != 0;
let payload_len : int =
let len = b1 land 0b0111_1111 in
if len = 126 then (
IO.Input_with_timeout.really_input self.ic ~deadline self.header_buf 0
2;
Bytes.get_int16_be self.header_buf 0
) else if len = 127 then (
IO.Input_with_timeout.really_input self.ic ~deadline self.header_buf 0
8;
let len64 = Bytes.get_int64_be self.header_buf 0 in
if compare len64 (Int64.of_int max_fragment_size) > 0 then (
Log.error (fun k ->
k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64
max_fragment_size);
raise Close_connection
);
Int64.to_int len64
) else
len
in
self.header.payload_len <- payload_len;
if self.header.mask then
IO.Input_with_timeout.really_input self.ic ~deadline
self.header.mask_key 0 4;
(*Log.debug (fun k ->
k "websocket: read frame header type=%s payload_len=%d mask=%b"
(Frame_type.show self.header.ty)
self.header.payload_len self.header.mask);*)
()
with
| IO.Input_with_timeout.Timeout_partial_read _
| IO.Input_with_timeout.Timeout
->
(* NOTE: this is not optimal but it's the easiest solution, for now,
to the problem of a partially read frame header with
a timeout in the middle (we would have to save *)
Log.error (fun k -> k "websocket: timeout while reading frame header");
raise Close_connection raise Close_connection
);
self.header.ty <- b0 land 0b0000_1111;
self.header.mask <- b1 land 0b1000_0000 != 0;
let payload_len : int =
let len = b1 land 0b0111_1111 in
if len = 126 then (
IO.Input.really_input self.ic self.header_buf 0 2;
Bytes.get_int16_be self.header_buf 0
) else if len = 127 then (
IO.Input.really_input self.ic self.header_buf 0 8;
let len64 = Bytes.get_int64_be self.header_buf 0 in
if compare len64 (Int64.of_int max_fragment_size) > 0 then (
Log.error (fun k ->
k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64
max_fragment_size);
raise Close_connection
);
Int64.to_int len64
) else
len
in
self.header.payload_len <- payload_len;
if self.header.mask then
IO.Input.really_input self.ic self.header.mask_key 0 4;
(*Log.debug (fun k ->
k "websocket: read frame header type=%s payload_len=%d mask=%b"
(Frame_type.show self.header.ty)
self.header.payload_len self.header.mask);*)
()
external apply_masking_ : bytes -> bytes -> int -> int -> unit external apply_masking_ : bytes -> bytes -> int -> int -> unit
= "tiny_httpd_ws_apply_masking" = "tiny_httpd_ws_apply_masking"
@ -276,30 +289,45 @@ module Reader = struct
assert (off >= 0 && off + len <= Bytes.length buf); assert (off >= 0 && off + len <= Bytes.length buf);
apply_masking_ mask_key buf off len apply_masking_ mask_key buf off len
let read_body_to_string (self : t) : string = let read_body_to_string (self : t) ~deadline : string =
let len = self.header.payload_len in let len = self.header.payload_len in
let buf = Bytes.create len in let buf = Bytes.create len in
IO.Input.really_input self.ic buf 0 len; (try IO.Input_with_timeout.really_input self.ic ~deadline buf 0 len
with
| IO.Input_with_timeout.Timeout_partial_read _
| IO.Input_with_timeout.Timeout
->
raise Close_connection);
if self.header.mask then if self.header.mask then
apply_masking ~mask_key:self.header.mask_key buf 0 len; apply_masking ~mask_key:self.header.mask_key buf 0 len;
Bytes.unsafe_to_string buf Bytes.unsafe_to_string buf
(** Skip bytes of the body *) (** Skip bytes of the body *)
let skip_body (self : t) : unit = let skip_body (self : t) ~deadline : unit =
let len = ref self.header.payload_len in let len = ref self.header.payload_len in
while !len > 0 do while !len > 0 do
let n = min !len (Bytes.length self.small_buf) in let n = min !len (Bytes.length self.small_buf) in
IO.Input.really_input self.ic self.small_buf 0 n; (try
IO.Input_with_timeout.really_input self.ic ~deadline self.small_buf 0 n
with
| IO.Input_with_timeout.Timeout_partial_read _
| IO.Input_with_timeout.Timeout
->
raise Close_connection);
len := !len - n len := !len - n
done done
(** State machine that reads [len] bytes into [buf] *) (** State machine that reads [len] bytes into [buf] *)
let rec read_rec (self : t) buf i len : int = let rec read_rec (self : t) ~deadline buf i len : int =
match self.state with match self.state with
| Close -> 0 | Close -> 0
| Reading_frame r -> | Reading_frame r ->
let len = min len r.remaining_bytes in let len = min len r.remaining_bytes in
let n = IO.Input.input self.ic buf i len in let timeout = Time.now_s () -. deadline in
if timeout <= 0. then raise IO.Input_with_timeout.Timeout;
let n =
IO.Input_with_timeout.input_with_timeout self.ic timeout buf i len
in
(* update state *) (* update state *)
r.remaining_bytes <- r.remaining_bytes - n; r.remaining_bytes <- r.remaining_bytes - n;
@ -313,7 +341,7 @@ module Reader = struct
); );
n n
| Begin -> | Begin ->
read_frame_header self; read_frame_header self ~deadline;
(*Log.debug (fun k -> (*Log.debug (fun k ->
k "websocket: read frame of type=%s payload_len=%d" k "websocket: read frame of type=%s payload_len=%d"
(Frame_type.show self.header.ty) (Frame_type.show self.header.ty)
@ -330,19 +358,19 @@ module Reader = struct
(Frame_type.show self.last_ty)); (Frame_type.show self.last_ty));
raise Close_connection raise Close_connection
); );
read_rec self buf i len read_rec self ~deadline buf i len
| 1 -> | 1 ->
self.state <- self.state <-
Reading_frame { remaining_bytes = self.header.payload_len }; Reading_frame { remaining_bytes = self.header.payload_len };
read_rec self buf i len read_rec self ~deadline buf i len
| 2 -> | 2 ->
self.state <- self.state <-
Reading_frame { remaining_bytes = self.header.payload_len }; Reading_frame { remaining_bytes = self.header.payload_len };
read_rec self buf i len read_rec self ~deadline buf i len
| 8 -> | 8 ->
(* close frame *) (* close frame *)
self.state <- Close; self.state <- Close;
let body = read_body_to_string self in let body = read_body_to_string self ~deadline in
if String.length body >= 2 then ( if String.length body >= 2 then (
let errcode = Bytes.get_int16_be (Bytes.unsafe_of_string body) 0 in let errcode = Bytes.get_int16_be (Bytes.unsafe_of_string body) 0 in
Log.info (fun k -> Log.info (fun k ->
@ -352,19 +380,19 @@ module Reader = struct
0 0
| 9 -> | 9 ->
(* pong, just ignore *) (* pong, just ignore *)
skip_body self; skip_body self ~deadline;
Writer.send_pong self.writer; Writer.send_pong self.writer;
read_rec self buf i len read_rec self ~deadline buf i len
| 10 -> | 10 ->
(* pong, just ignore *) (* pong, just ignore *)
skip_body self; skip_body self ~deadline;
read_rec self buf i len read_rec self ~deadline buf i len
| ty -> | ty ->
Log.error (fun k -> k "unknown frame type: %xd" ty); Log.error (fun k -> k "unknown frame type: %xd" ty);
raise Close_connection) raise Close_connection)
let read self buf i len = let read self ~deadline buf i len =
try read_rec self buf i len try read_rec self ~deadline buf i len
with Close_connection -> with Close_connection ->
self.state <- Close; self.state <- Close;
0 0
@ -376,16 +404,26 @@ module Reader = struct
) )
end end
let upgrade ic oc : _ * _ = (* 30 min *)
let default_timeout_s = 60. *. 30.
let upgrade ?(timeout_s = default_timeout_s) ic oc : _ * _ =
let writer = Writer.create ~oc () in let writer = Writer.create ~oc () in
let reader = Reader.create ~ic ~writer () in let reader = Reader.create ~ic ~writer () in
let ws_ic : IO.Input.t = let ws_ic : IO.Input_with_timeout.t =
object object (self)
inherit IO.Input.t_from_refill ~bytes:(Bytes.create 4_096) () inherit
IO.Input_with_timeout.t_with_timeout_from_refill
~bytes:(Bytes.create 4_096) () as super
method private refill (slice : IO.Slice.t) = method private refill_with_timeout t (slice : IO.Slice.t) =
let deadline = Time.now_s () +. t in
slice.off <- 0; slice.off <- 0;
slice.len <- Reader.read reader slice.bytes 0 (Bytes.length slice.bytes) slice.len <-
Reader.read reader ~deadline slice.bytes 0 (Bytes.length slice.bytes)
method! fill_buf () =
IO.Input_with_timeout.fill_buf_with_timeout self timeout_s
method close () = Reader.close reader method close () = Reader.close reader
end end
@ -404,6 +442,7 @@ let upgrade ic oc : _ * _ =
module Make_upgrade_handler (X : sig module Make_upgrade_handler (X : sig
val accept_ws_protocol : string -> bool val accept_ws_protocol : string -> bool
val handler : handler val handler : handler
val timeout_s : float
end) : Server.UPGRADE_HANDLER = struct end) : Server.UPGRADE_HANDLER = struct
type handshake_state = unit type handshake_state = unit
@ -446,7 +485,7 @@ end) : Server.UPGRADE_HANDLER = struct
try Ok (handshake_ req) with Bad_req s -> Error s try Ok (handshake_ req) with Bad_req s -> Error s
let handle_connection addr () ic oc = let handle_connection addr () ic oc =
let ws_ic, ws_oc = upgrade ic oc in let ws_ic, ws_oc = upgrade ~timeout_s:X.timeout_s ic oc in
try X.handler addr ws_ic ws_oc try X.handler addr ws_ic ws_oc
with Close_connection -> with Close_connection ->
Log.debug (fun k -> k "websocket: requested to close the connection"); Log.debug (fun k -> k "websocket: requested to close the connection");
@ -454,10 +493,12 @@ end) : Server.UPGRADE_HANDLER = struct
end end
let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true) let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true)
(server : Server.t) route (f : handler) : unit = ?(timeout_s = default_timeout_s) (server : Server.t) route (f : handler) :
unit =
let module M = Make_upgrade_handler (struct let module M = Make_upgrade_handler (struct
let handler = f let handler = f
let accept_ws_protocol = accept_ws_protocol let accept_ws_protocol = accept_ws_protocol
let timeout_s = timeout_s
end) in end) in
let up : Server.upgrade_handler = (module M) in let up : Server.upgrade_handler = (module M) in
Server.add_upgrade_handler ?accept server route up Server.add_upgrade_handler ?accept server route up

View file

@ -4,15 +4,20 @@
for a websocket server. It has no additional dependencies. for a websocket server. It has no additional dependencies.
*) *)
type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit type handler = Unix.sockaddr -> IO.Input_with_timeout.t -> IO.Output.t -> unit
(** Websocket handler *) (** Websocket handler *)
val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t val upgrade :
?timeout_s:float ->
IO.Input_with_timeout.t ->
IO.Output.t ->
IO.Input_with_timeout.t * IO.Output.t
(** Upgrade a byte stream to the websocket framing protocol. *) (** Upgrade a byte stream to the websocket framing protocol. *)
val add_route_handler : val add_route_handler :
?accept:(unit Request.t -> (unit, int * string) result) -> ?accept:(unit Request.t -> (unit, int * string) result) ->
?accept_ws_protocol:(string -> bool) -> ?accept_ws_protocol:(string -> bool) ->
?timeout_s:float ->
Server.t -> Server.t ->
(Server.upgrade_handler, Server.upgrade_handler) Route.t -> (Server.upgrade_handler, Server.upgrade_handler) Route.t ->
handler -> handler ->

View file

@ -9,12 +9,13 @@ let () =
\r\n\ \r\n\
salutationsSOMEJUNK" salutationsSOMEJUNK"
in in
let str = IO.Input.of_string q in let str = IO.Input_with_timeout.of_string q in
let client_addr = Unix.(ADDR_INET (inet_addr_loopback, 1024)) in let client_addr = Unix.(ADDR_INET (inet_addr_loopback, 1024)) in
let deadline = Time.now_s () +. 10. in
let r = let r =
Request.Private_.parse_req_start_exn ~client_addr ~buf:(Buf.create ()) Request.Private_.parse_req_start_exn ~client_addr ~buf:(Buf.create ())
~get_time_s:(fun _ -> 0.) ~deadline str
str
in in
match r with match r with
| None -> failwith "should parse" | None -> failwith "should parse"
@ -23,6 +24,8 @@ let () =
assert_eq (Some "coucou") (Headers.get "host" req.headers); assert_eq (Some "coucou") (Headers.get "host" req.headers);
assert_eq (Some "11") (Headers.get "content-length" req.headers); assert_eq (Some "11") (Headers.get "content-length" req.headers);
assert_eq "hello" req.path; assert_eq "hello" req.path;
let req = Request.Private_.parse_body req str |> Request.read_body_full in let req =
Request.Private_.parse_body req str |> Request.read_body_full ~deadline
in
assert_eq ~to_string:(fun s -> s) "salutations" req.body; assert_eq ~to_string:(fun s -> s) "salutations" req.body;
() ()

View file

@ -16,6 +16,7 @@ depends: [
"base-threads" "base-threads"
"result" "result"
"hmap" "hmap"
"base-unix"
"iostream" {>= "0.2"} "iostream" {>= "0.2"}
"ocaml" {>= "4.08"} "ocaml" {>= "4.08"}
"odoc" {with-doc} "odoc" {with-doc}