From e67ab53f9f12de163cba6b0684836d88020f0939 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 25 Oct 2023 00:19:34 -0400 Subject: [PATCH] feat pool: rewrite main pool to use work stealing there's a single blocking queue, and one WS_queue per worker. Scheduling into the pool from a worker (e.g. via fork_join or explicitly) will push into this WS queue; otherwise it goes into the main blocking queue. Workers will always try to empty their local queue first, then try to work steal, then block on the main queue. --- src/dune | 2 +- src/moonpool.ml | 7 +- src/moonpool.mli | 16 ++- src/pool.ml | 327 +++++++++++++++-------------------------------- src/pool.mli | 5 - 5 files changed, 120 insertions(+), 237 deletions(-) diff --git a/src/dune b/src/dune index d65920e8..313191a5 100644 --- a/src/dune +++ b/src/dune @@ -1,7 +1,7 @@ (library (public_name moonpool) (name moonpool) - (private_modules d_pool_) + (private_modules d_pool_ dla_) (preprocess (action (run %{project_root}/src/cpp/cpp.exe %{input-file}))) diff --git a/src/moonpool.ml b/src/moonpool.ml index 83ae22a8..97da4d2a 100644 --- a/src/moonpool.ml +++ b/src/moonpool.ml @@ -11,4 +11,9 @@ module Fut = Fut module Lock = Lock module Pool = Pool module Runner = Runner -module Suspend_ = Suspend_ +module Simple_pool = Simple_pool + +module Private = struct + module Ws_deque_ = Ws_deque_ + module Suspend_ = Suspend_ +end diff --git a/src/moonpool.mli b/src/moonpool.mli index 1d300665..74b48772 100644 --- a/src/moonpool.mli +++ b/src/moonpool.mli @@ -5,6 +5,7 @@ *) module Pool = Pool +module Simple_pool = Simple_pool module Runner = Runner val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t @@ -141,12 +142,19 @@ module Atomic = Atomic_ This is either a shim using [ref], on pre-OCaml 5, or the standard [Atomic] module on OCaml 5. *) -(** {2 Suspensions} *) +(**/**) -module Suspend_ = Suspend_ -[@@alert unstable "this module is an implementation detail of moonpool for now"] -(** Suspensions. +module Private : sig + module Ws_deque_ = Ws_deque_ + + (** {2 Suspensions} *) + + module Suspend_ = Suspend_ + [@@alert + unstable "this module is an implementation detail of moonpool for now"] + (** Suspensions. This is only going to work on OCaml 5.x. {b NOTE}: this is not stable for now. *) +end diff --git a/src/pool.ml b/src/pool.ml index 590a9586..4ce08f76 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -1,204 +1,76 @@ -(* TODO: use a better queue for the tasks *) - -module A = Atomic_ +module WSQ = Ws_deque_ include Runner let ( let@ ) = ( @@ ) -(** Thread safe queue, non blocking *) -module TS_queue = struct - type 'a t = { - mutex: Mutex.t; - q: 'a Queue.t; - } - - let create () : _ t = { mutex = Mutex.create (); q = Queue.create () } - - let try_push (self : _ t) x : bool = - if Mutex.try_lock self.mutex then ( - Queue.push x self.q; - Mutex.unlock self.mutex; - true - ) else - false - - let push (self : _ t) x : unit = - Mutex.lock self.mutex; - Queue.push x self.q; - Mutex.unlock self.mutex - - let try_pop ~force_lock (self : _ t) : _ option = - let has_lock = - if force_lock then ( - Mutex.lock self.mutex; - true - ) else - Mutex.try_lock self.mutex - in - if has_lock then ( - match Queue.pop self.q with - | x -> - Mutex.unlock self.mutex; - Some x - | exception Queue.Empty -> - Mutex.unlock self.mutex; - None - ) else - None -end - type thread_loop_wrapper = thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit -let global_thread_wrappers_ : thread_loop_wrapper list A.t = A.make [] - -let add_global_thread_loop_wrapper f : unit = - while - let l = A.get global_thread_wrappers_ in - not (A.compare_and_set global_thread_wrappers_ l (f :: l)) - do - Domain_.relax () - done +type worker_state = { + mutable thread: Thread.t; + q: task WSQ.t; (** Work stealing queue *) +} type state = { - active: bool A.t; - threads: Thread.t array; - qs: task TS_queue.t array; - num_tasks: int A.t; - mutex: Mutex.t; - cond: Condition.t; - cur_q: int A.t; (** Selects queue into which to push *) + workers: worker_state array; + main_q: task Bb_queue.t; (** Main queue to block on *) } (** internal state *) -let[@inline] size_ (self : state) = Array.length self.threads -let[@inline] num_tasks_ (self : state) : int = A.get self.num_tasks +let[@inline] size_ (self : state) = Array.length self.workers -let awake_workers_ (self : state) : unit = - Mutex.lock self.mutex; - Condition.broadcast self.cond; - Mutex.unlock self.mutex +let num_tasks_ (self : state) : int = + let n = ref (Bb_queue.size self.main_q) in + Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; + !n + +exception Got_worker of worker_state + +let find_current_worker_ (self : state) : worker_state option = + let self_id = Thread.id @@ Thread.self () in + try + (* see if we're in one of the worker threads *) + for i = 0 to Array.length self.workers - 1 do + let w = self.workers.(i) in + if Thread.id w.thread = self_id then raise_notrace (Got_worker w) + done; + None + with Got_worker w -> Some w (** Run [task] as is, on the pool. *) -let run_direct_ (self : state) (task : task) : unit = - let n_qs = Array.length self.qs in - let offset = A.fetch_and_add self.cur_q 1 in +let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = + match w with + | Some w -> WSQ.push w.q task + | None -> Bb_queue.push self.main_q task - (* push that forces lock acquisition, last resort *) - let[@inline] push_wait f = - let q_idx = offset mod Array.length self.qs in - let q = self.qs.(q_idx) in - TS_queue.push q f +let run_async_ (self : state) (task : task) : unit = + (* stay on current worker if possible *) + let w = find_current_worker_ self in + + let rec run_async_rec_ (task : task) = + let task_with_suspend_ () = + (* run [f()] and handle [suspend] in it *) + Suspend_.with_suspend task ~run:(fun ~with_handler task' -> + if with_handler then + run_async_rec_ task' + else + run_direct_ self w task') + in + run_direct_ self w task_with_suspend_ in - - (try - (* try each queue with a round-robin initial offset *) - for _retry = 1 to 10 do - for i = 0 to n_qs - 1 do - let q_idx = (i + offset) mod Array.length self.qs in - let q = self.qs.(q_idx) in - - if TS_queue.try_push q task then raise_notrace Exit - done - done; - push_wait task - with Exit -> ()); - - (* successfully pushed, now see if we need to wakeup workers *) - let old_num_tasks = A.fetch_and_add self.num_tasks 1 in - if old_num_tasks < size_ self then awake_workers_ self - -let rec run_async_ (self : state) (task : task) : unit = - let task' () = - (* run [f()] and handle [suspend] in it *) - Suspend_.with_suspend task ~run:(fun ~with_handler task -> - if with_handler then - run_async_ self task - else - run_direct_ self task) - in - run_direct_ self task' + run_async_rec_ task let run = run_async -[@@@ifge 5.0] - -(* DLA interop *) -let prepare_for_await () : Dla_.t = - (* current state *) - let st : - ((with_handler:bool -> task -> unit) * Suspend_.suspension) option A.t = - A.make None - in - - let release () : unit = - match A.exchange st None with - | None -> () - | Some (run, k) -> run ~with_handler:true (fun () -> k (Ok ())) - and await () : unit = - Suspend_.suspend - { Suspend_.handle = (fun ~run k -> A.set st (Some (run, k))) } - in - - let t = { Dla_.release; await } in - t - -[@@@else_] - -let prepare_for_await () = { Dla_.release = ignore; await = ignore } - -[@@@endif] - exception Got_task of task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task -exception Closed - -let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task - ~(offset : int) : unit = - let num_qs = Array.length self.qs in +let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn + ~around_task : unit = let (AT_pair (before_task, after_task)) = around_task in - (* try to get a task that is already in one of the queues. - @param force_lock if true, we force acquisition of the queue's mutex, - which is slower but always succeeds to get a task if there's one. *) - let get_task_already_in_queues ~force_lock () : _ option = - try - for _retry = 1 to 3 do - for i = 0 to num_qs - 1 do - let q = self.qs.((offset + i) mod num_qs) in - match TS_queue.try_pop ~force_lock q with - | Some f -> raise_notrace (Got_task f) - | None -> () - done - done; - None - with Got_task f -> - A.decr self.num_tasks; - Some f - in - - (* slow path: force locking when trying to get tasks, - and wait on [self.cond] if no task is currently available. *) - let pop_blocking () : task = - try - while A.get self.active do - match get_task_already_in_queues ~force_lock:true () with - | Some t -> raise_notrace (Got_task t) - | None -> - Mutex.lock self.mutex; - (* NOTE: be careful about race conditions: we must only - block if the [shutdown] that sets [active] to [false] - has not broadcast over this condition first. Otherwise - we might miss the signal and wait here forever. *) - if A.get self.active then Condition.wait self.cond self.mutex; - Mutex.unlock self.mutex - done; - raise Closed - with Got_task t -> t - in - + (* run this task. *) let run_task task : unit = let _ctx = before_task runner in (* run the task now, catching errors *) @@ -209,51 +81,69 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task after_task runner _ctx in - (* drain the queues from existing tasks. If [force_lock=false] - then it is best effort. *) - let run_tasks_already_present ~force_lock () = + let run_self_tasks_ () = let continue = ref true in + let pop_retries = ref 0 in while !continue do - match get_task_already_in_queues ~force_lock () with - | None -> continue := false - | Some task -> run_task task + match WSQ.pop w.q with + | Some task -> + pop_retries := 0; + run_task task + | None -> + incr pop_retries; + if !pop_retries > 10 then continue := false done in + (* get a task from another worker *) + let try_to_steal_work () : task option = + try + for _retry = 1 to 3 do + Array.iter + (fun w' -> + if w != w' then ( + match WSQ.steal w'.q with + | None -> () + | Some task -> raise_notrace (Got_task task) + )) + self.workers + done; + None + with Got_task task -> Some task + in + let main_loop () = - while A.get self.active do - run_tasks_already_present ~force_lock:false (); + let steal_attempts = ref 0 in + while true do + run_self_tasks_ (); - (* no task available, block until one comes *) - match pop_blocking () with - | exception Closed -> () - | task -> run_task task - done; + match try_to_steal_work () with + | Some task -> + steal_attempts := 0; + run_task task + | None -> + incr steal_attempts; + Domain_.relax (); - (* cleanup *) - run_tasks_already_present ~force_lock:true () + if !steal_attempts > 10 then ( + steal_attempts := 0; + let task = Bb_queue.pop self.main_q in + run_task task + ) + done in try (* handle domain-local await *) - Dla_.using ~prepare_for_await ~while_running:main_loop + Dla_.using ~prepare_for_await:Suspend_.prepare_for_await + ~while_running:main_loop with Bb_queue.Closed -> () let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () -(** We want a reasonable number of queues. Even if your system is - a beast with hundreds of cores, trying - to work-steal through hundreds of queues will have a cost. - - Hence, we limit the number of queues to at most 32 (number picked - via the ancestral technique of the pifomètre). *) -let max_queues = 32 - let shutdown_ ~wait (self : state) : unit = - let was_active = A.exchange self.active false in - (* wake up the subset of [self.threads] that are waiting on new tasks *) - if was_active then awake_workers_ self; - if wait then Array.iter Thread.join self.threads + Bb_queue.close self.main_q; + if wait then Array.iter (fun w -> Thread.join w.thread) self.workers type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> @@ -286,24 +176,12 @@ let create ?(on_init_thread = default_thread_init_exit_) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) let offset = Random.int num_domains in - let active = A.make true in - let qs = - let num_qs = min (min num_domains num_threads) max_queues in - Array.init num_qs (fun _ -> TS_queue.create ()) + let workers : worker_state array = + let dummy = Thread.self () in + Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () }) in - let pool = - let dummy = Thread.self () in - { - active; - threads = Array.make num_threads dummy; - num_tasks = A.make 0; - qs; - mutex = Mutex.create (); - cond = Condition.create (); - cur_q = A.make 0; - } - in + let pool = { workers; main_q = Bb_queue.create () } in let runner = Runner.For_runner_implementors.create @@ -320,6 +198,7 @@ let create ?(on_init_thread = default_thread_init_exit_) (* start the thread with index [i] *) let start_thread_with_idx i = + let w = pool.workers.(i) in let dom_idx = (offset + i) mod num_domains in (* function run in the thread itself *) @@ -328,17 +207,13 @@ let create ?(on_init_thread = default_thread_init_exit_) let t_id = Thread.id thread in on_init_thread ~dom_id:dom_idx ~t_id (); - let all_wrappers = - List.rev_append thread_wrappers (A.get global_thread_wrappers_) - in - - let run () = worker_thread_ pool runner ~on_exn ~around_task ~offset:i in + let run () = worker_thread_ pool runner w ~on_exn ~around_task in (* the actual worker loop is [worker_thread_], with all wrappers for this pool and for all pools (global_thread_wrappers_) *) let run' = List.fold_left (fun run f -> f ~thread ~pool:runner run) - run all_wrappers + run thread_wrappers in (* now run the main loop *) @@ -368,7 +243,7 @@ let create ?(on_init_thread = default_thread_init_exit_) (* receive the newly created threads back from domains *) for _j = 1 to num_threads do let i, th = Bb_queue.pop receive_threads in - pool.threads.(i) <- th + pool.workers.(i).thread <- th done; runner diff --git a/src/pool.mli b/src/pool.mli index 11cac88b..f7a42633 100644 --- a/src/pool.mli +++ b/src/pool.mli @@ -23,11 +23,6 @@ type thread_loop_wrapper = By default it just returns the same loop function but it can be used to install tracing, effect handlers, etc. *) -val add_global_thread_loop_wrapper : thread_loop_wrapper -> unit -(** [add_global_thread_loop_wrapper f] installs [f] to be installed in every new pool worker - thread, for all existing pools, and all new pools created with [create]. - These wrappers accumulate: they all apply, but their order is not specified. *) - type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->