improve docs for moonpool_lwt; fix race condition

This commit is contained in:
Simon Cruanes 2024-02-17 12:37:03 -05:00
parent 283a1cb118
commit 8bfe76b3e0
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
5 changed files with 100 additions and 64 deletions

View file

@ -3,8 +3,6 @@
This module is an implementation detail of Moonpool and should This module is an implementation detail of Moonpool and should
not be used outside of it, except by experts to implement {!Runner}. *) not be used outside of it, except by experts to implement {!Runner}. *)
open Types_
type suspension = unit Exn_bt.result -> unit type suspension = unit Exn_bt.result -> unit
(** A suspended computation *) (** A suspended computation *)

View file

@ -1,12 +1,6 @@
open Base open Base
let rec read fd buf i len : int = let await_readable fd : unit =
if len = 0 then
0
else (
match Unix.read fd buf i len with
| exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
(* wait for FD to be ready *)
Moonpool.Private.Suspend_.suspend Moonpool.Private.Suspend_.suspend
{ {
handle = handle =
@ -17,18 +11,20 @@ let rec read fd buf i len : int =
fun cancel -> fun cancel ->
resume sus @@ Ok (); resume sus @@ Ok ();
Lwt_engine.stop_event cancel )); Lwt_engine.stop_event cancel ));
}; }
let rec read fd buf i len : int =
if len = 0 then
0
else (
match Unix.read fd buf i len with
| exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
await_readable fd;
read fd buf i len read fd buf i len
| n -> n | n -> n
) )
let rec write_once fd buf i len : int = let await_writable fd =
if len = 0 then
0
else (
match Unix.write fd buf i len with
| exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
(* wait for FD to be ready *)
Moonpool.Private.Suspend_.suspend Moonpool.Private.Suspend_.suspend
{ {
handle = handle =
@ -39,7 +35,15 @@ let rec write_once fd buf i len : int =
fun cancel -> fun cancel ->
resume sus @@ Ok (); resume sus @@ Ok ();
Lwt_engine.stop_event cancel )); Lwt_engine.stop_event cancel ));
}; }
let rec write_once fd buf i len : int =
if len = 0 then
0
else (
match Unix.write fd buf i len with
| exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
await_writable fd;
write_once fd buf i len write_once fd buf i len
| n -> n | n -> n
) )

View file

