error handling, and bugfix (idempotent closing of Unix.fd)

This commit is contained in:
Simon Cruanes 2024-02-22 18:23:18 -05:00
parent d56ffb3a08
commit 225c21b4cc
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
6 changed files with 123 additions and 23 deletions

View file

@ -98,7 +98,8 @@ let vfs_of_dir (top : string) : vfs =
match Unix.stat fpath with
| { st_kind = Unix.S_REG; _ } ->
let ic = Unix.(openfile fpath [ O_RDONLY ] 0) in
Tiny_httpd_stream.of_fd_close_noerr ic
let closed = ref false in
Tiny_httpd_stream.of_fd_close_noerr ~closed ic
| _ -> failwith (Printf.sprintf "not a regular file: %S" f)
let create f =

View file

@ -34,7 +34,7 @@ module Input = struct
close_in ic);
}
let of_unix_fd ?(close_noerr = false) (fd : Unix.file_descr) : t =
let of_unix_fd ?(close_noerr = false) ~closed (fd : Unix.file_descr) : t =
let eof = ref false in
{
input =
@ -48,6 +48,14 @@ module Input = struct
| n_ ->
n := n_;
continue := false
| exception
Unix.Unix_error
( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN
| Unix.ECONNRESET | Unix.EPIPE ),
_,
_ ) ->
eof := true;
continue := false
| exception
Unix.Unix_error
((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) ->
@ -59,11 +67,14 @@ module Input = struct
!n);
close =
(fun () ->
eof := true;
if close_noerr then (
try Unix.close fd with _ -> ()
) else
Unix.close fd);
if not !closed then (
closed := true;
eof := true;
if close_noerr then (
try Unix.close fd with _ -> ()
) else
Unix.close fd
));
}
let of_slice (i_bs : bytes) (i_off : int) (i_len : int) : t =
@ -134,6 +145,70 @@ module Output = struct
This can be a [Buffer.t], an [out_channel], a [Unix.file_descr], etc. *)
let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Buf.t)
(fd : Unix.file_descr) : t =
Buf.clear buf;
let buf = Buf.bytes_slice buf in
let off = ref 0 in
let flush () =
if !off > 0 then (
let i = ref 0 in
while !i < !off do
(* Printf.eprintf "write %d bytes\n%!" (!off - !i); *)
match Unix.write fd buf !i (!off - !i) with
| 0 -> failwith "write failed"
| n -> i := !i + n
| exception
Unix.Unix_error
( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN
| Unix.ECONNRESET | Unix.EPIPE ),
_,
_ ) ->
failwith "write failed"
| exception
Unix.Unix_error
((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) ->
ignore (Unix.select [] [ fd ] [] 1.)
done;
off := 0
)
in
let[@inline] flush_if_full_ () = if !off = Bytes.length buf then flush () in
let output_char c =
flush_if_full_ ();
Bytes.set buf !off c;
incr off;
flush_if_full_ ()
in
let output bs i len =
(* Printf.eprintf "output %d bytes (buffered)\n%!" len; *)
let i = ref i in
let len = ref len in
while !len > 0 do
flush_if_full_ ();
let n = min !len (Bytes.length buf - !off) in
Bytes.blit bs !i buf !off n;
i := !i + n;
len := !len - n;
off := !off + n
done;
flush_if_full_ ()
in
let close () =
if not !closed then (
closed := true;
flush ();
if close_noerr then (
try Unix.close fd with _ -> ()
) else
Unix.close fd
)
in
{ output; output_char; flush; close }
(** [of_out_channel oc] wraps the channel into a {!Output.t}.
@param close_noerr if true, then closing the result uses [close_out_noerr]
instead of [close_out] to close [oc] *)

View file

