From e3be2aceaa49d5535d9a0def6a73fb0b37cbd566 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 4 Sep 2025 16:03:06 -0400 Subject: [PATCH] feat lwt: make sure we can setup/cleanup multiple times --- src/lwt/moonpool_lwt.ml | 198 ++++++++++++++++++++++++--------------- src/lwt/moonpool_lwt.mli | 2 + 2 files changed, 123 insertions(+), 77 deletions(-) diff --git a/src/lwt/moonpool_lwt.ml b/src/lwt/moonpool_lwt.ml index 9895ad30..9fa62582 100644 --- a/src/lwt/moonpool_lwt.ml +++ b/src/lwt/moonpool_lwt.ml @@ -20,7 +20,7 @@ module Scheduler_state = struct (** Other threads ask us to run closures in the lwt thread *) mutex: Mutex.t; mutable thread: int; - mutable closed: bool; + closed: bool Atomic.t; mutable as_runner: Moonpool.Runner.t; mutable enter_hook: Lwt_main.Enter_iter_hooks.hook option; mutable leave_hook: Lwt_main.Leave_iter_hooks.hook option; @@ -29,13 +29,16 @@ module Scheduler_state = struct has_notified: bool Atomic.t; } - let st : st = + (** Main state *) + let cur_st : st option Atomic.t = Atomic.make None + + let create_new () : st = { tasks = Queue.create (); actions_from_other_threads = Queue.create (); mutex = Mutex.create (); - thread = -1; - closed = false; + thread = Thread.id (Thread.self ()); + closed = Atomic.make false; as_runner = Moonpool.Runner.dummy; enter_hook = None; leave_hook = None; @@ -44,17 +47,11 @@ module Scheduler_state = struct } let[@inline never] add_action_from_another_thread_ (self : st) f : unit = - Mutex.lock st.mutex; + Mutex.lock self.mutex; Queue.push f self.actions_from_other_threads; - Mutex.unlock st.mutex; if not (Atomic.exchange self.has_notified true) then - Lwt_unix.send_notification self.notification -end - -module Ops = struct - type st = Scheduler_state.st - - let around_task _ = default_around_task_ + Lwt_unix.send_notification self.notification; + Mutex.unlock self.mutex let[@inline] on_lwt_thread_ (self : st) : bool = Thread.id (Thread.self ()) = self.thread @@ -63,17 +60,41 @@ module Ops = struct if on_lwt_thread_ self then f () else - Scheduler_state.add_action_from_another_thread_ self f + add_action_from_another_thread_ self f + + let cleanup (st : st) = + match Atomic.get cur_st with + | Some st' -> + if st != st' then + failwith + "moonpool-lwt: cleanup failed (state is not the currently active \ + one!)"; + if not (on_lwt_thread_ st) then + failwith "moonpool-lwt: cleanup from the wrong thread"; + Option.iter Lwt_main.Enter_iter_hooks.remove st.enter_hook; + Option.iter Lwt_main.Leave_iter_hooks.remove st.leave_hook; + + Atomic.set cur_st None + | _ -> () +end + +module Ops = struct + type st = Scheduler_state.st + + let around_task _ = default_around_task_ let schedule (self : st) t = - run_on_lwt_thread_ self (fun () -> Queue.push t self.tasks) + if Atomic.get self.closed then + failwith "moonpool-lwt.schedule: scheduler is closed"; + Scheduler_state.run_on_lwt_thread_ self (fun () -> Queue.push t self.tasks) let get_next_task (self : st) = - if self.closed then raise WL.No_more_tasks; + if Atomic.get self.closed then raise WL.No_more_tasks; try Queue.pop self.tasks with Queue.Empty -> raise WL.No_more_tasks let on_exn _ ebt = !on_uncaught_exn ebt let runner (self : st) = self.as_runner + let cleanup = Scheduler_state.cleanup let as_runner (self : st) : Moonpool.Runner.t = Moonpool.Runner.For_runner_implementors.create @@ -84,19 +105,13 @@ module Ops = struct Mutex.unlock self.mutex; n) ~run_async:(fun ~fiber f -> schedule self @@ WL.T_start { fiber; f }) - ~shutdown:(fun ~wait:_ () -> self.closed <- true) + ~shutdown:(fun ~wait:_ () -> Atomic.set self.closed true) () let before_start (self : st) : unit = self.as_runner <- as_runner self; () - let cleanup (self : st) = - self.closed <- true; - Option.iter Lwt_main.Enter_iter_hooks.remove self.enter_hook; - Option.iter Lwt_main.Leave_iter_hooks.remove self.leave_hook; - () - let ops : st WL.ops = { schedule; @@ -107,17 +122,12 @@ module Ops = struct before_start; cleanup; } -end -open struct - module FG = - WL.Fine_grained - (struct - include Scheduler_state - - let ops = Ops.ops - end) - () + let setup st = + if Atomic.compare_and_set Scheduler_state.cur_st None (Some st) then + before_start st + else + failwith "moonpool-lwt: setup failed (state already in place)" end (** Resolve [prom] with the result of [lwt_fut] *) @@ -138,9 +148,28 @@ let[@inline] await_lwt_terminated (fut : _ Lwt.t) = | Fail exn -> raise exn | Sleep -> assert false +module Main_state = struct + let[@inline] get_st () : Scheduler_state.st = + match Atomic.get Scheduler_state.cur_st with + | Some st -> + if Atomic.get st.closed then failwith "moonpool-lwt: scheduler is closed"; + st + | None -> failwith "moonpool-lwt: scheduler is not setup" + + let[@inline] run_on_lwt_thread f = + Scheduler_state.run_on_lwt_thread_ (get_st ()) f + + let[@inline] on_lwt_thread () : bool = + Scheduler_state.on_lwt_thread_ (get_st ()) + + let[@inline] add_action_from_another_thread f : unit = + Scheduler_state.add_action_from_another_thread_ (get_st ()) f +end + let await_lwt (fut : _ Lwt.t) = - if Ops.on_lwt_thread_ Scheduler_state.st then ( - match Lwt.state fut with + if Scheduler_state.on_lwt_thread_ (Main_state.get_st ()) then ( + (* can directly access the future *) + match Lwt.state fut with | Return x -> x | Fail exn -> raise exn | Sleep -> @@ -150,14 +179,14 @@ let await_lwt (fut : _ Lwt.t) = await_lwt_terminated fut ) else ( let tr = M.Trigger.create () in - Scheduler_state.add_action_from_another_thread_ Scheduler_state.st - (fun () -> register_trigger_on_lwt_termination fut tr); + Main_state.add_action_from_another_thread (fun () -> + register_trigger_on_lwt_termination fut tr); M.Trigger.await_exn tr; await_lwt_terminated fut ) let lwt_of_fut (fut : 'a M.Fut.t) : 'a Lwt.t = - if not (Ops.on_lwt_thread_ Scheduler_state.st) then + if not (Main_state.on_lwt_thread ()) then failwith "lwt_of_fut: not on the lwt thread"; let lwt_fut, lwt_prom = Lwt.wait () in @@ -170,14 +199,14 @@ let lwt_of_fut (fut : 'a M.Fut.t) : 'a Lwt.t = in M.Fut.on_result fut (fun res -> - Ops.run_on_lwt_thread_ Scheduler_state.st (fun () -> + Main_state.run_on_lwt_thread (fun () -> (* can safely wakeup from the lwt thread *) wakeup_using_res res)); lwt_fut let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t = - if Ops.on_lwt_thread_ Scheduler_state.st then ( + if Main_state.on_lwt_thread () then ( match Lwt.state lwt_fut with | Return x -> M.Fut.return x | _ -> @@ -186,69 +215,82 @@ let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t = fut ) else ( let fut, prom = M.Fut.make () in - Scheduler_state.add_action_from_another_thread_ Scheduler_state.st - (fun () -> transfer_lwt_to_fut lwt_fut prom); + Main_state.add_action_from_another_thread (fun () -> + transfer_lwt_to_fut lwt_fut prom); fut ) let run_in_lwt_and_await (f : unit -> 'a Lwt.t) : 'a = - if Ops.on_lwt_thread_ Scheduler_state.st then ( + if Main_state.on_lwt_thread () then ( let fut = f () in await_lwt fut ) else ( let fut, prom = Fut.make () in - Scheduler_state.add_action_from_another_thread_ Scheduler_state.st - (fun () -> + Main_state.add_action_from_another_thread (fun () -> let lwt_fut = f () in transfer_lwt_to_fut lwt_fut prom); Fut.await fut ) -module Setup_lwt_hooks = struct +module Setup_lwt_hooks (ARG : sig + val st : Scheduler_state.st +end) = +struct + open ARG + + module FG = + WL.Fine_grained + (struct + include Scheduler_state + + let st = st + let ops = Ops.ops + end) + () + let run_in_hook () = (* execute actions sent from other threads; first transfer them all atomically to a local queue to reduce contention *) let local_acts = Queue.create () in - Mutex.lock Scheduler_state.st.mutex; - Queue.transfer Scheduler_state.st.actions_from_other_threads local_acts; - Atomic.set Scheduler_state.st.has_notified false; - Mutex.unlock Scheduler_state.st.mutex; + Mutex.lock st.mutex; + Queue.transfer st.actions_from_other_threads local_acts; + Atomic.set st.has_notified false; + Mutex.unlock st.mutex; Queue.iter (fun f -> f ()) local_acts; (* run tasks *) FG.run ~max_tasks:1000 (); - if not (Queue.is_empty Scheduler_state.st.tasks) then - ignore (Lwt.pause () : unit Lwt.t); + if not (Queue.is_empty st.tasks) then ignore (Lwt.pause () : unit Lwt.t); () - let is_setup_ = Atomic.make false - let[@inline] is_setup () : bool = Atomic.get is_setup_ - let setup () = - if not (Atomic.exchange is_setup_ true) then ( - (* only one thread does this *) - FG.setup ~block_signals:false (); + (* only one thread does this *) + FG.setup ~block_signals:false (); - Scheduler_state.st.thread <- Thread.self () |> Thread.id; - Scheduler_state.st.enter_hook <- - Some (Lwt_main.Enter_iter_hooks.add_last run_in_hook); - Scheduler_state.st.leave_hook <- - Some (Lwt_main.Leave_iter_hooks.add_last run_in_hook); - (* notification used to wake lwt up *) - Scheduler_state.st.notification <- - Lwt_unix.make_notification ~once:false run_in_hook - ) else if not (Ops.on_lwt_thread_ Scheduler_state.st) then - (* sanity check failed *) - failwith "moonpool-lwt.setup: called again on a different thread" + st.thread <- Thread.self () |> Thread.id; + st.enter_hook <- Some (Lwt_main.Enter_iter_hooks.add_last run_in_hook); + st.leave_hook <- Some (Lwt_main.Leave_iter_hooks.add_last run_in_hook); + (* notification used to wake lwt up *) + st.notification <- Lwt_unix.make_notification ~once:false run_in_hook end +let setup () : Scheduler_state.st = + let st = Scheduler_state.create_new () in + Ops.setup st; + let module Setup_lwt_hooks' = Setup_lwt_hooks (struct + let st = st + end) in + Setup_lwt_hooks'.setup (); + st + +let[@inline] is_setup () = Option.is_some @@ Atomic.get Scheduler_state.cur_st + let spawn_lwt f : _ Lwt.t = - if not (Setup_lwt_hooks.is_setup ()) then - failwith "spawn_lwt: scheduler was not setup"; + let st = Main_state.get_st () in let lwt_fut, lwt_prom = Lwt.wait () in - M.Runner.run_async Scheduler_state.st.as_runner (fun () -> + M.Runner.run_async st.as_runner (fun () -> try let x = f () in Lwt.wakeup lwt_prom x @@ -256,11 +298,13 @@ let spawn_lwt f : _ Lwt.t = lwt_fut let lwt_main (f : _ -> 'a) : 'a = - Setup_lwt_hooks.setup (); - let fut = spawn_lwt (fun () -> f Scheduler_state.st.as_runner) in + let st = setup () in + (* make sure to cleanup *) + let finally () = Scheduler_state.cleanup st in + Fun.protect ~finally @@ fun () -> + let fut = spawn_lwt (fun () -> f st.as_runner) in Lwt_main.run fut let[@inline] lwt_main_runner () = - if not (Setup_lwt_hooks.is_setup ()) then - failwith "lwt_main_runner: scheduler was not setup"; - Scheduler_state.st.as_runner + let st = Main_state.get_st () in + st.as_runner diff --git a/src/lwt/moonpool_lwt.mli b/src/lwt/moonpool_lwt.mli index 7696246b..4553d5b7 100644 --- a/src/lwt/moonpool_lwt.mli +++ b/src/lwt/moonpool_lwt.mli @@ -45,3 +45,5 @@ val lwt_main_runner : unit -> Moonpool.Runner.t (** The runner from {!lwt_main}. The runner is only going to work if {!lwt_main} is currently running in some thread. @raise Failure if {!lwt_main} was not called. *) + +val is_setup : unit -> bool