diff --git a/src/lwt/moonpool_lwt.ml b/src/lwt/moonpool_lwt.ml index 5b2039dd..b563fa52 100644 --- a/src/lwt/moonpool_lwt.ml +++ b/src/lwt/moonpool_lwt.ml @@ -43,7 +43,7 @@ module Scheduler_state = struct has_notified = Atomic.make false; } - let add_action_from_another_thread_ (self : st) f : unit = + let[@inline never] add_action_from_another_thread_ (self : st) f : unit = Mutex.lock st.mutex; Queue.push f self.actions_from_other_threads; Mutex.unlock st.mutex; @@ -56,12 +56,17 @@ module Ops = struct let around_task _ = default_around_task_ - let schedule (self : st) t = - if Thread.id (Thread.self ()) = self.thread then - Queue.push t self.tasks + let[@inline] on_lwt_thread_ (self : st) : bool = + Thread.id (Thread.self ()) = self.thread + + let[@inline] run_on_lwt_thread_ (self : st) (f : unit -> unit) : unit = + if on_lwt_thread_ self then + f () else - Scheduler_state.add_action_from_another_thread_ self (fun () -> - Queue.push t self.tasks) + Scheduler_state.add_action_from_another_thread_ self f + + let schedule (self : st) t = + run_on_lwt_thread_ self (fun () -> Queue.push t self.tasks) let get_next_task (self : st) = if self.closed then raise WL.No_more_tasks; @@ -74,8 +79,10 @@ module Ops = struct Moonpool.Runner.For_runner_implementors.create ~size:(fun () -> 1) ~num_tasks:(fun () -> - (* FIXME: thread safety. use an atomic?? *) - Queue.length self.tasks) + Mutex.lock self.mutex; + let n = Queue.length self.tasks in + Mutex.unlock self.mutex; + n) ~run_async:(fun ~fiber f -> schedule self @@ WL.T_start { fiber; f }) ~shutdown:(fun ~wait:_ () -> self.closed <- true) () @@ -113,20 +120,41 @@ open struct () end -let await_lwt (fut : _ Lwt.t) = +(** Resolve [prom] with the result of [lwt_fut] *) +let transfer_lwt_to_fut (lwt_fut : 'a Lwt.t) (prom : 'a Fut.promise) : unit = + Lwt.on_any lwt_fut + (fun x -> M.Fut.fulfill prom (Ok x)) + (fun exn -> + let bt = Printexc.get_callstack 10 in + M.Fut.fulfill prom (Error (Exn_bt.make exn bt))) + +let[@inline] register_trigger_on_lwt_termination (lwt_fut : _ Lwt.t) + (tr : M.Trigger.t) : unit = + Lwt.on_termination lwt_fut (fun _ -> M.Trigger.signal tr) + +let[@inline] await_lwt_terminated (fut : _ Lwt.t) = match Lwt.state fut with | Return x -> x | Fail exn -> raise exn - | Sleep -> - (* suspend fiber, wake it up when [fut] resolves *) - let trigger = M.Trigger.create () in - Lwt.on_termination fut (fun _ -> M.Trigger.signal trigger); - M.Trigger.await trigger |> Option.iter Exn_bt.raise; + | Sleep -> assert false - (match Lwt.state fut with +let await_lwt (fut : _ Lwt.t) = + if Ops.on_lwt_thread_ Scheduler_state.st then ( + match Lwt.state fut with | Return x -> x | Fail exn -> raise exn - | Sleep -> assert false) + | Sleep -> + let tr = M.Trigger.create () in + register_trigger_on_lwt_termination fut tr; + M.Trigger.await_exn tr; + await_lwt_terminated fut + ) else ( + let tr = M.Trigger.create () in + Scheduler_state.add_action_from_another_thread_ Scheduler_state.st + (fun () -> register_trigger_on_lwt_termination fut tr); + M.Trigger.await_exn tr; + await_lwt_terminated fut + ) let lwt_of_fut (fut : 'a M.Fut.t) : 'a Lwt.t = let lwt_fut, lwt_prom = Lwt.wait () in @@ -140,62 +168,83 @@ let lwt_of_fut (fut : 'a M.Fut.t) : 'a Lwt.t = in M.Fut.on_result fut (fun res -> - if Thread.id (Thread.self ()) = Scheduler_state.st.thread then - (* can safely wakeup from the lwt thread *) - wakeup_using_res res - else - Scheduler_state.add_action_from_another_thread_ Scheduler_state.st - (fun () -> wakeup_using_res res)); + Ops.run_on_lwt_thread_ Scheduler_state.st (fun () -> + (* can safely wakeup from the lwt thread *) + wakeup_using_res res)); lwt_fut let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t = - match Lwt.poll lwt_fut with - | Some x -> M.Fut.return x - | None -> + if Ops.on_lwt_thread_ Scheduler_state.st then ( + match Lwt.state lwt_fut with + | Return x -> M.Fut.return x + | _ -> + let fut, prom = M.Fut.make () in + transfer_lwt_to_fut lwt_fut prom; + fut + ) else ( let fut, prom = M.Fut.make () in - Lwt.on_any lwt_fut - (fun x -> M.Fut.fulfill prom (Ok x)) - (fun exn -> - let bt = Printexc.get_callstack 10 in - M.Fut.fulfill prom (Error (Exn_bt.make exn bt))); + Scheduler_state.add_action_from_another_thread_ Scheduler_state.st + (fun () -> transfer_lwt_to_fut lwt_fut prom); fut - -let run_in_hook () = - (* execute actions sent from other threads; first transfer them - all atomically to a local queue to reduce contention *) - let local_acts = Queue.create () in - Mutex.lock Scheduler_state.st.mutex; - Queue.transfer Scheduler_state.st.actions_from_other_threads local_acts; - Atomic.set Scheduler_state.st.has_notified false; - Mutex.unlock Scheduler_state.st.mutex; - - Queue.iter (fun f -> f ()) local_acts; - - (* run tasks *) - FG.run ~max_tasks:1000 (); - - if not (Queue.is_empty Scheduler_state.st.tasks) then - ignore (Lwt.pause () : unit Lwt.t); - () - -let is_setup_ = Atomic.make false - -let setup () = - if not (Atomic.exchange is_setup_ true) then ( - (* only one thread does this *) - FG.setup ~block_signals:false (); - Scheduler_state.st.enter_hook <- - Some (Lwt_main.Enter_iter_hooks.add_last run_in_hook); - Scheduler_state.st.leave_hook <- - Some (Lwt_main.Leave_iter_hooks.add_last run_in_hook); - (* notification used to wake lwt up *) - Scheduler_state.st.notification <- - Lwt_unix.make_notification ~once:false run_in_hook ) +let run_in_lwt_and_await (f : unit -> 'a Lwt.t) : 'a = + if Ops.on_lwt_thread_ Scheduler_state.st then ( + let fut = f () in + await_lwt fut + ) else ( + let fut, prom = Fut.make () in + Scheduler_state.add_action_from_another_thread_ Scheduler_state.st + (fun () -> + let lwt_fut = f () in + transfer_lwt_to_fut lwt_fut prom); + Fut.await fut + ) + +module Setup_lwt_hooks = struct + let run_in_hook () = + (* execute actions sent from other threads; first transfer them + all atomically to a local queue to reduce contention *) + let local_acts = Queue.create () in + Mutex.lock Scheduler_state.st.mutex; + Queue.transfer Scheduler_state.st.actions_from_other_threads local_acts; + Atomic.set Scheduler_state.st.has_notified false; + Mutex.unlock Scheduler_state.st.mutex; + + Queue.iter (fun f -> f ()) local_acts; + + (* run tasks *) + FG.run ~max_tasks:1000 (); + + if not (Queue.is_empty Scheduler_state.st.tasks) then + ignore (Lwt.pause () : unit Lwt.t); + () + + let is_setup_ = Atomic.make false + let[@inline] is_setup () : bool = Atomic.get is_setup_ + + let setup () = + if not (Atomic.exchange is_setup_ true) then ( + (* only one thread does this *) + FG.setup ~block_signals:false (); + + Scheduler_state.st.thread <- Thread.self () |> Thread.id; + Scheduler_state.st.enter_hook <- + Some (Lwt_main.Enter_iter_hooks.add_last run_in_hook); + Scheduler_state.st.leave_hook <- + Some (Lwt_main.Leave_iter_hooks.add_last run_in_hook); + (* notification used to wake lwt up *) + Scheduler_state.st.notification <- + Lwt_unix.make_notification ~once:false run_in_hook + ) else if not (Ops.on_lwt_thread_ Scheduler_state.st) then + (* sanity check failed *) + failwith "moonpool-lwt.setup: called again on a different thread" +end + let spawn_lwt f : _ Lwt.t = - setup (); + if not (Setup_lwt_hooks.is_setup ()) then + failwith "spawn_lwt: scheduler was not setup"; let lwt_fut, lwt_prom = Lwt.wait () in M.Runner.run_async Scheduler_state.st.as_runner (fun () -> try @@ -205,11 +254,11 @@ let spawn_lwt f : _ Lwt.t = lwt_fut let lwt_main (f : _ -> 'a) : 'a = - setup (); - Scheduler_state.st.thread <- Thread.self () |> Thread.id; + Setup_lwt_hooks.setup (); let fut = spawn_lwt (fun () -> f Scheduler_state.st.as_runner) in Lwt_main.run fut -let lwt_main_runner () = - if not (Atomic.get is_setup_) then failwith "lwt_main_runner: not setup yet"; +let[@inline] lwt_main_runner () = + if not (Setup_lwt_hooks.is_setup ()) then + failwith "lwt_main_runner: scheduler was not setup"; Scheduler_state.st.as_runner diff --git a/src/lwt/moonpool_lwt.mli b/src/lwt/moonpool_lwt.mli index 79d9e61d..c8ca9485 100644 --- a/src/lwt/moonpool_lwt.mli +++ b/src/lwt/moonpool_lwt.mli @@ -23,13 +23,16 @@ val lwt_of_fut : 'a Moonpool.Fut.t -> 'a Lwt.t (** {2 Helpers on the moonpool side} *) val spawn_lwt : (unit -> 'a) -> 'a Lwt.t -(** This spawns a task that runs in the Lwt scheduler *) +(** This spawns a task that runs in the Lwt scheduler. + @raise Failure if {!lwt_main} was not called. *) val await_lwt : 'a Lwt.t -> 'a (** [await_lwt fut] awaits a Lwt future from inside a task running on a moonpool runner. This must be run from within a Moonpool runner so that the await-ing effect is handled. *) +val run_in_lwt_and_await : (unit -> 'a Lwt.t) -> 'a + (** {2 Wrappers around Lwt_main} *) val on_uncaught_exn : (Moonpool.Exn_bt.t -> unit) ref @@ -39,4 +42,5 @@ val lwt_main : (Moonpool.Runner.t -> 'a) -> 'a val lwt_main_runner : unit -> Moonpool.Runner.t (** The runner from {!lwt_main}. The runner is only going to work if {!lwt_main} - is currently running in some thread. *) + is currently running in some thread. + @raise Failure if {!lwt_main} was not called. *)