feat lwt: make most functions work on any thread, not just the main

This commit is contained in:
Simon Cruanes 2025-09-04 14:46:35 -04:00
parent 786d75d680
commit 122b3a6b06
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 122 additions and 69 deletions

View file

@ -43,7 +43,7 @@ module Scheduler_state = struct
has_notified = Atomic.make false;
}
let add_action_from_another_thread_ (self : st) f : unit =
let[@inline never] add_action_from_another_thread_ (self : st) f : unit =
Mutex.lock st.mutex;
Queue.push f self.actions_from_other_threads;
Mutex.unlock st.mutex;
@ -56,12 +56,17 @@ module Ops = struct
let around_task _ = default_around_task_
let schedule (self : st) t =
if Thread.id (Thread.self ()) = self.thread then
Queue.push t self.tasks
let[@inline] on_lwt_thread_ (self : st) : bool =
Thread.id (Thread.self ()) = self.thread
let[@inline] run_on_lwt_thread_ (self : st) (f : unit -> unit) : unit =
if on_lwt_thread_ self then
f ()
else
Scheduler_state.add_action_from_another_thread_ self (fun () ->
Queue.push t self.tasks)
Scheduler_state.add_action_from_another_thread_ self f
let schedule (self : st) t =
run_on_lwt_thread_ self (fun () -> Queue.push t self.tasks)
let get_next_task (self : st) =
if self.closed then raise WL.No_more_tasks;
@ -74,8 +79,10 @@ module Ops = struct
Moonpool.Runner.For_runner_implementors.create
~size:(fun () -> 1)
~num_tasks:(fun () ->
(* FIXME: thread safety. use an atomic?? *)
Queue.length self.tasks)
Mutex.lock self.mutex;
let n = Queue.length self.tasks in
Mutex.unlock self.mutex;
n)
~run_async:(fun ~fiber f -> schedule self @@ WL.T_start { fiber; f })
~shutdown:(fun ~wait:_ () -> self.closed <- true)
()
@ -113,20 +120,41 @@ open struct
()
end
let await_lwt (fut : _ Lwt.t) =
(** Resolve [prom] with the result of [lwt_fut] *)
let transfer_lwt_to_fut (lwt_fut : 'a Lwt.t) (prom : 'a Fut.promise) : unit =
Lwt.on_any lwt_fut
(fun x -> M.Fut.fulfill prom (Ok x))
(fun exn ->
let bt = Printexc.get_callstack 10 in
M.Fut.fulfill prom (Error (Exn_bt.make exn bt)))
let[@inline] register_trigger_on_lwt_termination (lwt_fut : _ Lwt.t)
(tr : M.Trigger.t) : unit =
Lwt.on_termination lwt_fut (fun _ -> M.Trigger.signal tr)
let[@inline] await_lwt_terminated (fut : _ Lwt.t) =
match Lwt.state fut with
| Return x -> x
| Fail exn -> raise exn
| Sleep ->
(* suspend fiber, wake it up when [fut] resolves *)
let trigger = M.Trigger.create () in
Lwt.on_termination fut (fun _ -> M.Trigger.signal trigger);
M.Trigger.await trigger |> Option.iter Exn_bt.raise;
| Sleep -> assert false
(match Lwt.state fut with
let await_lwt (fut : _ Lwt.t) =
if Ops.on_lwt_thread_ Scheduler_state.st then (
match Lwt.state fut with
| Return x -> x
| Fail exn -> raise exn
| Sleep -> assert false)
| Sleep ->
let tr = M.Trigger.create () in
register_trigger_on_lwt_termination fut tr;
M.Trigger.await_exn tr;
await_lwt_terminated fut
) else (
let tr = M.Trigger.create () in
Scheduler_state.add_action_from_another_thread_ Scheduler_state.st
(fun () -> register_trigger_on_lwt_termination fut tr);
M.Trigger.await_exn tr;
await_lwt_terminated fut
)
let lwt_of_fut (fut : 'a M.Fut.t) : 'a Lwt.t =
let lwt_fut, lwt_prom = Lwt.wait () in
@ -140,62 +168,83 @@ let lwt_of_fut (fut : 'a M.Fut.t) : 'a Lwt.t =
in
M.Fut.on_result fut (fun res ->
if Thread.id (Thread.self ()) = Scheduler_state.st.thread then
(* can safely wakeup from the lwt thread *)
wakeup_using_res res
else
Scheduler_state.add_action_from_another_thread_ Scheduler_state.st
(fun () -> wakeup_using_res res));
Ops.run_on_lwt_thread_ Scheduler_state.st (fun () ->
(* can safely wakeup from the lwt thread *)
wakeup_using_res res));
lwt_fut
let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t =
match Lwt.poll lwt_fut with
| Some x -> M.Fut.return x
| None ->
if Ops.on_lwt_thread_ Scheduler_state.st then (
match Lwt.state lwt_fut with
| Return x -> M.Fut.return x
| _ ->
let fut, prom = M.Fut.make () in
transfer_lwt_to_fut lwt_fut prom;
fut
) else (
let fut, prom = M.Fut.make () in
Lwt.on_any lwt_fut
(fun x -> M.Fut.fulfill prom (Ok x))
(fun exn ->
let bt = Printexc.get_callstack 10 in
M.Fut.fulfill prom (Error (Exn_bt.make exn bt)));
Scheduler_state.add_action_from_another_thread_ Scheduler_state.st
(fun () -> transfer_lwt_to_fut lwt_fut prom);
fut
let run_in_hook () =
(* execute actions sent from other threads; first transfer them
all atomically to a local queue to reduce contention *)
let local_acts = Queue.create () in
Mutex.lock Scheduler_state.st.mutex;
Queue.transfer Scheduler_state.st.actions_from_other_threads local_acts;
Atomic.set Scheduler_state.st.has_notified false;
Mutex.unlock Scheduler_state.st.mutex;
Queue.iter (fun f -> f ()) local_acts;
(* run tasks *)
FG.run ~max_tasks:1000 ();
if not (Queue.is_empty Scheduler_state.st.tasks) then
ignore (Lwt.pause () : unit Lwt.t);
()
let is_setup_ = Atomic.make false
let setup () =
if not (Atomic.exchange is_setup_ true) then (
(* only one thread does this *)
FG.setup ~block_signals:false ();
Scheduler_state.st.enter_hook <-
Some (Lwt_main.Enter_iter_hooks.add_last run_in_hook);
Scheduler_state.st.leave_hook <-
Some (Lwt_main.Leave_iter_hooks.add_last run_in_hook);
(* notification used to wake lwt up *)
Scheduler_state.st.notification <-
Lwt_unix.make_notification ~once:false run_in_hook
)
let run_in_lwt_and_await (f : unit -> 'a Lwt.t) : 'a =
if Ops.on_lwt_thread_ Scheduler_state.st then (
let fut = f () in
await_lwt fut
) else (
let fut, prom = Fut.make () in
Scheduler_state.add_action_from_another_thread_ Scheduler_state.st
(fun () ->
let lwt_fut = f () in
transfer_lwt_to_fut lwt_fut prom);
Fut.await fut
)
module Setup_lwt_hooks = struct
let run_in_hook () =
(* execute actions sent from other threads; first transfer them
all atomically to a local queue to reduce contention *)
let local_acts = Queue.create () in
Mutex.lock Scheduler_state.st.mutex;
Queue.transfer Scheduler_state.st.actions_from_other_threads local_acts;
Atomic.set Scheduler_state.st.has_notified false;
Mutex.unlock Scheduler_state.st.mutex;
Queue.iter (fun f -> f ()) local_acts;
(* run tasks *)
FG.run ~max_tasks:1000 ();
if not (Queue.is_empty Scheduler_state.st.tasks) then
ignore (Lwt.pause () : unit Lwt.t);
()
let is_setup_ = Atomic.make false
let[@inline] is_setup () : bool = Atomic.get is_setup_
let setup () =
if not (Atomic.exchange is_setup_ true) then (
(* only one thread does this *)
FG.setup ~block_signals:false ();
Scheduler_state.st.thread <- Thread.self () |> Thread.id;
Scheduler_state.st.enter_hook <-
Some (Lwt_main.Enter_iter_hooks.add_last run_in_hook);
Scheduler_state.st.leave_hook <-
Some (Lwt_main.Leave_iter_hooks.add_last run_in_hook);
(* notification used to wake lwt up *)
Scheduler_state.st.notification <-
Lwt_unix.make_notification ~once:false run_in_hook
) else if not (Ops.on_lwt_thread_ Scheduler_state.st) then
(* sanity check failed *)
failwith "moonpool-lwt.setup: called again on a different thread"
end
let spawn_lwt f : _ Lwt.t =
setup ();
if not (Setup_lwt_hooks.is_setup ()) then
failwith "spawn_lwt: scheduler was not setup";
let lwt_fut, lwt_prom = Lwt.wait () in
M.Runner.run_async Scheduler_state.st.as_runner (fun () ->
try
@ -205,11 +254,11 @@ let spawn_lwt f : _ Lwt.t =
lwt_fut
let lwt_main (f : _ -> 'a) : 'a =
setup ();
Scheduler_state.st.thread <- Thread.self () |> Thread.id;
Setup_lwt_hooks.setup ();
let fut = spawn_lwt (fun () -> f Scheduler_state.st.as_runner) in
Lwt_main.run fut
let lwt_main_runner () =
if not (Atomic.get is_setup_) then failwith "lwt_main_runner: not setup yet";
let[@inline] lwt_main_runner () =
if not (Setup_lwt_hooks.is_setup ()) then
failwith "lwt_main_runner: scheduler was not setup";
Scheduler_state.st.as_runner

View file

@ -23,13 +23,16 @@ val lwt_of_fut : 'a Moonpool.Fut.t -> 'a Lwt.t
(** {2 Helpers on the moonpool side} *)
val spawn_lwt : (unit -> 'a) -> 'a Lwt.t
(** This spawns a task that runs in the Lwt scheduler *)
(** This spawns a task that runs in the Lwt scheduler.
@raise Failure if {!lwt_main} was not called. *)
val await_lwt : 'a Lwt.t -> 'a
(** [await_lwt fut] awaits a Lwt future from inside a task running on a moonpool
runner. This must be run from within a Moonpool runner so that the await-ing
effect is handled. *)
val run_in_lwt_and_await : (unit -> 'a Lwt.t) -> 'a
(** {2 Wrappers around Lwt_main} *)
val on_uncaught_exn : (Moonpool.Exn_bt.t -> unit) ref
@ -39,4 +42,5 @@ val lwt_main : (Moonpool.Runner.t -> 'a) -> 'a
val lwt_main_runner : unit -> Moonpool.Runner.t
(** The runner from {!lwt_main}. The runner is only going to work if {!lwt_main}
is currently running in some thread. *)
is currently running in some thread.
@raise Failure if {!lwt_main} was not called. *)