feat pool: rewrite main pool to use work stealing

there's a single blocking queue, and one WS_queue per worker. Scheduling
into the pool from a worker (e.g. via fork_join or explicitly) will push
into this WS queue; otherwise it goes into the main blocking queue.

Workers will always try to empty their local queue first, then try to
work steal, then block on the main queue.
This commit is contained in:
Simon Cruanes 2023-10-25 00:19:34 -04:00
parent f2e9f99b36
commit e67ab53f9f
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
5 changed files with 120 additions and 237 deletions

View file

@ -1,7 +1,7 @@
(library (library
(public_name moonpool) (public_name moonpool)
(name moonpool) (name moonpool)
(private_modules d_pool_) (private_modules d_pool_ dla_)
(preprocess (preprocess
(action (action
(run %{project_root}/src/cpp/cpp.exe %{input-file}))) (run %{project_root}/src/cpp/cpp.exe %{input-file})))

View file

@ -11,4 +11,9 @@ module Fut = Fut
module Lock = Lock module Lock = Lock
module Pool = Pool module Pool = Pool
module Runner = Runner module Runner = Runner
module Suspend_ = Suspend_ module Simple_pool = Simple_pool
module Private = struct
module Ws_deque_ = Ws_deque_
module Suspend_ = Suspend_
end

View file

@ -5,6 +5,7 @@
*) *)
module Pool = Pool module Pool = Pool
module Simple_pool = Simple_pool
module Runner = Runner module Runner = Runner
val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t
@ -141,12 +142,19 @@ module Atomic = Atomic_
This is either a shim using [ref], on pre-OCaml 5, or the This is either a shim using [ref], on pre-OCaml 5, or the
standard [Atomic] module on OCaml 5. *) standard [Atomic] module on OCaml 5. *)
(** {2 Suspensions} *) (**/**)
module Suspend_ = Suspend_ module Private : sig
[@@alert unstable "this module is an implementation detail of moonpool for now"] module Ws_deque_ = Ws_deque_
(** Suspensions.
(** {2 Suspensions} *)
module Suspend_ = Suspend_
[@@alert
unstable "this module is an implementation detail of moonpool for now"]
(** Suspensions.
This is only going to work on OCaml 5.x. This is only going to work on OCaml 5.x.
{b NOTE}: this is not stable for now. *) {b NOTE}: this is not stable for now. *)
end

View file

@ -1,204 +1,76 @@
(* TODO: use a better queue for the tasks *) module WSQ = Ws_deque_
module A = Atomic_
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
(** Thread safe queue, non blocking *)
module TS_queue = struct
type 'a t = {
mutex: Mutex.t;
q: 'a Queue.t;
}
let create () : _ t = { mutex = Mutex.create (); q = Queue.create () }
let try_push (self : _ t) x : bool =
if Mutex.try_lock self.mutex then (
Queue.push x self.q;
Mutex.unlock self.mutex;
true
) else
false
let push (self : _ t) x : unit =
Mutex.lock self.mutex;
Queue.push x self.q;
Mutex.unlock self.mutex
let try_pop ~force_lock (self : _ t) : _ option =
let has_lock =
if force_lock then (
Mutex.lock self.mutex;
true
) else
Mutex.try_lock self.mutex
in
if has_lock then (
match Queue.pop self.q with
| x ->
Mutex.unlock self.mutex;
Some x
| exception Queue.Empty ->
Mutex.unlock self.mutex;
None
) else
None
end
type thread_loop_wrapper = type thread_loop_wrapper =
thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit
let global_thread_wrappers_ : thread_loop_wrapper list A.t = A.make [] type worker_state = {
mutable thread: Thread.t;
let add_global_thread_loop_wrapper f : unit = q: task WSQ.t; (** Work stealing queue *)
while }
let l = A.get global_thread_wrappers_ in
not (A.compare_and_set global_thread_wrappers_ l (f :: l))
do
Domain_.relax ()
done
type state = { type state = {
active: bool A.t; workers: worker_state array;
threads: Thread.t array; main_q: task Bb_queue.t; (** Main queue to block on *)
qs: task TS_queue.t array;
num_tasks: int A.t;
mutex: Mutex.t;
cond: Condition.t;
cur_q: int A.t; (** Selects queue into which to push *)
} }
(** internal state *) (** internal state *)
let[@inline] size_ (self : state) = Array.length self.threads let[@inline] size_ (self : state) = Array.length self.workers
let[@inline] num_tasks_ (self : state) : int = A.get self.num_tasks
let awake_workers_ (self : state) : unit = let num_tasks_ (self : state) : int =
Mutex.lock self.mutex; let n = ref (Bb_queue.size self.main_q) in
Condition.broadcast self.cond; Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
Mutex.unlock self.mutex !n
exception Got_worker of worker_state
let find_current_worker_ (self : state) : worker_state option =
let self_id = Thread.id @@ Thread.self () in
try
(* see if we're in one of the worker threads *)
for i = 0 to Array.length self.workers - 1 do
let w = self.workers.(i) in
if Thread.id w.thread = self_id then raise_notrace (Got_worker w)
done;
None
with Got_worker w -> Some w
(** Run [task] as is, on the pool. *) (** Run [task] as is, on the pool. *)
let run_direct_ (self : state) (task : task) : unit = let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
let n_qs = Array.length self.qs in match w with
let offset = A.fetch_and_add self.cur_q 1 in | Some w -> WSQ.push w.q task
| None -> Bb_queue.push self.main_q task
(* push that forces lock acquisition, last resort *) let run_async_ (self : state) (task : task) : unit =
let[@inline] push_wait f = (* stay on current worker if possible *)
let q_idx = offset mod Array.length self.qs in let w = find_current_worker_ self in
let q = self.qs.(q_idx) in
TS_queue.push q f let rec run_async_rec_ (task : task) =
let task_with_suspend_ () =
(* run [f()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:(fun ~with_handler task' ->
if with_handler then
run_async_rec_ task'
else
run_direct_ self w task')
in
run_direct_ self w task_with_suspend_
in in
run_async_rec_ task
(try
(* try each queue with a round-robin initial offset *)
for _retry = 1 to 10 do
for i = 0 to n_qs - 1 do
let q_idx = (i + offset) mod Array.length self.qs in
let q = self.qs.(q_idx) in
if TS_queue.try_push q task then raise_notrace Exit
done
done;
push_wait task
with Exit -> ());
(* successfully pushed, now see if we need to wakeup workers *)
let old_num_tasks = A.fetch_and_add self.num_tasks 1 in
if old_num_tasks < size_ self then awake_workers_ self
let rec run_async_ (self : state) (task : task) : unit =
let task' () =
(* run [f()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:(fun ~with_handler task ->
if with_handler then
run_async_ self task
else
run_direct_ self task)
in
run_direct_ self task'
let run = run_async let run = run_async
[@@@ifge 5.0]
(* DLA interop *)
let prepare_for_await () : Dla_.t =
(* current state *)
let st :
((with_handler:bool -> task -> unit) * Suspend_.suspension) option A.t =
A.make None
in
let release () : unit =
match A.exchange st None with
| None -> ()
| Some (run, k) -> run ~with_handler:true (fun () -> k (Ok ()))
and await () : unit =
Suspend_.suspend
{ Suspend_.handle = (fun ~run k -> A.set st (Some (run, k))) }
in
let t = { Dla_.release; await } in
t
[@@@else_]
let prepare_for_await () = { Dla_.release = ignore; await = ignore }
[@@@endif]
exception Got_task of task exception Got_task of task
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
exception Closed let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
~around_task : unit =
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task
~(offset : int) : unit =
let num_qs = Array.length self.qs in
let (AT_pair (before_task, after_task)) = around_task in let (AT_pair (before_task, after_task)) = around_task in
(* try to get a task that is already in one of the queues. (* run this task. *)
@param force_lock if true, we force acquisition of the queue's mutex,
which is slower but always succeeds to get a task if there's one. *)
let get_task_already_in_queues ~force_lock () : _ option =
try
for _retry = 1 to 3 do
for i = 0 to num_qs - 1 do
let q = self.qs.((offset + i) mod num_qs) in
match TS_queue.try_pop ~force_lock q with
| Some f -> raise_notrace (Got_task f)
| None -> ()
done
done;
None
with Got_task f ->
A.decr self.num_tasks;
Some f
in
(* slow path: force locking when trying to get tasks,
and wait on [self.cond] if no task is currently available. *)
let pop_blocking () : task =
try
while A.get self.active do
match get_task_already_in_queues ~force_lock:true () with
| Some t -> raise_notrace (Got_task t)
| None ->
Mutex.lock self.mutex;
(* NOTE: be careful about race conditions: we must only
block if the [shutdown] that sets [active] to [false]
has not broadcast over this condition first. Otherwise
we might miss the signal and wait here forever. *)
if A.get self.active then Condition.wait self.cond self.mutex;
Mutex.unlock self.mutex
done;
raise Closed
with Got_task t -> t
in
let run_task task : unit = let run_task task : unit =
let _ctx = before_task runner in let _ctx = before_task runner in
(* run the task now, catching errors *) (* run the task now, catching errors *)
@ -209,51 +81,69 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task
after_task runner _ctx after_task runner _ctx
in in
(* drain the queues from existing tasks. If [force_lock=false] let run_self_tasks_ () =
then it is best effort. *)
let run_tasks_already_present ~force_lock () =
let continue = ref true in let continue = ref true in
let pop_retries = ref 0 in
while !continue do while !continue do
match get_task_already_in_queues ~force_lock () with match WSQ.pop w.q with
| None -> continue := false | Some task ->
| Some task -> run_task task pop_retries := 0;
run_task task
| None ->
incr pop_retries;
if !pop_retries > 10 then continue := false
done done
in in
(* get a task from another worker *)
let try_to_steal_work () : task option =
try
for _retry = 1 to 3 do
Array.iter
(fun w' ->
if w != w' then (
match WSQ.steal w'.q with
| None -> ()
| Some task -> raise_notrace (Got_task task)
))
self.workers
done;
None
with Got_task task -> Some task
in
let main_loop () = let main_loop () =
while A.get self.active do let steal_attempts = ref 0 in
run_tasks_already_present ~force_lock:false (); while true do
run_self_tasks_ ();
(* no task available, block until one comes *) match try_to_steal_work () with
match pop_blocking () with | Some task ->
| exception Closed -> () steal_attempts := 0;
| task -> run_task task run_task task
done; | None ->
incr steal_attempts;
Domain_.relax ();
(* cleanup *) if !steal_attempts > 10 then (
run_tasks_already_present ~force_lock:true () steal_attempts := 0;
let task = Bb_queue.pop self.main_q in
run_task task
)
done
in in
try try
(* handle domain-local await *) (* handle domain-local await *)
Dla_.using ~prepare_for_await ~while_running:main_loop Dla_.using ~prepare_for_await:Suspend_.prepare_for_await
~while_running:main_loop
with Bb_queue.Closed -> () with Bb_queue.Closed -> ()
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
(** We want a reasonable number of queues. Even if your system is
a beast with hundreds of cores, trying
to work-steal through hundreds of queues will have a cost.
Hence, we limit the number of queues to at most 32 (number picked
via the ancestral technique of the pifomètre). *)
let max_queues = 32
let shutdown_ ~wait (self : state) : unit = let shutdown_ ~wait (self : state) : unit =
let was_active = A.exchange self.active false in Bb_queue.close self.main_q;
(* wake up the subset of [self.threads] that are waiting on new tasks *) if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
if was_active then awake_workers_ self;
if wait then Array.iter Thread.join self.threads
type ('a, 'b) create_args = type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
@ -286,24 +176,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in let offset = Random.int num_domains in
let active = A.make true in let workers : worker_state array =
let qs = let dummy = Thread.self () in
let num_qs = min (min num_domains num_threads) max_queues in Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
Array.init num_qs (fun _ -> TS_queue.create ())
in in
let pool = let pool = { workers; main_q = Bb_queue.create () } in
let dummy = Thread.self () in
{
active;
threads = Array.make num_threads dummy;
num_tasks = A.make 0;
qs;
mutex = Mutex.create ();
cond = Condition.create ();
cur_q = A.make 0;
}
in
let runner = let runner =
Runner.For_runner_implementors.create Runner.For_runner_implementors.create
@ -320,6 +198,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* start the thread with index [i] *) (* start the thread with index [i] *)
let start_thread_with_idx i = let start_thread_with_idx i =
let w = pool.workers.(i) in
let dom_idx = (offset + i) mod num_domains in let dom_idx = (offset + i) mod num_domains in
(* function run in the thread itself *) (* function run in the thread itself *)
@ -328,17 +207,13 @@ let create ?(on_init_thread = default_thread_init_exit_)
let t_id = Thread.id thread in let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id (); on_init_thread ~dom_id:dom_idx ~t_id ();
let all_wrappers = let run () = worker_thread_ pool runner w ~on_exn ~around_task in
List.rev_append thread_wrappers (A.get global_thread_wrappers_)
in
let run () = worker_thread_ pool runner ~on_exn ~around_task ~offset:i in
(* the actual worker loop is [worker_thread_], with all (* the actual worker loop is [worker_thread_], with all
wrappers for this pool and for all pools (global_thread_wrappers_) *) wrappers for this pool and for all pools (global_thread_wrappers_) *)
let run' = let run' =
List.fold_left List.fold_left
(fun run f -> f ~thread ~pool:runner run) (fun run f -> f ~thread ~pool:runner run)
run all_wrappers run thread_wrappers
in in
(* now run the main loop *) (* now run the main loop *)
@ -368,7 +243,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* receive the newly created threads back from domains *) (* receive the newly created threads back from domains *)
for _j = 1 to num_threads do for _j = 1 to num_threads do
let i, th = Bb_queue.pop receive_threads in let i, th = Bb_queue.pop receive_threads in
pool.threads.(i) <- th pool.workers.(i).thread <- th
done; done;
runner runner

View file

@ -23,11 +23,6 @@ type thread_loop_wrapper =
By default it just returns the same loop function but it can be used By default it just returns the same loop function but it can be used
to install tracing, effect handlers, etc. *) to install tracing, effect handlers, etc. *)
val add_global_thread_loop_wrapper : thread_loop_wrapper -> unit
(** [add_global_thread_loop_wrapper f] installs [f] to be installed in every new pool worker
thread, for all existing pools, and all new pools created with [create].
These wrappers accumulate: they all apply, but their order is not specified. *)
type ('a, 'b) create_args = type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->