diff --git a/src/lwt/moonpool_lwt.ml b/src/lwt/moonpool_lwt.ml index 1e1086c0..cbb40758 100644 --- a/src/lwt/moonpool_lwt.ml +++ b/src/lwt/moonpool_lwt.ml @@ -21,6 +21,7 @@ module Scheduler_state = struct mutex: Mutex.t; mutable thread: int; closed: bool Atomic.t; + cleanup_done: bool Atomic.t; mutable as_runner: Moonpool.Runner.t; mutable enter_hook: Lwt_main.Enter_iter_hooks.hook option; mutable leave_hook: Lwt_main.Leave_iter_hooks.hook option; @@ -39,6 +40,7 @@ module Scheduler_state = struct mutex = Mutex.create (); thread = Thread.id (Thread.self ()); closed = Atomic.make false; + cleanup_done = Atomic.make false; as_runner = Moonpool.Runner.dummy; enter_hook = None; leave_hook = None; @@ -46,12 +48,15 @@ module Scheduler_state = struct has_notified = Atomic.make false; } + let[@inline] notify_ (self : st) : unit = + if not (Atomic.exchange self.has_notified true) then + Lwt_unix.send_notification self.notification + let[@inline never] add_action_from_another_thread_ (self : st) f : unit = Mutex.lock self.mutex; Queue.push f self.actions_from_other_threads; Mutex.unlock self.mutex; - if not (Atomic.exchange self.has_notified true) then - Lwt_unix.send_notification self.notification + notify_ self let[@inline] on_lwt_thread_ (self : st) : bool = Thread.id (Thread.self ()) = self.thread @@ -71,9 +76,12 @@ module Scheduler_state = struct one!)"; if not (on_lwt_thread_ st) then failwith "moonpool-lwt: cleanup from the wrong thread"; - Option.iter Lwt_main.Enter_iter_hooks.remove st.enter_hook; - Option.iter Lwt_main.Leave_iter_hooks.remove st.leave_hook; - Lwt_unix.stop_notification st.notification; + Atomic.set st.closed true; + if not (Atomic.exchange st.cleanup_done true) then ( + Option.iter Lwt_main.Enter_iter_hooks.remove st.enter_hook; + Option.iter Lwt_main.Leave_iter_hooks.remove st.leave_hook; + Lwt_unix.stop_notification st.notification + ); Atomic.set cur_st None | None -> failwith "moonpool-lwt: cleanup failed (no current active state)" @@ -304,6 +312,8 @@ let lwt_main (f : _ -> 'a) : 'a = let finally () = Scheduler_state.cleanup st in Fun.protect ~finally @@ fun () -> let fut = spawn_lwt (fun () -> f st.as_runner) in + (* make sure the scheduler isn't already sleeping *) + Scheduler_state.notify_ st; Lwt_main.run fut let[@inline] lwt_main_runner () =