@ -1,4 +1,10 @@
(** Lwt_engine-based event loop for Moonpool *) (** Lwt_engine-based event loop for Moonpool.
In what follows, we mean by "lwt thread" the thread
running [Lwt_main.run] (so, the thread where the Lwt event
loop and all Lwt callbacks execute).
@since NEXT_RELEASE *)
module Fiber = Moonpool_fib.Fiber module Fiber = Moonpool_fib.Fiber
module FLS = Moonpool_fib.Fls module FLS = Moonpool_fib.Fls
@ -7,26 +13,31 @@ module FLS = Moonpool_fib.Fls
val fut_of_lwt : 'a Lwt.t -> 'a Moonpool.Fut.t val fut_of_lwt : 'a Lwt.t -> 'a Moonpool.Fut.t
(** [fut_of_lwt lwt_fut] makes a thread-safe moonpool future that (** [fut_of_lwt lwt_fut] makes a thread-safe moonpool future that
completes when [lwt_fut] does *) completes when [lwt_fut] does. This must be run from within
the Lwt thread. *)
val lwt_of_fut : 'a Moonpool.Fut.t -> 'a Lwt.t val lwt_of_fut : 'a Moonpool.Fut.t -> 'a Lwt.t
(** [lwt_of_fut fut] makes a lwt future that completes when (** [lwt_of_fut fut] makes a lwt future that completes when
[fut] does. The result should be used only from inside the [fut] does. This must be called from the Lwt thread, and the result
thread running [Lwt_main.run]. *) must always be used only from inside the Lwt thread. *)
(** {2 Helpers on the moonpool side} *) (** {2 Helpers on the moonpool side} *)
val await_lwt : 'a Lwt.t -> 'a val await_lwt : 'a Lwt.t -> 'a
(** [await_lwt fut] awaits a Lwt future from inside a task running on (** [await_lwt fut] awaits a Lwt future from inside a task running on
a moonpool runner. This must be run from within moonpool. *) a moonpool runner. This must be run from within a Moonpool runner
so that the await-ing effect is handled. *)
val run_in_lwt : (unit -> 'a Lwt.t) -> 'a Moonpool.Fut.t val run_in_lwt : (unit -> 'a Lwt.t) -> 'a Moonpool.Fut.t
(** [run_in_lwt f] runs [f()] from within the Lwt thread (** [run_in_lwt f] runs [f()] from within the Lwt thread
and returns a thread-safe future. *) and returns a thread-safe future. This can be run from anywhere. *)
val run_in_lwt_and_await : (unit -> 'a Lwt.t) -> 'a val run_in_lwt_and_await : (unit -> 'a Lwt.t) -> 'a
(** [run_in_lwt_and_await f] runs [f] in the Lwt thread, and (** [run_in_lwt_and_await f] runs [f] in the Lwt thread, and
awaits its result. Must be run from inside a moonpool runner. *) awaits its result. Must be run from inside a moonpool runner
so that the await-in effect is handled.
This is similar to [Moonpool.await @@ run_in_lwt f]. *)
val get_runner : unit -> Moonpool.Runner.t val get_runner : unit -> Moonpool.Runner.t
(** Returns the runner from within which this is called. (** Returns the runner from within which this is called.
@ -41,22 +52,54 @@ val get_runner : unit -> Moonpool.Runner.t
and rely on a [Lwt_engine] event loop being active (meaning, and rely on a [Lwt_engine] event loop being active (meaning,
[Lwt_main.run] is currently running in some thread). [Lwt_main.run] is currently running in some thread).
Calling these functions must be done from a moonpool runner and Calling these functions must be done from a moonpool runner.
will suspend the current task/fut/fiber if the FD is not ready. A function like [read] will first try to perform the IO action
directly (here, call {!Unix.read}); if the action fails because
the FD is not ready, then [await_readable] is called:
it suspends the fiber and subscribes it to Lwt to be awakened
when the FD becomes ready.
*) *)
module IO : sig module IO : sig
val read : Unix.file_descr -> bytes -> int -> int -> int val read : Unix.file_descr -> bytes -> int -> int -> int
(** Read from the file descriptor *)
val await_readable : Unix.file_descr -> unit
(** Suspend the fiber until the FD is readable *)
val write_once : Unix.file_descr -> bytes -> int -> int -> int val write_once : Unix.file_descr -> bytes -> int -> int -> int
(** Perform one write into the file descriptor *)
val await_writable : Unix.file_descr -> unit
(** Suspend the fiber until the FD is writable *)
val write : Unix.file_descr -> bytes -> int -> int -> unit val write : Unix.file_descr -> bytes -> int -> int -> unit
(** Loop around {!write_once} to write the entire slice. *)
val sleep_s : float -> unit val sleep_s : float -> unit
(** Suspend the fiber for [n] seconds. *)
end end
module IO_in = IO_in module IO_in = IO_in
(** Input channel *)
module IO_out = IO_out module IO_out = IO_out
(** Output channel *)
module TCP_server : sig module TCP_server : sig
type t = Lwt_io.server type t = Lwt_io.server
val establish_lwt :
?backlog:(* ?server_fd:Unix.file_descr -> *)
int ->
?no_close:bool ->
runner:Moonpool.Runner.t ->
Unix.sockaddr ->
(Unix.sockaddr -> Lwt_io.input_channel -> Lwt_io.output_channel -> unit) ->
t
(** [establish ~runner addr handler] runs a TCP server in the Lwt
thread. When a client connects, a moonpool fiber is started on [runner]
to handle it. *)
val establish : val establish :
?backlog:(* ?server_fd:Unix.file_descr -> *) ?backlog:(* ?server_fd:Unix.file_descr -> *)
int -> int ->
@ -65,26 +108,21 @@ module TCP_server : sig
Unix.sockaddr -> Unix.sockaddr ->
(Unix.sockaddr -> IO_in.t -> IO_out.t -> unit) -> (Unix.sockaddr -> IO_in.t -> IO_out.t -> unit) ->
t t
(** Like {!establish_lwt} but uses {!IO} to directly handle
val establish' : reads and writes on client sockets. *)
?backlog:(* ?server_fd:Unix.file_descr -> *)
int ->
?no_close:bool ->
runner:Moonpool.Runner.t ->
Unix.sockaddr ->
(Unix.sockaddr -> Lwt_io.input_channel -> Lwt_io.output_channel -> unit) ->
t
val shutdown : t -> unit val shutdown : t -> unit
(** Shutdown the server *)
end end
module TCP_client : sig module TCP_client : sig
val connect : Unix.sockaddr -> Unix.file_descr val connect : Unix.sockaddr -> Unix.file_descr
val with_connect : Unix.sockaddr -> (IO_in.t -> IO_out.t -> 'a) -> 'a val with_connect : Unix.sockaddr -> (IO_in.t -> IO_out.t -> 'a) -> 'a
(** Open a connection. *) (** Open a connection, and use {!IO} to read and write from
the socket in a non blocking way. *)
val with_connect' : val with_connect_lwt :
Unix.sockaddr -> (Lwt_io.input_channel -> Lwt_io.output_channel -> 'a) -> 'a Unix.sockaddr -> (Lwt_io.input_channel -> Lwt_io.output_channel -> 'a) -> 'a
(** Open a connection. *) (** Open a connection. *)
end end

View file

@ -14,17 +14,7 @@ let connect addr : Unix.file_descr =
with with
| Unix.Unix_error ((Unix.EWOULDBLOCK | Unix.EINPROGRESS | Unix.EAGAIN), _, _) | Unix.Unix_error ((Unix.EWOULDBLOCK | Unix.EINPROGRESS | Unix.EAGAIN), _, _)
-> ->
Moonpool.Private.Suspend_.suspend IO.await_writable sock;
{
handle =
(fun ~run:_ ~resume sus ->
Perform_action_in_lwt.schedule
@@ Action.Wait_writable
( sock,
fun ev ->
resume sus @@ Ok ();
Lwt_engine.stop_event ev ));
};
true true
do do
() ()
@ -41,16 +31,22 @@ let with_connect addr (f : IO_in.t -> IO_out.t -> 'a) : 'a =
let@ () = Fun.protect ~finally in let@ () = Fun.protect ~finally in
f ic oc f ic oc
let with_connect' addr (f : Lwt_io.input_channel -> Lwt_io.output_channel -> 'a) let with_connect_lwt addr
: 'a = (f : Lwt_io.input_channel -> Lwt_io.output_channel -> 'a) : 'a =
let sock = connect addr in let sock = connect addr in
let ic = Lwt_io.of_unix_fd ~mode:Lwt_io.input sock in let ic =
let oc = Lwt_io.of_unix_fd ~mode:Lwt_io.output sock in run_in_lwt_and_await (fun () ->
Lwt.return @@ Lwt_io.of_unix_fd ~mode:Lwt_io.input sock)
in
let oc =
run_in_lwt_and_await (fun () ->
Lwt.return @@ Lwt_io.of_unix_fd ~mode:Lwt_io.output sock)
in
let finally () = let finally () =
(try Lwt_io.close ic |> await_lwt with _ -> ()); (try run_in_lwt_and_await (fun () -> Lwt_io.close ic) with _ -> ());
(try Lwt_io.close oc |> await_lwt with _ -> ()); (try run_in_lwt_and_await (fun () -> Lwt_io.close oc) with _ -> ());
try Unix.close sock with _ -> () try Unix.close sock with _ -> ()
in in
let@ () = Fun.protect ~finally in let@ () = Fun.protect ~finally in

View file

@ -3,7 +3,7 @@ open Base
type t = Lwt_io.server type t = Lwt_io.server
let establish' ?backlog ?no_close ~runner addr handler : t = let establish_lwt ?backlog ?no_close ~runner addr handler : t =
let server = let server =
Lwt_io.establish_server_with_client_socket ?backlog ?no_close addr Lwt_io.establish_server_with_client_socket ?backlog ?no_close addr
(fun client_addr client_sock -> (fun client_addr client_sock ->