From 8bfe76b3e0bb2c34ef56035588f4b8ad7e8141ec Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 17 Feb 2024 12:37:03 -0500 Subject: [PATCH] improve docs for moonpool_lwt; fix race condition --- src/core/suspend_.mli | 2 -- src/lwt/IO.ml | 52 ++++++++++++++------------- src/lwt/moonpool_lwt.mli | 78 +++++++++++++++++++++++++++++----------- src/lwt/tcp_client.ml | 30 +++++++--------- src/lwt/tcp_server.ml | 2 +- 5 files changed, 100 insertions(+), 64 deletions(-) diff --git a/src/core/suspend_.mli b/src/core/suspend_.mli index bd8a9a9d..45d4bc97 100644 --- a/src/core/suspend_.mli +++ b/src/core/suspend_.mli @@ -3,8 +3,6 @@ This module is an implementation detail of Moonpool and should not be used outside of it, except by experts to implement {!Runner}. *) -open Types_ - type suspension = unit Exn_bt.result -> unit (** A suspended computation *) diff --git a/src/lwt/IO.ml b/src/lwt/IO.ml index 09ae6d07..4a8acc69 100644 --- a/src/lwt/IO.ml +++ b/src/lwt/IO.ml @@ -1,45 +1,49 @@ open Base +let await_readable fd : unit = + Moonpool.Private.Suspend_.suspend + { + handle = + (fun ~run:_ ~resume sus -> + Perform_action_in_lwt.schedule + @@ Action.Wait_readable + ( fd, + fun cancel -> + resume sus @@ Ok (); + Lwt_engine.stop_event cancel )); + } + let rec read fd buf i len : int = if len = 0 then 0 else ( match Unix.read fd buf i len with | exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) -> - (* wait for FD to be ready *) - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Wait_readable - ( fd, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - }; + await_readable fd; read fd buf i len | n -> n ) +let await_writable fd = + Moonpool.Private.Suspend_.suspend + { + handle = + (fun ~run:_ ~resume sus -> + Perform_action_in_lwt.schedule + @@ Action.Wait_writable + ( fd, + fun cancel -> + resume sus @@ Ok (); + Lwt_engine.stop_event cancel )); + } + let rec write_once fd buf i len : int = if len = 0 then 0 else ( match Unix.write fd buf i len with | exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) -> - (* wait for FD to be ready *) - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Wait_writable - ( fd, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - }; + await_writable fd; write_once fd buf i len | n -> n ) diff --git a/src/lwt/moonpool_lwt.mli b/src/lwt/moonpool_lwt.mli index cff4dbd1..ac218e0c 100644 --- a/src/lwt/moonpool_lwt.mli +++ b/src/lwt/moonpool_lwt.mli @@ -1,4 +1,10 @@ -(** Lwt_engine-based event loop for Moonpool *) +(** Lwt_engine-based event loop for Moonpool. + + In what follows, we mean by "lwt thread" the thread + running [Lwt_main.run] (so, the thread where the Lwt event + loop and all Lwt callbacks execute). + + @since NEXT_RELEASE *) module Fiber = Moonpool_fib.Fiber module FLS = Moonpool_fib.Fls @@ -7,26 +13,31 @@ module FLS = Moonpool_fib.Fls val fut_of_lwt : 'a Lwt.t -> 'a Moonpool.Fut.t (** [fut_of_lwt lwt_fut] makes a thread-safe moonpool future that - completes when [lwt_fut] does *) + completes when [lwt_fut] does. This must be run from within + the Lwt thread. *) val lwt_of_fut : 'a Moonpool.Fut.t -> 'a Lwt.t (** [lwt_of_fut fut] makes a lwt future that completes when - [fut] does. The result should be used only from inside the - thread running [Lwt_main.run]. *) + [fut] does. This must be called from the Lwt thread, and the result + must always be used only from inside the Lwt thread. *) (** {2 Helpers on the moonpool side} *) val await_lwt : 'a Lwt.t -> 'a (** [await_lwt fut] awaits a Lwt future from inside a task running on - a moonpool runner. This must be run from within moonpool. *) + a moonpool runner. This must be run from within a Moonpool runner + so that the await-ing effect is handled. *) val run_in_lwt : (unit -> 'a Lwt.t) -> 'a Moonpool.Fut.t (** [run_in_lwt f] runs [f()] from within the Lwt thread - and returns a thread-safe future. *) + and returns a thread-safe future. This can be run from anywhere. *) val run_in_lwt_and_await : (unit -> 'a Lwt.t) -> 'a (** [run_in_lwt_and_await f] runs [f] in the Lwt thread, and - awaits its result. Must be run from inside a moonpool runner. *) + awaits its result. Must be run from inside a moonpool runner + so that the await-in effect is handled. + + This is similar to [Moonpool.await @@ run_in_lwt f]. *) val get_runner : unit -> Moonpool.Runner.t (** Returns the runner from within which this is called. @@ -41,22 +52,54 @@ val get_runner : unit -> Moonpool.Runner.t and rely on a [Lwt_engine] event loop being active (meaning, [Lwt_main.run] is currently running in some thread). - Calling these functions must be done from a moonpool runner and - will suspend the current task/fut/fiber if the FD is not ready. + Calling these functions must be done from a moonpool runner. + A function like [read] will first try to perform the IO action + directly (here, call {!Unix.read}); if the action fails because + the FD is not ready, then [await_readable] is called: + it suspends the fiber and subscribes it to Lwt to be awakened + when the FD becomes ready. *) module IO : sig val read : Unix.file_descr -> bytes -> int -> int -> int + (** Read from the file descriptor *) + + val await_readable : Unix.file_descr -> unit + (** Suspend the fiber until the FD is readable *) + val write_once : Unix.file_descr -> bytes -> int -> int -> int + (** Perform one write into the file descriptor *) + + val await_writable : Unix.file_descr -> unit + (** Suspend the fiber until the FD is writable *) + val write : Unix.file_descr -> bytes -> int -> int -> unit + (** Loop around {!write_once} to write the entire slice. *) + val sleep_s : float -> unit + (** Suspend the fiber for [n] seconds. *) end module IO_in = IO_in +(** Input channel *) + module IO_out = IO_out +(** Output channel *) module TCP_server : sig type t = Lwt_io.server + val establish_lwt : + ?backlog:(* ?server_fd:Unix.file_descr -> *) + int -> + ?no_close:bool -> + runner:Moonpool.Runner.t -> + Unix.sockaddr -> + (Unix.sockaddr -> Lwt_io.input_channel -> Lwt_io.output_channel -> unit) -> + t + (** [establish ~runner addr handler] runs a TCP server in the Lwt + thread. When a client connects, a moonpool fiber is started on [runner] + to handle it. *) + val establish : ?backlog:(* ?server_fd:Unix.file_descr -> *) int -> @@ -65,26 +108,21 @@ module TCP_server : sig Unix.sockaddr -> (Unix.sockaddr -> IO_in.t -> IO_out.t -> unit) -> t - - val establish' : - ?backlog:(* ?server_fd:Unix.file_descr -> *) - int -> - ?no_close:bool -> - runner:Moonpool.Runner.t -> - Unix.sockaddr -> - (Unix.sockaddr -> Lwt_io.input_channel -> Lwt_io.output_channel -> unit) -> - t + (** Like {!establish_lwt} but uses {!IO} to directly handle + reads and writes on client sockets. *) val shutdown : t -> unit + (** Shutdown the server *) end module TCP_client : sig val connect : Unix.sockaddr -> Unix.file_descr val with_connect : Unix.sockaddr -> (IO_in.t -> IO_out.t -> 'a) -> 'a - (** Open a connection. *) + (** Open a connection, and use {!IO} to read and write from + the socket in a non blocking way. *) - val with_connect' : + val with_connect_lwt : Unix.sockaddr -> (Lwt_io.input_channel -> Lwt_io.output_channel -> 'a) -> 'a (** Open a connection. *) end diff --git a/src/lwt/tcp_client.ml b/src/lwt/tcp_client.ml index c7db3880..8aec16f2 100644 --- a/src/lwt/tcp_client.ml +++ b/src/lwt/tcp_client.ml @@ -14,17 +14,7 @@ let connect addr : Unix.file_descr = with | Unix.Unix_error ((Unix.EWOULDBLOCK | Unix.EINPROGRESS | Unix.EAGAIN), _, _) -> - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Wait_writable - ( sock, - fun ev -> - resume sus @@ Ok (); - Lwt_engine.stop_event ev )); - }; + IO.await_writable sock; true do () @@ -41,16 +31,22 @@ let with_connect addr (f : IO_in.t -> IO_out.t -> 'a) : 'a = let@ () = Fun.protect ~finally in f ic oc -let with_connect' addr (f : Lwt_io.input_channel -> Lwt_io.output_channel -> 'a) - : 'a = +let with_connect_lwt addr + (f : Lwt_io.input_channel -> Lwt_io.output_channel -> 'a) : 'a = let sock = connect addr in - let ic = Lwt_io.of_unix_fd ~mode:Lwt_io.input sock in - let oc = Lwt_io.of_unix_fd ~mode:Lwt_io.output sock in + let ic = + run_in_lwt_and_await (fun () -> + Lwt.return @@ Lwt_io.of_unix_fd ~mode:Lwt_io.input sock) + in + let oc = + run_in_lwt_and_await (fun () -> + Lwt.return @@ Lwt_io.of_unix_fd ~mode:Lwt_io.output sock) + in let finally () = - (try Lwt_io.close ic |> await_lwt with _ -> ()); - (try Lwt_io.close oc |> await_lwt with _ -> ()); + (try run_in_lwt_and_await (fun () -> Lwt_io.close ic) with _ -> ()); + (try run_in_lwt_and_await (fun () -> Lwt_io.close oc) with _ -> ()); try Unix.close sock with _ -> () in let@ () = Fun.protect ~finally in diff --git a/src/lwt/tcp_server.ml b/src/lwt/tcp_server.ml index 2b6605b0..22fa9253 100644 --- a/src/lwt/tcp_server.ml +++ b/src/lwt/tcp_server.ml @@ -3,7 +3,7 @@ open Base type t = Lwt_io.server -let establish' ?backlog ?no_close ~runner addr handler : t = +let establish_lwt ?backlog ?no_close ~runner addr handler : t = let server = Lwt_io.establish_server_with_client_socket ?backlog ?no_close addr (fun client_addr client_sock ->