From 1532d925402f47a50a2aacb0597ff3b3d3ca2df6 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 24 Jun 2024 10:26:51 -0400 Subject: [PATCH] wip: async-io --- src/unix/async_io.ml | 590 +++++++++++------------------------------- src/unix/async_io.mli | 74 +++++- src/unix/dune | 1 + src/unix/sockaddr.ml | 16 ++ 4 files changed, 240 insertions(+), 441 deletions(-) create mode 100644 src/unix/sockaddr.ml diff --git a/src/unix/async_io.ml b/src/unix/async_io.ml index 31cbdad7..5dd4766e 100644 --- a/src/unix/async_io.ml +++ b/src/unix/async_io.ml @@ -1,5 +1,5 @@ open Common_ -module Slice = Iostream.Slice +module Slice = Iostream_types.Slice let rec read (fd : Fd.t) buf i len : int = if len = 0 || Fd.closed fd then @@ -57,7 +57,7 @@ let write fd buf i len : unit = done module Reader = struct - include Iostream.In + include Iostream_types.In let of_fd ?(close_noerr = false) (fd : Fd.t) : t = object @@ -71,11 +71,11 @@ module Reader = struct end let of_slice (slice : Slice.t) : t = - (of_bytes ~off:slice.off ~len:slice.len slice.bytes :> t) + (Iostream.In.of_bytes ~off:slice.off ~len:slice.len slice.bytes :> t) end module Buf_reader = struct - include Iostream.In_buf + include Iostream_types.In_buf let of_fd ?(close_noerr = false) ~(buf : bytes) (fd : Fd.t) : t = let eof = ref false in @@ -97,453 +97,171 @@ module Buf_reader = struct end end -(* -(** Output channel (byte sink) *) -module Output = struct - include Iostream.Out_buf +module Writer = struct + include Iostream_types.Out - class of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t) - (fd : Unix.file_descr) : - t = + let of_fd ?(close_noerr = false) (fd : Fd.t) : t = object - inherit t_from_output ~bytes:buf.bytes () + method output buf i len = write fd buf i len - method private output_underlying bs i len0 = - let i = ref i in - let len = ref len0 in - while !len > 0 do - match Unix.write fd bs !i !len with - | 0 -> failwith "write failed" - | n -> - i := !i + n; - len := !len - n - | exception - Unix.Unix_error - ( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN - | Unix.ECONNRESET | Unix.EPIPE ), - _, - _ ) -> - failwith "write failed" - | exception - Unix.Unix_error - ((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) -> - ignore (Unix.select [] [ fd ] [] 1.) - done + method close () = + if close_noerr then + Fd.close_noerr fd + else + Fd.close fd + end +end + +module Buf_writer = struct + include Iostream_types.Out_buf + + let of_fd ?(close_noerr = false) ~(buf : bytes) (fd : Fd.t) : t = + object + inherit Iostream.Out_buf.t_from_output ~bytes:buf () + method private output_underlying bs i len = write fd bs i len method private close_underlying () = - if not !closed then ( - closed := true; - if close_noerr then ( - try Unix.close fd with _ -> () - ) else - Unix.close fd - ) - end - - let output_buf (self : t) (buf : Buf.t) : unit = - let b = Buf.bytes_slice buf in - output self b 0 (Buf.size buf) - - (** [chunk_encoding oc] makes a new channel that outputs its content into [oc] - in chunk encoding form. - @param close_rec if true, closing the result will also close [oc] - @param buf a buffer used to accumulate data into chunks. - Chunks are emitted when [buf]'s size gets over a certain threshold, - or when [flush] is called. - *) - let chunk_encoding ?(buf = Buf.create ()) ~close_rec (oc : #t) : t = - (* write content of [buf] as a chunk if it's big enough. - If [force=true] then write content of [buf] if it's simply non empty. *) - let write_buf ~force () = - let n = Buf.size buf in - if (force && n > 0) || n >= 4_096 then ( - output_string oc (Printf.sprintf "%x\r\n" n); - output oc (Buf.bytes_slice buf) 0 n; - output_string oc "\r\n"; - Buf.clear buf - ) - in - - object - method flush () = - write_buf ~force:true (); - flush oc - - method close () = - write_buf ~force:true (); - (* write an empty chunk to close the stream *) - output_string oc "0\r\n"; - (* write another crlf after the stream (see #56) *) - output_string oc "\r\n"; - flush oc; - if close_rec then close oc - - method output b i n = - Buf.add_bytes buf b i n; - write_buf ~force:false () - - method output_char c = - Buf.add_char buf c; - write_buf ~force:false () + if close_noerr then + Fd.close_noerr fd + else + Fd.close fd end end -(** Input channel (byte source) *) -module Input = struct - include Iostream.In_buf +module Buf_pool = struct + type t = { with_buf: 'a. int -> (bytes -> 'a) -> 'a } [@@unboxed] - let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t) - (fd : Unix.file_descr) : t = - let eof = ref false in - object - inherit Iostream.In_buf.t_from_refill ~bytes:buf.bytes () - - method private refill (slice : Slice.t) = - if not !eof then ( - slice.off <- 0; - let continue = ref true in - while !continue do - match Unix.read fd slice.bytes 0 (Bytes.length slice.bytes) with - | n -> - slice.len <- n; - continue := false - | exception - Unix.Unix_error - ( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN - | Unix.ECONNRESET | Unix.EPIPE ), - _, - _ ) -> - eof := true; - continue := false - | exception - Unix.Unix_error - ((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) -> - ignore (Unix.select [ fd ] [] [] 1.) - done; - (* Printf.eprintf "read returned %d B\n%!" !n; *) - if slice.len = 0 then eof := true - ) - - method close () = - if not !closed then ( - closed := true; - eof := true; - if close_noerr then ( - try Unix.close fd with _ -> () - ) else - Unix.close fd - ) - end - - let of_slice (slice : Slice.t) : t = - object - inherit Iostream.In_buf.t_from_refill ~bytes:slice.bytes () - - method private refill (slice : Slice.t) = - slice.off <- 0; - slice.len <- 0 - - method close () = () - end - - (** Read into the given slice. - @return the number of bytes read, [0] means end of input. *) - let[@inline] input (self : t) buf i len = self#input buf i len - - (** Close the channel. *) - let[@inline] close self : unit = self#close () - - (** Read exactly [len] bytes. - @raise End_of_file if the input did not contain enough data. *) - let really_input (self : t) buf i len : unit = - let i = ref i in - let len = ref len in - while !len > 0 do - let n = input self buf !i !len in - if n = 0 then raise End_of_file; - i := !i + n; - len := !len - n - done - - let iter_slice (f : Slice.t -> unit) (self : #t) : unit = - let continue = ref true in - while !continue do - let slice = self#fill_buf () in - if slice.len = 0 then ( - continue := false; - close self - ) else ( - f slice; - Slice.consume slice slice.len - ) - done - - let iter f self = - iter_slice (fun (slice : Slice.t) -> f slice.bytes slice.off slice.len) self - - let to_chan oc (self : #t) = - iter_slice - (fun (slice : Slice.t) -> - Stdlib.output oc slice.bytes slice.off slice.len) - self - - let to_chan' (oc : #Iostream.Out.t) (self : #t) : unit = - iter_slice - (fun (slice : Slice.t) -> - Iostream.Out.output oc slice.bytes slice.off slice.len) - self - - let read_all_using ~buf (self : #t) : string = - Buf.clear buf; - let continue = ref true in - while !continue do - let slice = fill_buf self in - if slice.len = 0 then - continue := false - else ( - assert (slice.len > 0); - Buf.add_bytes buf slice.bytes slice.off slice.len; - Slice.consume slice slice.len - ) - done; - Buf.contents_and_clear buf - - (** Read [n] bytes from the input into [bytes]. *) - let read_exactly_ ~too_short (self : #t) (bytes : bytes) (n : int) : unit = - assert (Bytes.length bytes >= n); - let offset = ref 0 in - while !offset < n do - let slice = self#fill_buf () in - let n_read = min slice.len (n - !offset) in - Bytes.blit slice.bytes slice.off bytes !offset n_read; - offset := !offset + n_read; - Slice.consume slice n_read; - if n_read = 0 then too_short () - done - - (** read a line into the buffer, after clearing it. *) - let read_line_into (self : t) ~buf : unit = - Buf.clear buf; - let continue = ref true in - while !continue do - let slice = self#fill_buf () in - if slice.len = 0 then ( - continue := false; - if Buf.size buf = 0 then raise End_of_file - ); - let j = ref slice.off in - let limit = slice.off + slice.len in - while !j < limit && Bytes.get slice.bytes !j <> '\n' do - incr j - done; - if !j < limit then ( - assert (Bytes.get slice.bytes !j = '\n'); - (* line without '\n' *) - Buf.add_bytes buf slice.bytes slice.off (!j - slice.off); - (* consume line + '\n' *) - Slice.consume slice (!j - slice.off + 1); - continue := false - ) else ( - Buf.add_bytes buf slice.bytes slice.off slice.len; - Slice.consume slice slice.len - ) - done - - let read_line_using ~buf (self : #t) : string = - read_line_into self ~buf; - Buf.contents_and_clear buf - - let read_line_using_opt ~buf (self : #t) : string option = - match read_line_into self ~buf with - | () -> Some (Buf.contents_and_clear buf) - | exception End_of_file -> None - - (* helper for making a new input stream that either contains at most [size] - bytes, or contains exactly [size] bytes. *) - let reading_exactly_ ~skip_on_close ~close_rec ~size ~bytes (arg : t) : t = - let remaining_size = ref size in - - object - inherit t_from_refill ~bytes () - - method close () = - if !remaining_size > 0 && skip_on_close then skip arg !remaining_size; - if close_rec then close arg - - method private refill (slice : Slice.t) = - slice.off <- 0; - slice.len <- 0; - if !remaining_size > 0 then ( - let sub = fill_buf arg in - let n = - min !remaining_size (min sub.len (Bytes.length slice.bytes)) - in - Bytes.blit sub.bytes sub.off slice.bytes 0 n; - Slice.consume sub n; - remaining_size := !remaining_size - n; - slice.len <- n - ) - end - - (** new stream with maximum size [max_size]. - @param close_rec if true, closing this will also close the input stream *) - let limit_size_to ~close_rec ~max_size ~bytes (arg : t) : t = - reading_exactly_ ~size:max_size ~skip_on_close:false ~bytes ~close_rec arg - - (** New stream that consumes exactly [size] bytes from the input. - If fewer bytes are read before [close] is called, we read and discard - the remaining quota of bytes before [close] returns. - @param close_rec if true, closing this will also close the input stream *) - let reading_exactly ~close_rec ~size ~bytes (arg : t) : t = - reading_exactly_ ~size ~close_rec ~skip_on_close:true ~bytes arg - - let read_chunked ~(bytes : bytes) ~fail (ic : #t) : t = - let first = ref true in - - (* small buffer to read the chunk sizes *) - let line_buf = Buf.create ~size:32 () in - let read_next_chunk_len () : int = - if !first then - first := false - else ( - let line = read_line_using ~buf:line_buf ic in - if String.trim line <> "" then - raise (fail "expected crlf between chunks") - ); - let line = read_line_using ~buf:line_buf ic in - (* parse chunk length, ignore extensions *) - let chunk_size = - if String.trim line = "" then - 0 - else ( - try - let off = ref 0 in - let n = Parse_.pos_hex line off in - n - with _ -> - raise (fail (spf "cannot read chunk size from line %S" line)) - ) - in - chunk_size - in - let eof = ref false in - let chunk_size = ref 0 in - - object - inherit t_from_refill ~bytes () - - method private refill (slice : Slice.t) : unit = - if !chunk_size = 0 && not !eof then ( - chunk_size := read_next_chunk_len (); - if !chunk_size = 0 then ( - (* stream is finished, consume trailing \r\n *) - eof := true; - let line = read_line_using ~buf:line_buf ic in - if String.trim line <> "" then - raise - (fail (spf "expected \\r\\n to follow last chunk, got %S" line)) - ) - ); - slice.off <- 0; - slice.len <- 0; - if !chunk_size > 0 then ( - (* read the whole chunk, or [Bytes.length bytes] of it *) - let to_read = min !chunk_size (Bytes.length slice.bytes) in - read_exactly_ - ~too_short:(fun () -> raise (fail "chunk is too short")) - ic slice.bytes to_read; - slice.len <- to_read; - chunk_size := !chunk_size - to_read - ) - - method close () = eof := true (* do not close underlying stream *) - end - - (** Output a stream using chunked encoding *) - let output_chunked' ?buf (oc : #Iostream.Out_buf.t) (self : #t) : unit = - let oc' = Output.chunk_encoding ?buf oc ~close_rec:false in - match to_chan' oc' self with - | () -> Output.close oc' - | exception e -> - let bt = Printexc.get_raw_backtrace () in - Output.close oc'; - Printexc.raise_with_backtrace e bt - - (** print a stream as a series of chunks *) - let output_chunked ?buf (oc : out_channel) (self : #t) : unit = - output_chunked' ?buf (Output.of_out_channel oc) self -end - -(** A writer abstraction. *) -module Writer = struct - type t = { write: Output.t -> unit } [@@unboxed] - (** Writer. - - A writer is a push-based stream of bytes. - Give it an output channel and it will write the bytes in it. - - This is useful for responses: an http endpoint can return a writer - as its response's body; the writer is given access to the connection - to the client and can write into it as if it were a regular - [out_channel], including controlling calls to [flush]. - Tiny_httpd will convert these writes into valid HTTP chunks. - @since 0.14 - *) - - let[@inline] make ~write () : t = { write } - - (** Write into the channel. *) - let[@inline] write (oc : #Output.t) (self : t) : unit = - self.write (oc :> Output.t) - - (** Empty writer, will output 0 bytes. *) - let empty : t = { write = ignore } - - (** A writer that just emits the bytes from the given string. *) - let[@inline] of_string (str : string) : t = - let write oc = Iostream.Out.output_string oc str in - { write } - - let[@inline] of_input (ic : #Input.t) : t = - { write = (fun oc -> Input.to_chan' oc ic) } + let dummy : t = { with_buf = (fun size f -> f (Bytes.create size)) } end (** A TCP server abstraction. *) module TCP_server = struct - type conn_handler = { - handle: client_addr:Unix.sockaddr -> Input.t -> Output.t -> unit; - (** Handle client connection *) - } + type conn_handler = + client_addr:Sockaddr.t -> Buf_reader.t -> Buf_writer.t -> unit - type t = { - endpoint: unit -> string * int; - (** Endpoint we listen on. This can only be called from within [serve]. *) - active_connections: unit -> int; - (** Number of connections currently active *) - running: unit -> bool; (** Is the server currently running? *) - stop: unit -> unit; - (** Ask the server to stop. This might not take effect immediately, - and is idempotent. After this [server.running()] must return [false]. *) - } - (** A running TCP server. + class type t = object + method endpoint : unit -> Sockaddr.t + method active_connections : unit -> int + method running : unit -> bool + method run : unit -> unit + method stop : unit -> unit + end - This contains some functions that provide information about the running - server, including whether it's active (as opposed to stopped), a function - to stop it, and statistics about the number of connections. *) + let run (self : #t) = self#run () - type builder = { - serve: after_init:(t -> unit) -> handle:conn_handler -> unit -> unit; - (** Blocking call to listen for incoming connections and handle them. - Uses the connection handler [handle] to handle individual client - connections in individual threads/fibers/tasks. - @param after_init is called once with the server after the server - has started. *) - } - (** A TCP server builder implementation. + type state = + | Created + | Running + | Stopped - Calling [builder.serve ~after_init ~handle ()] starts a new TCP server on - an unspecified endpoint - (most likely coming from the function returning this builder) - and returns the running server. *) + let rec accept_ (sock : Unix.file_descr) = + match Unix.accept sock with + | csock, addr -> csock, addr + | exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) -> + Ev_loop.wait_readable sock Cancel_handle.dummy ignore; + accept_ sock + + class base_server ?(listen = 32) ?(buf_pool = Buf_pool.dummy) ~runner + ~(handle : conn_handler) (addr : Sockaddr.t) : + t = + let n_active_ = A.make 0 in + let st = A.make Created in + + object + method endpoint () = addr + method active_connections () = A.get n_active_ + method running () = A.get st = Running + method stop () = if A.exchange st Stopped = Running then ( (* TODO *) ) + + method run () = + (* set to Running *) + let can_start = + let rec loop () = + match A.get st with + | Created -> A.compare_and_set st Created Running || loop () + | _ -> false + in + loop () + in + + if can_start then ( + let sock = + try + let sock = + Unix.socket (Sockaddr.domain addr) Unix.SOCK_STREAM 0 + in + Unix.setsockopt sock Unix.TCP_NODELAY true; + Unix.set_nonblock sock; + Unix.bind sock addr; + Unix.listen sock listen; + sock + with e -> + A.set st Stopped; + raise e + in + while A.get st = Running do + let client_sock, client_addr = accept_ sock in + let client_fd = Fd.create client_sock in + + (* start a fiber to handle the client *) + let _ : _ Fiber.t = + Fiber.spawn_top ~on:runner (fun () -> + A.incr n_active_; + + let@ () = + Fun.protect ~finally:(fun () -> + A.decr n_active_; + Fd.close_noerr client_fd) + in + + let@ buf_in = buf_pool.with_buf 4096 in + let@ buf_out = buf_pool.with_buf 4096 in + + let ic = Buf_reader.of_fd client_fd ~buf:buf_in in + let oc = Buf_writer.of_fd client_fd ~buf:buf_out in + handle ~client_addr ic oc) + in + () + done + ) + end + + let create ?(after_init = ignore) ?listen ?buf_pool ~runner + ~(handle : conn_handler) (addr : Sockaddr.t) : t = + let self = new base_server ?listen ?buf_pool ~runner ~handle addr in + after_init self; + self +end + +module TCP_client = struct + (** connect asynchronously *) + let rec connect_ sock addr = + match Unix.connect sock addr with + | () -> () + | exception + Unix.Unix_error + ((Unix.EWOULDBLOCK | Unix.EINPROGRESS | Unix.EAGAIN), _, _) -> + Ev_loop.wait_writable sock Cancel_handle.dummy ignore; + connect_ sock addr + + let with_connect' addr (f : Fd.t -> 'a) : 'a = + let sock = Unix.socket (Sockaddr.domain addr) Unix.SOCK_STREAM 0 in + Unix.set_nonblock sock; + Unix.setsockopt sock Unix.TCP_NODELAY true; + + connect_ sock addr; + let sock = Fd.create sock in + + let finally () = Fd.close_noerr sock in + let@ () = Fun.protect ~finally in + f sock + + let with_connect ?(buf_pool = Buf_pool.dummy) addr (f : _ -> _ -> 'a) : 'a = + with_connect' addr (fun sock -> + let@ buf_in = buf_pool.with_buf 4096 in + let@ buf_out = buf_pool.with_buf 4096 in + + let ic = Buf_reader.of_fd ~buf:buf_in sock in + let oc = Buf_writer.of_fd ~buf:buf_out sock in + f ic oc) end -*) diff --git a/src/unix/async_io.mli b/src/unix/async_io.mli index 2fe04d52..5860eee4 100644 --- a/src/unix/async_io.mli +++ b/src/unix/async_io.mli @@ -1,4 +1,4 @@ -module Slice = Iostream.Slice +module Slice = Iostream_types.Slice val read : Fd.t -> bytes -> int -> int -> int (** Non blocking read *) @@ -10,7 +10,7 @@ val write : Fd.t -> bytes -> int -> int -> unit module Reader : sig include module type of struct - include Iostream.In + include Iostream_types.In end val of_fd : ?close_noerr:bool -> Fd.t -> t @@ -19,7 +19,7 @@ end module Buf_reader : sig include module type of struct - include Iostream.In_buf + include Iostream_types.In_buf end val of_fd : ?close_noerr:bool -> buf:bytes -> Fd.t -> t @@ -27,7 +27,7 @@ end module Writer : sig include module type of struct - include Iostream.Out + include Iostream_types.Out end val of_fd : ?close_noerr:bool -> Fd.t -> t @@ -35,8 +35,72 @@ end module Buf_writer : sig include module type of struct - include Iostream.Out_buf + include Iostream_types.Out_buf end val of_fd : ?close_noerr:bool -> buf:bytes -> Fd.t -> t end + +module Buf_pool : sig + type t = { with_buf: 'a. int -> (bytes -> 'a) -> 'a } [@@unboxed] + + val dummy : t + (** Just allocate on the fly, no pooling *) +end + +module TCP_client : sig + val with_connect' : Sockaddr.t -> (Fd.t -> 'a) -> 'a + + val with_connect : + ?buf_pool:Buf_pool.t -> + Sockaddr.t -> + (Buf_reader.t -> Buf_writer.t -> 'a) -> + 'a +end + +module TCP_server : sig + type conn_handler = + client_addr:Sockaddr.t -> Buf_reader.t -> Buf_writer.t -> unit + (** Handle client connection *) + + (** A running TCP server. + + This contains some functions that provide information about the running + server, including whether it's active (as opposed to stopped), a function + to stop it, and statistics about the number of connections. *) + class type t = object + method endpoint : unit -> Sockaddr.t + (** Endpoint we listen on. This can only be called from within [serve]. *) + + method active_connections : unit -> int + (** Number of connections currently active *) + + method running : unit -> bool + (** Is the server currently running? *) + + method run : unit -> unit + + method stop : unit -> unit + (** Ask the server to stop. This might not take effect immediately, + and is idempotent. After this [server.running()] must return [false]. *) + end + + class base_server : + ?listen:int -> + ?buf_pool:Buf_pool.t -> + runner:Moonpool.Runner.t -> + handle:conn_handler -> + Sockaddr.t -> + t + + val create : + ?after_init:(t -> unit) -> + ?listen:int -> + ?buf_pool:Buf_pool.t -> + runner:Moonpool.Runner.t -> + handle:conn_handler -> + Sockaddr.t -> + t + + val run : #t -> unit +end diff --git a/src/unix/dune b/src/unix/dune index bd609139..eaf20b56 100644 --- a/src/unix/dune +++ b/src/unix/dune @@ -8,6 +8,7 @@ moonpool.fib unix iostream + (re_export iostream.types) (select time.ml from diff --git a/src/unix/sockaddr.ml b/src/unix/sockaddr.ml new file mode 100644 index 00000000..6089cefd --- /dev/null +++ b/src/unix/sockaddr.ml @@ -0,0 +1,16 @@ +type t = Unix.sockaddr + +let show = function + | Unix.ADDR_UNIX s -> s + | Unix.ADDR_INET (addr, port) -> + Printf.sprintf "%s:%d" (Unix.string_of_inet_addr addr) port + +let pp out (self : t) = Format.pp_print_string out (show self) + +let domain = function + | Unix.ADDR_UNIX _ -> Unix.PF_UNIX + | Unix.ADDR_INET (a, _) -> + if Unix.is_inet6_addr a then + Unix.PF_INET6 + else + Unix.PF_INET