@ -7,6 +7,10 @@ let debug k = Log.debug (fun fmt -> k (fun x -> fmt ?header:None ?tags:None x))
let error k = Log.err (fun fmt -> k (fun x -> fmt ?header:None ?tags:None x))
let setup ~debug () =
let mutex = Mutex.create () in
Logs.set_reporter_mutex
~lock:(fun () -> Mutex.lock mutex)
~unlock:(fun () -> Mutex.unlock mutex);
Logs.set_reporter @@ Logs.format_reporter ();
Logs.set_level ~all:true
(Some

View file

@ -491,8 +491,12 @@ module Response = struct
Byte_stream.close str
| exception e ->
let bt = Printexc.get_raw_backtrace () in
IO.Output.flush oc;
Log.error (fun k ->
k "t[%d]: outputing stream failed with %s"
(Thread.id @@ Thread.self ())
(Printexc.to_string e));
Byte_stream.close str;
IO.Output.flush oc;
Printexc.raise_with_backtrace e bt));
IO.Output.flush oc
end
@ -904,6 +908,7 @@ module Unix_tcp_server_ = struct
type t = {
addr: string;
port: int;
buf_pool: Buf.t Pool.t;
max_connections: int;
sem_max_connections: Sem_.t;
(** semaphore to restrict the number of active concurrent connections *)
@ -971,20 +976,24 @@ module Unix_tcp_server_ = struct
let handle_client_unix_ (client_sock : Unix.file_descr)
(client_addr : Unix.sockaddr) : unit =
Log.info (fun k ->
k "serving new client on %s"
k "t[%d]: serving new client on %s"
(Thread.id @@ Thread.self ())
(Tiny_httpd_util.show_sockaddr client_addr));
(*
if self.masksigpipe then
ignore (Unix.sigprocmask Unix.SIG_BLOCK [ Sys.sigpipe ] : _ list);
*)
Unix.set_nonblock client_sock;
Unix.setsockopt client_sock Unix.TCP_NODELAY true;
Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout);
Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout);
Pool.with_resource self.buf_pool @@ fun buf ->
let closed = ref false in
let oc =
IO.Output.of_out_channel ~close_noerr:true
@@ Unix.out_channel_of_descr client_sock
IO.Output.of_unix_fd ~close_noerr:true ~closed ~buf client_sock
in
let ic =
IO.Input.of_unix_fd ~close_noerr:true ~closed client_sock
in
let ic = IO.Input.of_unix_fd ~close_noerr:true client_sock in
handle.handle ~client_addr ic oc
in
@ -1046,6 +1055,10 @@ let create ?(masksigpipe = true) ?max_connections ?(timeout = 0.0) ?buf_size
{
Unix_tcp_server_.addr;
new_thread;
buf_pool =
Pool.create ~clear:Buf.clear_and_zero
~mk_item:(fun () -> Buf.create ?size:buf_size ())
();
running = true;
port;
sock;

View file

@ -67,12 +67,15 @@ let of_chan_ ?buf_size ic ~close_noerr : t =
let of_chan ?buf_size ic = of_chan_ ?buf_size ic ~close_noerr:false
let of_chan_close_noerr ?buf_size ic = of_chan_ ?buf_size ic ~close_noerr:true
let of_fd_ ?buf_size ~close_noerr ic : t =
let inc = IO.Input.of_unix_fd ~close_noerr ic in
let of_fd_ ?buf_size ~close_noerr ~closed ic : t =
let inc = IO.Input.of_unix_fd ~close_noerr ~closed ic in
of_input ?buf_size inc
let of_fd ?buf_size fd : t = of_fd_ ?buf_size ~close_noerr:false fd
let of_fd_close_noerr ?buf_size fd : t = of_fd_ ?buf_size ~close_noerr:true fd
let of_fd ?buf_size ~closed fd : t =
of_fd_ ?buf_size ~closed ~close_noerr:false fd
let of_fd_close_noerr ?buf_size ~closed fd : t =
of_fd_ ?buf_size ~closed ~close_noerr:true fd
let iter f (self : t) : unit =
let continue = ref true in
@ -120,7 +123,7 @@ let of_string s : t = of_bytes (Bytes.unsafe_of_string s)
let with_file ?buf_size file f =
let ic = Unix.(openfile file [ O_RDONLY ] 0) in
try
let x = f (of_fd ?buf_size ic) in
let x = f (of_fd ?buf_size ~closed:(ref false) ic) in
Unix.close ic;
x
with e ->
@ -304,8 +307,12 @@ let read_chunked ?(buf = Buf.create ()) ~fail (bs : t) : t =
let output_chunked' ?buf (oc : IO.Output.t) (self : t) : unit =
let oc' = IO.Output.chunk_encoding ?buf oc ~close_rec:false in
to_chan' oc' self;
IO.Output.close oc'
match to_chan' oc' self with
| () -> IO.Output.close oc'
| exception e ->
let bt = Printexc.get_raw_backtrace () in
IO.Output.close oc';
Printexc.raise_with_backtrace e bt
(* print a stream as a series of chunks *)
let output_chunked ?buf (oc : out_channel) (self : t) : unit =

View file

@ -74,10 +74,10 @@ val of_chan : ?buf_size:int -> in_channel -> t
val of_chan_close_noerr : ?buf_size:int -> in_channel -> t
(** Same as {!of_chan} but the [close] method will never fail. *)
val of_fd : ?buf_size:int -> Unix.file_descr -> t
val of_fd : ?buf_size:int -> closed:bool ref -> Unix.file_descr -> t
(** Make a buffered stream from the given file descriptor. *)
val of_fd_close_noerr : ?buf_size:int -> Unix.file_descr -> t
val of_fd_close_noerr : ?buf_size:int -> closed:bool ref -> Unix.file_descr -> t
(** Same as {!of_fd} but the [close] method will never fail. *)
val of_bytes : ?i:int -> ?len:int -> bytes -> t