diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index 38a07d3d..3e70e06b 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -28,7 +28,6 @@ type worker_state = { let[@inline] size_ (self : state) = Array.length self.threads let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q -let k_worker_state : worker_state TLS.t = TLS.create () (* get_thread_state = TLS.get_opt k_worker_state @@ -71,12 +70,6 @@ let schedule_w (self : worker_state) (task : task_full) : unit = let get_next_task (self : worker_state) = try Bb_queue.pop self.st.q with Bb_queue.Closed -> raise WL.No_more_tasks -let get_thread_state () = - match TLS.get_exn k_worker_state with - | st -> st - | exception TLS.Not_set -> - failwith "Moonpool: get_thread_state called from outside a runner." - let before_start (self : worker_state) = let t_id = Thread.id @@ Thread.self () in self.st.on_init_thread ~dom_id:self.dom_idx ~t_id (); @@ -103,7 +96,6 @@ let worker_ops : worker_state WL.ops = WL.schedule = schedule_w; runner; get_next_task; - get_thread_state; around_task; on_exn; before_start; diff --git a/src/core/worker_loop_.ml b/src/core/worker_loop_.ml index bd2cd5ca..c375e8a2 100644 --- a/src/core/worker_loop_.ml +++ b/src/core/worker_loop_.ml @@ -21,8 +21,6 @@ exception No_more_tasks type 'st ops = { schedule: 'st -> task_full -> unit; get_next_task: 'st -> task_full; (** @raise No_more_tasks *) - get_thread_state: unit -> 'st; - (** Access current thread's worker state from any worker *) around_task: 'st -> around_task; on_exn: 'st -> Exn_bt.t -> unit; runner: 'st -> Runner.t; @@ -98,31 +96,28 @@ let with_handler (type st arg) ~(ops : st ops) (self : st) : let handler = Effect.Deep.{ retc = Fun.id; exnc = raise_with_bt; effc } in fun f -> Effect.Deep.match_with f () handler -let worker_loop (type st) ~block_signals ~(ops : st ops) (self : st) : unit = - if block_signals then ( - try - ignore - (Unix.sigprocmask SIG_BLOCK - [ - Sys.sigterm; - Sys.sigpipe; - Sys.sigint; - Sys.sigchld; - Sys.sigalrm; - Sys.sigusr1; - Sys.sigusr2; - ] - : _ list) - with _ -> () - ); +module type FINE_GRAINED_ARGS = sig + type st - let cur_fiber : fiber ref = ref _dummy_fiber in - let runner = ops.runner self in - TLS.set Runner.For_runner_implementors.k_cur_runner runner; + val ops : st ops + val st : st +end - let (AT_pair (before_task, after_task)) = ops.around_task self in +module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct + open Args + + let cur_fiber : fiber ref = ref _dummy_fiber + let runner = ops.runner st + + type state = + | New + | Ready + | Torn_down + + let state = ref New let run_task (task : task_full) : unit = + let (AT_pair (before_task, after_task)) = ops.around_task st in let fiber = match task with | T_start { fiber; _ } | T_resume { fiber; _ } -> fiber @@ -136,32 +131,82 @@ let worker_loop (type st) ~block_signals ~(ops : st ops) (self : st) : unit = assert (task != _dummy_task); (try match task with - | T_start { fiber = _; f } -> with_handler ~ops self f + | T_start { fiber = _; f } -> with_handler ~ops st f | T_resume { fiber = _; k } -> (* this is already in an effect handler *) k () with e -> let bt = Printexc.get_raw_backtrace () in let ebt = Exn_bt.make e bt in - ops.on_exn self ebt); + ops.on_exn st ebt); after_task runner _ctx; cur_fiber := _dummy_fiber; TLS.set k_cur_fiber _dummy_fiber - in - ops.before_start self; + let setup (type st) ~block_signals () : unit = + if !state <> New then invalid_arg "worker_loop.setup: not a new instance"; + state := Ready; - let continue = ref true in - try - while !continue do - match ops.get_next_task self with - | task -> run_task task + if block_signals then ( + try + ignore + (Unix.sigprocmask SIG_BLOCK + [ + Sys.sigterm; + Sys.sigpipe; + Sys.sigint; + Sys.sigchld; + Sys.sigalrm; + Sys.sigusr1; + Sys.sigusr2; + ] + : _ list) + with _ -> () + ); + + TLS.set Runner.For_runner_implementors.k_cur_runner runner; + + ops.before_start st + + let run ?(max_tasks = max_int) () : unit = + if !state <> Ready then invalid_arg "worker_loop.run: not setup"; + + let continue = ref true in + let n_tasks = ref 0 in + while !continue && !n_tasks < max_tasks do + match ops.get_next_task st with + | task -> + incr n_tasks; + run_task task | exception No_more_tasks -> continue := false - done; - ops.cleanup self + done + + let teardown () = + if !state <> Torn_down then ( + state := Torn_down; + cur_fiber := _dummy_fiber; + ops.cleanup st + ) +end + +let worker_loop (type st) ~block_signals ~(ops : st ops) (self : st) : unit = + let module FG = + Fine_grained + (struct + type nonrec st = st + + let ops = ops + let st = self + end) + () + in + FG.setup ~block_signals (); + try + FG.run (); + FG.teardown () with exn -> let bt = Printexc.get_raw_backtrace () in - ops.cleanup self; + FG.teardown (); Printexc.raise_with_backtrace exn bt diff --git a/src/core/worker_loop_.mli b/src/core/worker_loop_.mli index 7098deb8..3041b0dd 100644 --- a/src/core/worker_loop_.mli +++ b/src/core/worker_loop_.mli @@ -26,7 +26,6 @@ exception No_more_tasks type 'st ops = { schedule: 'st -> task_full -> unit; get_next_task: 'st -> task_full; - get_thread_state: unit -> 'st; around_task: 'st -> around_task; on_exn: 'st -> Exn_bt.t -> unit; runner: 'st -> Runner.t; @@ -34,4 +33,23 @@ type 'st ops = { cleanup: 'st -> unit; } +module type FINE_GRAINED_ARGS = sig + type st + + val ops : st ops + val st : st +end + +module Fine_grained (_ : FINE_GRAINED_ARGS) () : sig + val setup : block_signals:bool -> unit -> unit + (** Just initialize the loop *) + + val run : ?max_tasks:int -> unit -> unit + (** Run the loop until no task remains or until [max_tasks] tasks have been + run *) + + val teardown : unit -> unit + (** Tear down the loop *) +end + val worker_loop : block_signals:bool -> ops:'st ops -> 'st -> unit diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index 1b95cd16..153f4f06 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -62,12 +62,6 @@ let k_worker_state : worker_state TLS.t = TLS.create () let[@inline] get_current_worker_ () : worker_state option = TLS.get_opt k_worker_state -let[@inline] get_current_worker_exn () : worker_state = - match TLS.get_exn k_worker_state with - | w -> w - | exception TLS.Not_set -> - failwith "Moonpool: get_current_runner was called from outside a pool." - (** Try to wake up a waiter, if there's any. *) let[@inline] try_wake_someone_ (self : state) : unit = if self.n_waiting_nonzero then ( @@ -212,7 +206,6 @@ let worker_ops : worker_state WL.ops = WL.schedule = schedule_from_w; runner; get_next_task; - get_thread_state = get_current_worker_exn; around_task; on_exn; before_start;