From 9fb23bed4cee86b5baaac4d966ae9512e4361090 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 28 Aug 2024 12:39:15 -0400 Subject: [PATCH] refactor core: use picos for schedulers; add Worker_loop_ we factor most of the thread workers' logic in `Worker_loop_`, which is now shared between Ws_pool and Fifo_pool --- src/core/fifo_pool.ml | 193 +++++++------- src/core/fifo_pool.mli | 14 - src/core/moonpool.ml | 2 +- src/core/moonpool.mli | 18 +- src/core/runner.ml | 22 +- src/core/runner.mli | 9 +- src/core/suspend_.ml | 70 ----- src/core/suspend_.mli | 86 ------ src/core/task_local_storage.ml | 2 + src/core/task_local_storage.mli | 2 + src/core/types_.ml | 47 ++-- src/core/worker_loop_.ml | 153 +++++++++++ src/core/ws_pool.ml | 447 +++++++++++++------------------- src/private/ws_deque_.ml | 17 +- src/private/ws_deque_.mli | 4 + 15 files changed, 497 insertions(+), 589 deletions(-) delete mode 100644 src/core/suspend_.ml delete mode 100644 src/core/suspend_.mli create mode 100644 src/core/worker_loop_.ml diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index 7c1b491b..1fe4b708 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -1,87 +1,39 @@ open Types_ include Runner +module WL = Worker_loop_ + +type fiber = Picos.Fiber.t +type task_full = WL.task_full let ( let@ ) = ( @@ ) -type task_full = - | T_start of { - ls: Task_local_storage.t; - f: task; - } - | T_resume : { - ls: Task_local_storage.t; - k: 'a -> unit; - x: 'a; - } - -> task_full - type state = { threads: Thread.t array; q: task_full Bb_queue.t; (** Queue for tasks. *) + around_task: WL.around_task; + as_runner: t lazy_t; + (* init options *) + name: string option; + on_init_thread: dom_id:int -> t_id:int -> unit -> unit; + on_exit_thread: dom_id:int -> t_id:int -> unit -> unit; + on_exn: exn -> Printexc.raw_backtrace -> unit; } (** internal state *) +type worker_state = { + idx: int; + dom_idx: int; + st: state; + mutable current: fiber; +} + let[@inline] size_ (self : state) = Array.length self.threads let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q - -(** Run [task] as is, on the pool. *) -let schedule_ (self : state) (task : task_full) : unit = - try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown - -type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task -type worker_state = { mutable cur_ls: Task_local_storage.t option } - let k_worker_state : worker_state TLS.t = TLS.create () -let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = - let w = { cur_ls = None } in - TLS.set k_worker_state w; - TLS.set Runner.For_runner_implementors.k_cur_runner runner; - - let (AT_pair (before_task, after_task)) = around_task in - - let on_suspend () = - match TLS.get_opt k_worker_state with - | Some { cur_ls = Some ls; _ } -> ls - | _ -> assert false - in - let run_another_task ls task' = schedule_ self @@ T_start { f = task'; ls } in - let resume ls k res = schedule_ self @@ T_resume { ls; k; x = res } in - - let run_task (task : task_full) : unit = - let ls = - match task with - | T_start { ls; _ } | T_resume { ls; _ } -> ls - in - w.cur_ls <- Some ls; - TLS.set k_cur_storage ls; - let _ctx = before_task runner in - - (* run the task now, catching errors, handling effects *) - (try - match task with - | T_start { f = task; _ } -> - (* run [task()] and handle [suspend] in it *) - Suspend_.with_suspend - (WSH { on_suspend; run = run_another_task; resume }) - task - | T_resume { k; x; _ } -> - (* this is already in an effect handler *) - k x - with e -> - let bt = Printexc.get_raw_backtrace () in - on_exn e bt); - after_task runner _ctx; - w.cur_ls <- None; - TLS.set k_cur_storage _dummy_ls - in - - let continue = ref true in - while !continue do - match Bb_queue.pop self.q with - | task -> run_task task - | exception Bb_queue.Closed -> continue := false - done +(* +get_thread_state = TLS.get_opt k_worker_state + *) let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () @@ -98,10 +50,14 @@ type ('a, 'b) create_args = ?name:string -> 'a -let default_around_task_ : around_task = AT_pair (ignore, fun _ _ -> ()) +let default_around_task_ : WL.around_task = AT_pair (ignore, fun _ _ -> ()) + +(** Run [task] as is, on the pool. *) +let schedule_ (self : state) (task : task_full) : unit = + try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown let runner_of_state (pool : state) : t = - let run_async ~ls f = schedule_ pool @@ T_start { f; ls } in + let run_async ~fiber f = schedule_ pool @@ T_start { f; fiber } in Runner.For_runner_implementors.create ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) ~run_async @@ -109,13 +65,59 @@ let runner_of_state (pool : state) : t = ~num_tasks:(fun () -> num_tasks_ pool) () +(** Run [task] as is, on the pool. *) +let schedule_w (self : worker_state) (task : task_full) : unit = + try Bb_queue.push self.st.q task with Bb_queue.Closed -> raise Shutdown + +let get_next_task (self : worker_state) = + try Bb_queue.pop self.st.q with Bb_queue.Closed -> raise WL.No_more_tasks + +let get_thread_state () = + match TLS.get_exn k_worker_state with + | st -> st + | exception TLS.Not_set -> + failwith "Moonpool: get_thread_state called from outside a runner." + +let before_start (self : worker_state) = + let t_id = Thread.id @@ Thread.self () in + self.st.on_init_thread ~dom_id:self.dom_idx ~t_id (); + + (* set thread name *) + Option.iter + (fun name -> + Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx)) + self.st.name + +let cleanup (self : worker_state) : unit = + (* on termination, decrease refcount of underlying domain *) + Domain_pool_.decr_on self.dom_idx; + let t_id = Thread.id @@ Thread.self () in + self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id () + +let worker_ops : worker_state WL.ops = + let runner (st : worker_state) = Lazy.force st.st.as_runner in + let around_task st = st.st.around_task in + let on_exn (st : worker_state) (ebt : Exn_bt.t) = + st.st.on_exn ebt.exn ebt.bt + in + { + WL.schedule = schedule_w; + runner; + get_next_task; + get_thread_state; + around_task; + on_exn; + before_start; + cleanup; + } + let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?around_task ?num_threads ?name () : t = (* wrapper *) let around_task = match around_task with - | Some (f, g) -> AT_pair (f, g) + | Some (f, g) -> WL.AT_pair (f, g) | None -> default_around_task_ in @@ -127,9 +129,18 @@ let create ?(on_init_thread = default_thread_init_exit_) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) let offset = Random.int num_domains in - let pool = + let rec pool = let dummy_thread = Thread.self () in - { threads = Array.make num_threads dummy_thread; q = Bb_queue.create () } + { + threads = Array.make num_threads dummy_thread; + q = Bb_queue.create (); + around_task; + as_runner = lazy (runner_of_state pool); + name; + on_init_thread; + on_exit_thread; + on_exn; + } in let runner = runner_of_state pool in @@ -142,31 +153,11 @@ let create ?(on_init_thread = default_thread_init_exit_) let start_thread_with_idx i = let dom_idx = (offset + i) mod num_domains in - (* function run in the thread itself *) - let main_thread_fun () : unit = - let thread = Thread.self () in - let t_id = Thread.id thread in - on_init_thread ~dom_id:dom_idx ~t_id (); - - (* set thread name *) - Option.iter - (fun name -> - Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i)) - name; - - let run () = worker_thread_ pool runner ~on_exn ~around_task in - - (* now run the main loop *) - Fun.protect run ~finally:(fun () -> - (* on termination, decrease refcount of underlying domain *) - Domain_pool_.decr_on dom_idx); - on_exit_thread ~dom_id:dom_idx ~t_id () - in - (* function called in domain with index [i], to create the thread and push it into [receive_threads] *) let create_thread_in_domain () = - let thread = Thread.create main_thread_fun () in + let st = { idx = i; dom_idx; st = pool; current = _dummy_fiber } in + let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in (* send the thread from the domain back to us *) Bb_queue.push receive_threads (i, thread) in @@ -196,13 +187,3 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads in let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in f pool - -module Private_ = struct - type nonrec state = state - - let create_state ~threads () : state = { threads; q = Bb_queue.create () } - let runner_of_state = runner_of_state - - let run_thread (st : state) (self : t) ~on_exn : unit = - worker_thread_ st self ~on_exn ~around_task:default_around_task_ -end diff --git a/src/core/fifo_pool.mli b/src/core/fifo_pool.mli index d7d103cf..637586a9 100644 --- a/src/core/fifo_pool.mli +++ b/src/core/fifo_pool.mli @@ -44,17 +44,3 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args When [f pool] returns or fails, [pool] is shutdown and its resources are released. Most parameters are the same as in {!create}. *) - -(**/**) - -module Private_ : sig - type state - - val create_state : threads:Thread.t array -> unit -> state - val runner_of_state : state -> Runner.t - - val run_thread : - state -> t -> on_exn:(exn -> Printexc.raw_backtrace -> unit) -> unit -end - -(**/**) diff --git a/src/core/moonpool.ml b/src/core/moonpool.ml index 60edc833..cf91f3c3 100644 --- a/src/core/moonpool.ml +++ b/src/core/moonpool.ml @@ -36,7 +36,7 @@ module Ws_pool = Ws_pool module Private = struct module Ws_deque_ = Ws_deque_ - module Suspend_ = Suspend_ + module Worker_loop_ = Worker_loop_ module Domain_ = Domain_ module Tracing_ = Tracing_ diff --git a/src/core/moonpool.mli b/src/core/moonpool.mli index d6abc764..d4243491 100644 --- a/src/core/moonpool.mli +++ b/src/core/moonpool.mli @@ -33,13 +33,13 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t to run the thread. This ensures that we don't always pick the same domain to run all the various threads needed in an application (timers, event loops, etc.) *) -val run_async : ?ls:Task_local_storage.t -> Runner.t -> (unit -> unit) -> unit +val run_async : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> unit) -> unit (** [run_async runner task] schedules the task to run on the given runner. This means [task()] will be executed at some point in the future, possibly in another thread. @since 0.5 *) -val run_wait_block : ?ls:Task_local_storage.t -> Runner.t -> (unit -> 'a) -> 'a +val run_wait_block : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> 'a) -> 'a (** [run_wait_block runner f] schedules [f] for later execution on the runner, like {!run_async}. It then blocks the current thread until [f()] is done executing, @@ -212,16 +212,10 @@ module Private : sig module Ws_deque_ = Ws_deque_ (** A deque for work stealing, fixed size. *) - (** {2 Suspensions} *) - - module Suspend_ = Suspend_ - [@@alert - unstable "this module is an implementation detail of moonpool for now"] - (** Suspensions. - - This is only going to work on OCaml 5.x. - - {b NOTE}: this is not stable for now. *) + module Worker_loop_ = Worker_loop_ + (** Worker loop. This is useful to implement custom runners, it + should run on each thread of the runner. + @since NEXT_RELEASE *) module Domain_ = Domain_ (** Utils for domains *) diff --git a/src/core/runner.ml b/src/core/runner.ml index f5d8c307..a95de289 100644 --- a/src/core/runner.ml +++ b/src/core/runner.ml @@ -1,9 +1,10 @@ open Types_ +type fiber = Picos.Fiber.t type task = unit -> unit type t = runner = { - run_async: ls:local_storage -> task -> unit; + run_async: fiber:fiber -> task -> unit; shutdown: wait:bool -> unit -> unit; size: unit -> int; num_tasks: unit -> int; @@ -11,8 +12,15 @@ type t = runner = { exception Shutdown -let[@inline] run_async ?(ls = create_local_storage ()) (self : t) f : unit = - self.run_async ~ls f +let[@inline] run_async ?fiber (self : t) f : unit = + let fiber = + match fiber with + | Some f -> f + | None -> + let comp = Picos.Computation.create () in + Picos.Fiber.create ~forbid:false comp + in + self.run_async ~fiber f let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true () @@ -22,9 +30,9 @@ let[@inline] shutdown_without_waiting (self : t) : unit = let[@inline] num_tasks (self : t) : int = self.num_tasks () let[@inline] size (self : t) : int = self.size () -let run_wait_block ?ls self (f : unit -> 'a) : 'a = +let run_wait_block ?fiber self (f : unit -> 'a) : 'a = let q = Bb_queue.create () in - run_async ?ls self (fun () -> + run_async ?fiber self (fun () -> try let x = f () in Bb_queue.push q (Ok x) @@ -47,9 +55,9 @@ let dummy : t = ~size:(fun () -> 0) ~num_tasks:(fun () -> 0) ~shutdown:(fun ~wait:_ () -> ()) - ~run_async:(fun ~ls:_ _ -> + ~run_async:(fun ~fiber:_ _ -> failwith "Runner.dummy: cannot actually run tasks") () let get_current_runner = get_current_runner -let get_current_storage = get_current_storage +let get_current_fiber = get_current_fiber diff --git a/src/core/runner.mli b/src/core/runner.mli index f454f598..d8c0ea0a 100644 --- a/src/core/runner.mli +++ b/src/core/runner.mli @@ -5,6 +5,7 @@ @since 0.3 *) +type fiber = Picos.Fiber.t type task = unit -> unit type t @@ -33,14 +34,14 @@ val shutdown_without_waiting : t -> unit exception Shutdown -val run_async : ?ls:Task_local_storage.t -> t -> task -> unit +val run_async : ?fiber:fiber -> t -> task -> unit (** [run_async pool f] schedules [f] for later execution on the runner in one of the threads. [f()] will run on one of the runner's worker threads/domains. @param ls if provided, run the task with this initial local storage @raise Shutdown if the runner was shut down before [run_async] was called. *) -val run_wait_block : ?ls:Task_local_storage.t -> t -> (unit -> 'a) -> 'a +val run_wait_block : ?fiber:fiber -> t -> (unit -> 'a) -> 'a (** [run_wait_block pool f] schedules [f] for later execution on the pool, like {!run_async}. It then blocks the current thread until [f()] is done executing, @@ -65,7 +66,7 @@ module For_runner_implementors : sig size:(unit -> int) -> num_tasks:(unit -> int) -> shutdown:(wait:bool -> unit -> unit) -> - run_async:(ls:Task_local_storage.t -> task -> unit) -> + run_async:(fiber:fiber -> task -> unit) -> unit -> t (** Create a new runner. @@ -85,6 +86,6 @@ val get_current_runner : unit -> t option happens on a thread that belongs in a runner. @since 0.5 *) -val get_current_storage : unit -> Task_local_storage.t option +val get_current_fiber : unit -> fiber option (** [get_current_storage runner] gets the local storage for the currently running task. *) diff --git a/src/core/suspend_.ml b/src/core/suspend_.ml deleted file mode 100644 index 0d62e6fb..00000000 --- a/src/core/suspend_.ml +++ /dev/null @@ -1,70 +0,0 @@ -type suspension = unit Exn_bt.result -> unit -type task = unit -> unit - -type suspension_handler = { - handle: - run:(task -> unit) -> - resume:(suspension -> unit Exn_bt.result -> unit) -> - suspension -> - unit; -} -[@@unboxed] - -type with_suspend_handler = - | WSH : { - on_suspend: unit -> 'state; - (** on_suspend called when [f()] suspends itself. *) - run: 'state -> task -> unit; (** run used to schedule new tasks *) - resume: 'state -> suspension -> unit Exn_bt.result -> unit; - (** resume run the suspension. Must be called exactly once. *) - } - -> with_suspend_handler - -[@@@ifge 5.0] -[@@@ocaml.alert "-unstable"] - -module A = Atomic_ - -type _ Effect.t += - | Suspend : suspension_handler -> unit Effect.t - | Yield : unit Effect.t - -let[@inline] yield () = Effect.perform Yield -let[@inline] suspend h = Effect.perform (Suspend h) - -let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit = - let module E = Effect.Deep in - (* effect handler *) - let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = - function - | Suspend h -> - (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) - Some - (fun k -> - let state = on_suspend () in - let k' : suspension = function - | Ok () -> E.continue k () - | Error ebt -> Exn_bt.discontinue k ebt - in - h.handle ~run:(run state) ~resume:(resume state) k') - | Yield -> - (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) - Some - (fun k -> - let state = on_suspend () in - let k' : suspension = function - | Ok () -> E.continue k () - | Error ebt -> Exn_bt.discontinue k ebt - in - resume state k' @@ Ok ()) - | _ -> None - in - - E.try_with f () { E.effc } - -[@@@ocaml.alert "+unstable"] -[@@@else_] - -let[@inline] with_suspend (WSH _) f = f () - -[@@@endif] diff --git a/src/core/suspend_.mli b/src/core/suspend_.mli deleted file mode 100644 index 7a71b36d..00000000 --- a/src/core/suspend_.mli +++ /dev/null @@ -1,86 +0,0 @@ -(** (Private) suspending tasks using Effects. - - This module is an implementation detail of Moonpool and should - not be used outside of it, except by experts to implement {!Runner}. *) - -type suspension = unit Exn_bt.result -> unit -(** A suspended computation *) - -type task = unit -> unit - -type suspension_handler = { - handle: - run:(task -> unit) -> - resume:(suspension -> unit Exn_bt.result -> unit) -> - suspension -> - unit; -} -[@@unboxed] -(** The handler that knows what to do with the suspended computation. - - The handler is given a few things: - - - the suspended computation (which can be resumed with a result - eventually); - - a [run] function that can be used to start tasks to perform some - computation. - - a [resume] function to resume the suspended computation. This - must be called exactly once, in all situations. - - This means that a fork-join primitive, for example, can use a single call - to {!suspend} to: - - suspend the caller until the fork-join is done - - use [run] to start all the tasks. Typically [run] is called multiple times, - which is where the "fork" part comes from. Each call to [run] potentially - runs in parallel with the other calls. The calls must coordinate so - that, once they are all done, the suspended caller is resumed with the - aggregated result of the computation. - - use [resume] exactly -*) - -[@@@ifge 5.0] -[@@@ocaml.alert "-unstable"] - -type _ Effect.t += - | Suspend : suspension_handler -> unit Effect.t - (** The effect used to suspend the current thread and pass it, suspended, - to the handler. The handler will ensure that the suspension is resumed later - once some computation has been done. *) - | Yield : unit Effect.t - (** The effect used to interrupt the current computation and immediately re-schedule - it on the same runner. *) - -[@@@ocaml.alert "+unstable"] - -val yield : unit -> unit -(** Interrupt current computation, and re-schedule it at the end of the - runner's job queue. *) - -val suspend : suspension_handler -> unit -(** [suspend h] jumps back to the nearest {!with_suspend} - and calls [h.handle] with the current continuation [k] - and a task runner function. -*) - -[@@@endif] - -type with_suspend_handler = - | WSH : { - on_suspend: unit -> 'state; - (** on_suspend called when [f()] suspends itself. *) - run: 'state -> task -> unit; (** run used to schedule new tasks *) - resume: 'state -> suspension -> unit Exn_bt.result -> unit; - (** resume run the suspension. Must be called exactly once. *) - } - -> with_suspend_handler - -val with_suspend : with_suspend_handler -> (unit -> unit) -> unit -(** [with_suspend wsh f] - runs [f()] in an environment where [suspend] will work. - - If [f()] suspends with suspension handler [h], - this calls [wsh.on_suspend()] to capture the current state [st]. - Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension. - The suspension should always be passed exactly once to - [resume]. [run] should be used to start other tasks. -*) diff --git a/src/core/task_local_storage.ml b/src/core/task_local_storage.ml index a1266304..f9dd98e6 100644 --- a/src/core/task_local_storage.ml +++ b/src/core/task_local_storage.ml @@ -1,4 +1,5 @@ open Types_ +(* module A = Atomic_ type 'a key = 'a ls_key @@ -79,3 +80,4 @@ let with_value key x f = Fun.protect ~finally:(fun () -> set key old) f let get_current = get_current_storage +*) diff --git a/src/core/task_local_storage.mli b/src/core/task_local_storage.mli index a1da0b0f..4fad8e0e 100644 --- a/src/core/task_local_storage.mli +++ b/src/core/task_local_storage.mli @@ -8,6 +8,7 @@ @since 0.6 *) +(* type t = Types_.local_storage (** Underlying storage for a task. This is mutable and not thread-safe. *) @@ -65,3 +66,4 @@ module Direct : sig val create : unit -> t val copy : t -> t end +*) diff --git a/src/core/types_.ml b/src/core/types_.ml index 141be6dd..20133108 100644 --- a/src/core/types_.ml +++ b/src/core/types_.ml @@ -1,36 +1,37 @@ module TLS = Thread_local_storage module Domain_pool_ = Moonpool_dpool -(* TODO: replace with Picos.Fiber.FLS *) -type ls_value = .. - -(** Key for task local storage *) -module type LS_KEY = sig - type t - type ls_value += V of t - - val offset : int - (** Unique offset *) - - val init : unit -> t -end - -type 'a ls_key = (module LS_KEY with type t = 'a) -(** A LS key (task local storage) *) - type task = unit -> unit -type local_storage = ls_value array ref +type fiber = Picos.Fiber.t type runner = { - run_async: ls:local_storage -> task -> unit; + run_async: fiber:fiber -> task -> unit; shutdown: wait:bool -> unit -> unit; size: unit -> int; num_tasks: unit -> int; } let k_cur_runner : runner TLS.t = TLS.create () -let k_cur_storage : local_storage TLS.t = TLS.create () -let _dummy_ls : local_storage = ref [||] +[@@alert todo "remove me asap, done via picos now"] + +let k_cur_fiber : fiber TLS.t = TLS.create () +[@@alert todo "remove me asap, done via picos now"] + +let _dummy_computation : Picos.Computation.packed = + let c = Picos.Computation.create () in + Picos.Computation.cancel c + { exn = Failure "dummy fiber"; bt = Printexc.get_callstack 0 }; + Picos.Computation.Packed c + +let _dummy_fiber = Picos.Fiber.create_packed ~forbid:true _dummy_computation let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner -let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage -let[@inline] create_local_storage () = ref [||] + +let[@inline] get_current_fiber () : fiber option = + match TLS.get_exn k_cur_fiber with + | f when f != _dummy_fiber -> Some f + | _ -> None + +let[@inline] get_current_fiber_exn () : fiber = + match TLS.get_exn k_cur_fiber with + | f when f != _dummy_fiber -> f + | _ -> failwith "Moonpool: get_current_fiber was called outside of a fiber." diff --git a/src/core/worker_loop_.ml b/src/core/worker_loop_.ml new file mode 100644 index 00000000..df99e169 --- /dev/null +++ b/src/core/worker_loop_.ml @@ -0,0 +1,153 @@ +open Types_ + +type fiber = Picos.Fiber.t + +type task_full = + | T_start of { + fiber: fiber; + f: unit -> unit; + } + | T_resume : { + fiber: fiber; + k: unit -> unit; + } + -> task_full + +type around_task = + | AT_pair : (Runner.t -> 'a) * (Runner.t -> 'a -> unit) -> around_task + +exception No_more_tasks + +type 'st ops = { + schedule: 'st -> task_full -> unit; + get_next_task: 'st -> task_full; (** @raise No_more_tasks *) + get_thread_state: unit -> 'st; + (** Access current thread's worker state from any worker *) + around_task: 'st -> around_task; + on_exn: 'st -> Exn_bt.t -> unit; + runner: 'st -> Runner.t; + before_start: 'st -> unit; + cleanup: 'st -> unit; +} + +(** A dummy task. *) +let _dummy_task : task_full = T_start { f = ignore; fiber = _dummy_fiber } + +[@@@ifge 5.0] + +let[@inline] discontinue k exn = + let bt = Printexc.get_raw_backtrace () in + Effect.Deep.discontinue_with_backtrace k exn bt + +let with_handler (type st arg) ~(ops : st ops) (self : st) : + (unit -> unit) -> unit = + let current = + Some + (fun k -> + match get_current_fiber_exn () with + | fiber -> Effect.Deep.continue k fiber + | exception exn -> discontinue k exn) + and yield = + Some + (fun k -> + let fiber = get_current_fiber_exn () in + match + let k () = Effect.Deep.continue k () in + ops.schedule self @@ T_resume { fiber; k } + with + | () -> () + | exception exn -> discontinue k exn) + and reschedule trigger fiber k : unit = + ignore (Picos.Fiber.unsuspend fiber trigger : bool); + let k () = Picos.Fiber.resume fiber k in + let task = T_resume { fiber; k } in + ops.schedule self task + in + let effc (type a) : + a Effect.t -> ((a, _) Effect.Deep.continuation -> _) option = function + | Picos.Fiber.Current -> current + | Picos.Fiber.Yield -> yield + | Picos.Fiber.Spawn r -> + Some + (fun k -> + match + let f () = r.main r.fiber in + let task = T_start { fiber = r.fiber; f } in + ops.schedule self task + with + | unit -> Effect.Deep.continue k unit + | exception exn -> discontinue k exn) + | Picos.Trigger.Await trigger -> + Some + (fun k -> + let fiber = get_current_fiber_exn () in + (* when triggers is signaled, reschedule task *) + if not (Picos.Fiber.try_suspend fiber trigger fiber k reschedule) then + (* trigger was already signaled, run task now *) + Picos.Fiber.resume fiber k) + | Picos.Computation.Cancel_after _r -> + Some + (fun k -> + (* not implemented *) + let exn = Failure "Moonpool: cancel_after is not implemented" in + discontinue k exn) + | _ -> None + in + let handler = Effect.Deep.{ retc = Fun.id; exnc = raise; effc } in + fun f -> Effect.Deep.match_with f () handler + +[@@@else_] + +let with_handler ~ops:_ self f = f () + +[@@@endif] + +let worker_loop (type st) ~(ops : st ops) (self : st) : unit = + let cur_fiber : fiber ref = ref _dummy_fiber in + let runner = ops.runner self in + TLS.set Runner.For_runner_implementors.k_cur_runner runner; + + let (AT_pair (before_task, after_task)) = ops.around_task self in + + let run_task (task : task_full) : unit = + let fiber = + match task with + | T_start { fiber; _ } | T_resume { fiber; _ } -> fiber + in + + cur_fiber := fiber; + TLS.set k_cur_fiber fiber; + let _ctx = before_task runner in + + (* run the task now, catching errors, handling effects *) + assert (task != _dummy_task); + (try + match task with + | T_start { fiber = _; f } -> with_handler ~ops self f + | T_resume { fiber = _; k } -> + (* this is already in an effect handler *) + k () + with e -> + let ebt = Exn_bt.get e in + ops.on_exn self ebt); + + after_task runner _ctx; + + cur_fiber := _dummy_fiber; + TLS.set k_cur_fiber _dummy_fiber + in + + ops.before_start self; + + let continue = ref true in + try + while !continue do + match ops.get_next_task self with + | task -> run_task task + | exception No_more_tasks -> continue := false + done; + ops.cleanup self + with exn -> + let bt = Printexc.get_raw_backtrace () in + ops.cleanup self; + Printexc.raise_with_backtrace exn bt diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index 13de7a00..ca31ef8a 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -1,6 +1,7 @@ open Types_ -module WSQ = Ws_deque_ module A = Atomic_ +module WSQ = Ws_deque_ +module WL = Worker_loop_ include Runner let ( let@ ) = ( @@ ) @@ -13,46 +14,39 @@ module Id = struct let equal : t -> t -> bool = ( == ) end -type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task - -type task_full = - | T_start of { - ls: Task_local_storage.t; - f: task; - } - | T_resume : { - ls: Task_local_storage.t; - k: 'a -> unit; - x: 'a; - } - -> task_full - -type worker_state = { - pool_id_: Id.t; (** Unique per pool *) - mutable thread: Thread.t; - q: task_full WSQ.t; (** Work stealing queue *) - mutable cur_ls: Task_local_storage.t option; (** Task storage *) - rng: Random.State.t; -} -(** State for a given worker. Only this worker is - allowed to push into the queue, but other workers - can come and steal from it if they're idle. *) - type state = { id_: Id.t; + (** Unique per pool. Used to make sure tasks stay within the same pool. *) active: bool A.t; (** Becomes [false] when the pool is shutdown. *) - workers: worker_state array; (** Fixed set of workers. *) - main_q: task_full Queue.t; + mutable workers: worker_state array; (** Fixed set of workers. *) + main_q: WL.task_full Queue.t; (** Main queue for tasks coming from the outside *) mutable n_waiting: int; (* protected by mutex *) mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *) mutex: Mutex.t; cond: Condition.t; + as_runner: t lazy_t; + (* init options *) + around_task: WL.around_task; + name: string option; + on_init_thread: dom_id:int -> t_id:int -> unit -> unit; + on_exit_thread: dom_id:int -> t_id:int -> unit -> unit; on_exn: exn -> Printexc.raw_backtrace -> unit; - around_task: around_task; } (** internal state *) +and worker_state = { + mutable thread: Thread.t; + idx: int; + dom_id: int; + st: state; + q: WL.task_full WSQ.t; (** Work stealing queue *) + rng: Random.State.t; +} +(** State for a given worker. Only this worker is + allowed to push into the queue, but other workers + can come and steal from it if they're idle. *) + let[@inline] size_ (self : state) = Array.length self.workers let num_tasks_ (self : state) : int = @@ -66,9 +60,15 @@ let num_tasks_ (self : state) : int = sub-tasks. *) let k_worker_state : worker_state TLS.t = TLS.create () -let[@inline] find_current_worker_ () : worker_state option = +let[@inline] get_current_worker_ () : worker_state option = TLS.get_opt k_worker_state +let[@inline] get_current_worker_exn () : worker_state = + match TLS.get_exn k_worker_state with + | w -> w + | exception TLS.Not_set -> + failwith "Moonpool: get_current_runner was called from outside a pool." + (** Try to wake up a waiter, if there's any. *) let[@inline] try_wake_someone_ (self : state) : unit = if self.n_waiting_nonzero then ( @@ -77,194 +77,144 @@ let[@inline] try_wake_someone_ (self : state) : unit = Mutex.unlock self.mutex ) -(** Run [task] as is, on the pool. *) -let schedule_task_ (self : state) ~w (task : task_full) : unit = - (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) - match w with - | Some w when Id.equal self.id_ w.pool_id_ -> - (* we're on this same pool, schedule in the worker's state. Otherwise - we might also be on pool A but asking to schedule on pool B, - so we have to check that identifiers match. *) - let pushed = WSQ.push w.q task in - if pushed then - try_wake_someone_ self - else ( - (* overflow into main queue *) - Mutex.lock self.mutex; - Queue.push task self.main_q; - if self.n_waiting_nonzero then Condition.signal self.cond; - Mutex.unlock self.mutex +let schedule_on_w (self : worker_state) task : unit = + (* we're on this same pool, schedule in the worker's state. Otherwise + we might also be on pool A but asking to schedule on pool B, + so we have to check that identifiers match. *) + let pushed = WSQ.push self.q task in + if pushed then + try_wake_someone_ self.st + else ( + (* overflow into main queue *) + Mutex.lock self.st.mutex; + Queue.push task self.st.main_q; + if self.st.n_waiting_nonzero then Condition.signal self.st.cond; + Mutex.unlock self.st.mutex + ) + +let schedule_on_main (self : state) task : unit = + if A.get self.active then ( + (* push into the main queue *) + Mutex.lock self.mutex; + Queue.push task self.main_q; + if self.n_waiting_nonzero then Condition.signal self.cond; + Mutex.unlock self.mutex + ) else + (* notify the caller that scheduling tasks is no + longer permitted *) + raise Shutdown + +let schedule_from_w (self : worker_state) (task : WL.task_full) : unit = + match get_current_worker_ () with + | Some w when Id.equal self.st.id_ w.st.id_ -> + (* use worker from the same pool *) + schedule_on_w w task + | _ -> schedule_on_main self.st task + +exception Got_task of WL.task_full + +(** Try to steal a task. + @raise Got_task if it finds one. *) +let try_to_steal_work_once_ (self : worker_state) : unit = + let init = Random.State.int self.rng (Array.length self.st.workers) in + for i = 0 to Array.length self.st.workers - 1 do + let w' = + Array.unsafe_get self.st.workers + ((i + init) mod Array.length self.st.workers) + in + + if self != w' then ( + match WSQ.steal w'.q with + | Some t -> raise_notrace (Got_task t) + | None -> () ) - | _ -> - if A.get self.active then ( - (* push into the main queue *) - Mutex.lock self.mutex; - Queue.push task self.main_q; - if self.n_waiting_nonzero then Condition.signal self.cond; - Mutex.unlock self.mutex - ) else - (* notify the caller that scheduling tasks is no - longer permitted *) - raise Shutdown - -(** Run this task, now. Must be called from a worker. *) -let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) - : unit = - (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) - let (AT_pair (before_task, after_task)) = self.around_task in - - let ls = - match task with - | T_start { ls; _ } | T_resume { ls; _ } -> ls - in - - w.cur_ls <- Some ls; - TLS.set k_cur_storage ls; - let _ctx = before_task runner in - - let[@inline] on_suspend () : _ ref = - match find_current_worker_ () with - | Some { cur_ls = Some w; _ } -> w - | _ -> assert false - in - - let run_another_task ls (task' : task) = - let w = - match find_current_worker_ () with - | Some w when Id.equal w.pool_id_ self.id_ -> Some w - | _ -> None - in - let ls' = Task_local_storage.Direct.copy ls in - schedule_task_ self ~w @@ T_start { ls = ls'; f = task' } - in - - let resume ls k x = - let w = - match find_current_worker_ () with - | Some w when Id.equal w.pool_id_ self.id_ -> Some w - | _ -> None - in - schedule_task_ self ~w @@ T_resume { ls; k; x } - in - - (* run the task now, catching errors *) - (try - match task with - | T_start { f = task; _ } -> - (* run [task()] and handle [suspend] in it *) - Suspend_.with_suspend - (WSH { on_suspend; run = run_another_task; resume }) - task - | T_resume { k; x; _ } -> - (* this is already in an effect handler *) - k x - with e -> - let bt = Printexc.get_raw_backtrace () in - self.on_exn e bt); - - after_task runner _ctx; - w.cur_ls <- None; - TLS.set k_cur_storage _dummy_ls - -let run_async_ (self : state) ~ls (f : task) : unit = - let w = find_current_worker_ () in - schedule_task_ self ~w @@ T_start { f; ls } - -(* TODO: function to schedule many tasks from the outside. - - build a queue - - lock - - queue transfer - - wakeup all (broadcast) - - unlock *) + done (** Wait on condition. Precondition: we hold the mutex. *) -let[@inline] wait_ (self : state) : unit = +let[@inline] wait_for_condition_ (self : state) : unit = self.n_waiting <- self.n_waiting + 1; if self.n_waiting = 1 then self.n_waiting_nonzero <- true; Condition.wait self.cond self.mutex; self.n_waiting <- self.n_waiting - 1; if self.n_waiting = 0 then self.n_waiting_nonzero <- false -exception Got_task of task_full +let rec get_next_task (self : worker_state) : WL.task_full = + if not (A.get self.st.active) then raise WL.No_more_tasks; + match WSQ.pop_exn self.q with + | task -> + try_wake_someone_ self.st; + task + | exception WSQ.Empty -> try_steal_from_other_workers_ self -(** Try to steal a task *) -let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option - = - let init = Random.State.int w.rng (Array.length self.workers) in +and try_steal_from_other_workers_ (self : worker_state) = + match try_to_steal_work_once_ self with + | exception Got_task task -> task + | () -> wait_on_worker self - try - for i = 0 to Array.length self.workers - 1 do - let w' = - Array.unsafe_get self.workers ((i + init) mod Array.length self.workers) - in +and wait_on_worker (self : worker_state) : WL.task_full = + Mutex.lock self.st.mutex; + match Queue.pop self.st.main_q with + | task -> + Mutex.unlock self.st.mutex; + task + | exception Queue.Empty -> + (* wait here *) + if A.get self.st.active then ( + wait_for_condition_ self.st; - if w != w' then ( - match WSQ.steal w'.q with - | Some t -> raise_notrace (Got_task t) - | None -> () - ) - done; - None - with Got_task t -> Some t + (* see if a task became available *) + match Queue.pop self.st.main_q with + | task -> + Mutex.unlock self.st.mutex; + task + | exception Queue.Empty -> try_steal_from_other_workers_ self + ) else ( + (* do nothing more: no task in main queue, and we are shutting + down so no new task should arrive. + The exception is if another task is creating subtasks + that overflow into the main queue, but we can ignore that at + the price of slightly decreased performance for the last few + tasks *) + Mutex.unlock self.st.mutex; + raise WL.No_more_tasks + ) -(** Worker runs tasks from its queue until none remains *) -let worker_run_self_tasks_ (self : state) ~runner w : unit = - let continue = ref true in - while !continue && A.get self.active do - match WSQ.pop w.q with - | Some task -> - try_wake_someone_ self; - run_task_now_ self ~runner ~w task - | None -> continue := false - done +let before_start (self : worker_state) : unit = + let t_id = Thread.id @@ Thread.self () in + self.st.on_init_thread ~dom_id:self.dom_id ~t_id (); + TLS.set k_cur_fiber _dummy_fiber; + TLS.set Runner.For_runner_implementors.k_cur_runner + (Lazy.force self.st.as_runner); + TLS.set k_worker_state self; -(** Main loop for a worker thread. *) -let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = - TLS.set Runner.For_runner_implementors.k_cur_runner runner; - TLS.set k_worker_state w; + (* set thread name *) + Option.iter + (fun name -> + Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx)) + self.st.name - let rec main () : unit = - worker_run_self_tasks_ self ~runner w; - try_steal () - and run_task task : unit = - run_task_now_ self ~runner ~w task; - main () - and try_steal () = - match try_to_steal_work_once_ self w with - | Some task -> run_task task - | None -> wait () - and wait () = - Mutex.lock self.mutex; - match Queue.pop self.main_q with - | task -> - Mutex.unlock self.mutex; - run_task task - | exception Queue.Empty -> - (* wait here *) - if A.get self.active then ( - wait_ self; +let cleanup (self : worker_state) : unit = + (* on termination, decrease refcount of underlying domain *) + Domain_pool_.decr_on self.dom_id; + let t_id = Thread.id @@ Thread.self () in + self.st.on_exit_thread ~dom_id:self.dom_id ~t_id () - (* see if a task became available *) - let task = - try Some (Queue.pop self.main_q) with Queue.Empty -> None - in - Mutex.unlock self.mutex; - - match task with - | Some t -> run_task t - | None -> try_steal () - ) else - (* do nothing more: no task in main queue, and we are shutting - down so no new task should arrive. - The exception is if another task is creating subtasks - that overflow into the main queue, but we can ignore that at - the price of slightly decreased performance for the last few - tasks *) - Mutex.unlock self.mutex +let worker_ops : worker_state WL.ops = + let runner (st : worker_state) = Lazy.force st.st.as_runner in + let around_task st = st.st.around_task in + let on_exn (st : worker_state) (ebt : Exn_bt.t) = + st.st.on_exn ebt.exn ebt.bt in - - (* handle domain-local await *) - main () + { + WL.schedule = schedule_from_w; + runner; + get_next_task; + get_thread_state = get_current_worker_exn; + around_task; + on_exn; + before_start; + cleanup; + } let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () @@ -276,6 +226,14 @@ let shutdown_ ~wait (self : state) : unit = if wait then Array.iter (fun w -> Thread.join w.thread) self.workers ) +let as_runner_ (self : state) : t = + Runner.For_runner_implementors.create + ~shutdown:(fun ~wait () -> shutdown_ self ~wait) + ~run_async:(fun ~fiber f -> schedule_on_main self @@ T_start { fiber; f }) + ~size:(fun () -> size_ self) + ~num_tasks:(fun () -> num_tasks_ self) + () + type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> @@ -286,9 +244,6 @@ type ('a, 'b) create_args = 'a (** Arguments used in {!create}. See {!create} for explanations. *) -let dummy_task_ : task_full = - T_start { f = ignore; ls = Task_local_storage.dummy } - let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?around_task ?num_threads ?name () : t = @@ -296,8 +251,8 @@ let create ?(on_init_thread = default_thread_init_exit_) (* wrapper *) let around_task = match around_task with - | Some (f, g) -> AT_pair (f, g) - | None -> AT_pair (ignore, fun _ _ -> ()) + | Some (f, g) -> WL.AT_pair (f, g) + | None -> WL.AT_pair (ignore, fun _ _ -> ()) in let num_domains = Domain_pool_.max_number_of_domains () in @@ -306,23 +261,11 @@ let create ?(on_init_thread = default_thread_init_exit_) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) let offset = Random.int num_domains in - let workers : worker_state array = - let dummy = Thread.self () in - Array.init num_threads (fun i -> - { - pool_id_; - thread = dummy; - q = WSQ.create ~dummy:dummy_task_ (); - rng = Random.State.make [| i |]; - cur_ls = None; - }) - in - - let pool = + let rec pool = { id_ = pool_id_; active = A.make true; - workers; + workers = [||]; main_q = Queue.create (); n_waiting = 0; n_waiting_nonzero = true; @@ -330,65 +273,47 @@ let create ?(on_init_thread = default_thread_init_exit_) cond = Condition.create (); around_task; on_exn; + on_init_thread; + on_exit_thread; + name; + as_runner = lazy (as_runner_ pool); } in - let runner = - Runner.For_runner_implementors.create - ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) - ~run_async:(fun ~ls f -> run_async_ pool ~ls f) - ~size:(fun () -> size_ pool) - ~num_tasks:(fun () -> num_tasks_ pool) - () - in - (* temporary queue used to obtain thread handles from domains on which the thread are started. *) let receive_threads = Bb_queue.create () in (* start the thread with index [i] *) - let start_thread_with_idx i = - let w = pool.workers.(i) in - let dom_idx = (offset + i) mod num_domains in - - (* function run in the thread itself *) - let main_thread_fun () : unit = - let thread = Thread.self () in - let t_id = Thread.id thread in - on_init_thread ~dom_id:dom_idx ~t_id (); - TLS.set k_cur_storage _dummy_ls; - - (* set thread name *) - Option.iter - (fun name -> - Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i)) - name; - - let run () = worker_thread_ pool ~runner w in - - (* now run the main loop *) - Fun.protect run ~finally:(fun () -> - (* on termination, decrease refcount of underlying domain *) - Domain_pool_.decr_on dom_idx); - on_exit_thread ~dom_id:dom_idx ~t_id () + let start_thread_with_idx idx = + let dom_id = (offset + idx) mod num_domains in + let st = + { + st = pool; + thread = (* dummy *) Thread.self (); + q = WSQ.create ~dummy:WL._dummy_task (); + rng = Random.State.make [| idx |]; + dom_id; + idx; + } in (* function called in domain with index [i], to create the thread and push it into [receive_threads] *) let create_thread_in_domain () = - let thread = Thread.create main_thread_fun () in + let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in (* send the thread from the domain back to us *) - Bb_queue.push receive_threads (i, thread) + Bb_queue.push receive_threads (idx, thread) in - Domain_pool_.run_on dom_idx create_thread_in_domain + Domain_pool_.run_on dom_id create_thread_in_domain; + + st in - (* start all threads, placing them on the domains + (* start all worker threads, placing them on the domains according to their index and [offset] in a round-robin fashion. *) - for i = 0 to num_threads - 1 do - start_thread_with_idx i - done; + pool.workers <- Array.init num_threads start_thread_with_idx; (* receive the newly created threads back from domains *) for _j = 1 to num_threads do @@ -397,7 +322,7 @@ let create ?(on_init_thread = default_thread_init_exit_) worker_state.thread <- th done; - runner + Lazy.force pool.as_runner let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads ?name () f = diff --git a/src/private/ws_deque_.ml b/src/private/ws_deque_.ml index 6c5d1419..368cc8b0 100644 --- a/src/private/ws_deque_.ml +++ b/src/private/ws_deque_.ml @@ -72,7 +72,9 @@ let push (self : 'a t) (x : 'a) : bool = true with Full -> false -let pop (self : 'a t) : 'a option = +exception Empty + +let pop_exn (self : 'a t) : 'a = let b = A.get self.bottom in let b = b - 1 in A.set self.bottom b; @@ -84,11 +86,11 @@ let pop (self : 'a t) : 'a option = if size < 0 then ( (* reset to basic empty state *) A.set self.bottom t; - None + raise_notrace Empty ) else if size > 0 then ( (* can pop without modifying [top] *) let x = CA.get self.arr b in - Some x + x ) else ( assert (size = 0); (* there was exactly one slot, so we might be racing against stealers @@ -96,13 +98,18 @@ let pop (self : 'a t) : 'a option = if A.compare_and_set self.top t (t + 1) then ( let x = CA.get self.arr b in A.set self.bottom (t + 1); - Some x + x ) else ( A.set self.bottom (t + 1); - None + raise_notrace Empty ) ) +let[@inline] pop self : _ option = + match pop_exn self with + | exception Empty -> None + | t -> Some t + let steal (self : 'a t) : 'a option = (* read [top], but do not update [top_cached] as we're in another thread *) diff --git a/src/private/ws_deque_.mli b/src/private/ws_deque_.mli index b696224e..0b9fd84a 100644 --- a/src/private/ws_deque_.mli +++ b/src/private/ws_deque_.mli @@ -21,6 +21,10 @@ val pop : 'a t -> 'a option (** Pop value from the bottom of deque. This must be called only by the owner thread. *) +exception Empty + +val pop_exn : 'a t -> 'a + val steal : 'a t -> 'a option (** Try to steal from the top of deque. This is thread-safe. *)