core: cleanup, and add a fined grained API for worker loop

This commit is contained in:
Simon Cruanes 2025-07-09 17:24:29 -04:00
parent 1a64e7345e
commit 55e3e77a66
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
4 changed files with 99 additions and 51 deletions

View file

@ -28,7 +28,6 @@ type worker_state = {
let[@inline] size_ (self : state) = Array.length self.threads
let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q
let k_worker_state : worker_state TLS.t = TLS.create ()
(*
get_thread_state = TLS.get_opt k_worker_state
@ -71,12 +70,6 @@ let schedule_w (self : worker_state) (task : task_full) : unit =
let get_next_task (self : worker_state) =
try Bb_queue.pop self.st.q with Bb_queue.Closed -> raise WL.No_more_tasks
let get_thread_state () =
match TLS.get_exn k_worker_state with
| st -> st
| exception TLS.Not_set ->
failwith "Moonpool: get_thread_state called from outside a runner."
let before_start (self : worker_state) =
let t_id = Thread.id @@ Thread.self () in
self.st.on_init_thread ~dom_id:self.dom_idx ~t_id ();
@ -103,7 +96,6 @@ let worker_ops : worker_state WL.ops =
WL.schedule = schedule_w;
runner;
get_next_task;
get_thread_state;
around_task;
on_exn;
before_start;

View file

@ -21,8 +21,6 @@ exception No_more_tasks
type 'st ops = {
schedule: 'st -> task_full -> unit;
get_next_task: 'st -> task_full; (** @raise No_more_tasks *)
get_thread_state: unit -> 'st;
(** Access current thread's worker state from any worker *)
around_task: 'st -> around_task;
on_exn: 'st -> Exn_bt.t -> unit;
runner: 'st -> Runner.t;
@ -98,7 +96,59 @@ let with_handler (type st arg) ~(ops : st ops) (self : st) :
let handler = Effect.Deep.{ retc = Fun.id; exnc = raise_with_bt; effc } in
fun f -> Effect.Deep.match_with f () handler
let worker_loop (type st) ~block_signals ~(ops : st ops) (self : st) : unit =
module type FINE_GRAINED_ARGS = sig
type st
val ops : st ops
val st : st
end
module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct
open Args
let cur_fiber : fiber ref = ref _dummy_fiber
let runner = ops.runner st
type state =
| New
| Ready
| Torn_down
let state = ref New
let run_task (task : task_full) : unit =
let (AT_pair (before_task, after_task)) = ops.around_task st in
let fiber =
match task with
| T_start { fiber; _ } | T_resume { fiber; _ } -> fiber
in
cur_fiber := fiber;
TLS.set k_cur_fiber fiber;
let _ctx = before_task runner in
(* run the task now, catching errors, handling effects *)
assert (task != _dummy_task);
(try
match task with
| T_start { fiber = _; f } -> with_handler ~ops st f
| T_resume { fiber = _; k } ->
(* this is already in an effect handler *)
k ()
with e ->
let bt = Printexc.get_raw_backtrace () in
let ebt = Exn_bt.make e bt in
ops.on_exn st ebt);
after_task runner _ctx;
cur_fiber := _dummy_fiber;
TLS.set k_cur_fiber _dummy_fiber
let setup (type st) ~block_signals () : unit =
if !state <> New then invalid_arg "worker_loop.setup: not a new instance";
state := Ready;
if block_signals then (
try
ignore
@ -116,52 +166,47 @@ let worker_loop (type st) ~block_signals ~(ops : st ops) (self : st) : unit =
with _ -> ()
);
let cur_fiber : fiber ref = ref _dummy_fiber in
let runner = ops.runner self in
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
let (AT_pair (before_task, after_task)) = ops.around_task self in
ops.before_start st
let run_task (task : task_full) : unit =
let fiber =
match task with
| T_start { fiber; _ } | T_resume { fiber; _ } -> fiber
in
cur_fiber := fiber;
TLS.set k_cur_fiber fiber;
let _ctx = before_task runner in
(* run the task now, catching errors, handling effects *)
assert (task != _dummy_task);
(try
match task with
| T_start { fiber = _; f } -> with_handler ~ops self f
| T_resume { fiber = _; k } ->
(* this is already in an effect handler *)
k ()
with e ->
let bt = Printexc.get_raw_backtrace () in
let ebt = Exn_bt.make e bt in
ops.on_exn self ebt);
after_task runner _ctx;
cur_fiber := _dummy_fiber;
TLS.set k_cur_fiber _dummy_fiber
in
ops.before_start self;
let run ?(max_tasks = max_int) () : unit =
if !state <> Ready then invalid_arg "worker_loop.run: not setup";
let continue = ref true in
try
while !continue do
match ops.get_next_task self with
| task -> run_task task
let n_tasks = ref 0 in
while !continue && !n_tasks < max_tasks do
match ops.get_next_task st with
| task ->
incr n_tasks;
run_task task
| exception No_more_tasks -> continue := false
done;
ops.cleanup self
done
let teardown () =
if !state <> Torn_down then (
state := Torn_down;
cur_fiber := _dummy_fiber;
ops.cleanup st
)
end
let worker_loop (type st) ~block_signals ~(ops : st ops) (self : st) : unit =
let module FG =
Fine_grained
(struct
type nonrec st = st
let ops = ops
let st = self
end)
()
in
FG.setup ~block_signals ();
try
FG.run ();
FG.teardown ()
with exn ->
let bt = Printexc.get_raw_backtrace () in
ops.cleanup self;
FG.teardown ();
Printexc.raise_with_backtrace exn bt

View file

@ -26,7 +26,6 @@ exception No_more_tasks
type 'st ops = {
schedule: 'st -> task_full -> unit;
get_next_task: 'st -> task_full;
get_thread_state: unit -> 'st;
around_task: 'st -> around_task;
on_exn: 'st -> Exn_bt.t -> unit;
runner: 'st -> Runner.t;
@ -34,4 +33,23 @@ type 'st ops = {
cleanup: 'st -> unit;
}
module type FINE_GRAINED_ARGS = sig
type st
val ops : st ops
val st : st
end
module Fine_grained (_ : FINE_GRAINED_ARGS) () : sig
val setup : block_signals:bool -> unit -> unit
(** Just initialize the loop *)
val run : ?max_tasks:int -> unit -> unit
(** Run the loop until no task remains or until [max_tasks] tasks have been
run *)
val teardown : unit -> unit
(** Tear down the loop *)
end
val worker_loop : block_signals:bool -> ops:'st ops -> 'st -> unit

View file

@ -62,12 +62,6 @@ let k_worker_state : worker_state TLS.t = TLS.create ()
let[@inline] get_current_worker_ () : worker_state option =
TLS.get_opt k_worker_state
let[@inline] get_current_worker_exn () : worker_state =
match TLS.get_exn k_worker_state with
| w -> w
| exception TLS.Not_set ->
failwith "Moonpool: get_current_runner was called from outside a pool."
(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
if self.n_waiting_nonzero then (
@ -212,7 +206,6 @@ let worker_ops : worker_state WL.ops =
WL.schedule = schedule_from_w;
runner;
get_next_task;
get_thread_state = get_current_worker_exn;
around_task;
on_exn;
before_start;