mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-06 03:05:30 -05:00
feat pool: rewrite main pool to use work stealing
there's a single blocking queue, and one WS_queue per worker. Scheduling into the pool from a worker (e.g. via fork_join or explicitly) will push into this WS queue; otherwise it goes into the main blocking queue. Workers will always try to empty their local queue first, then try to work steal, then block on the main queue.
This commit is contained in:
parent
f2e9f99b36
commit
e67ab53f9f
5 changed files with 120 additions and 237 deletions
2
src/dune
2
src/dune
|
|
@ -1,7 +1,7 @@
|
|||
(library
|
||||
(public_name moonpool)
|
||||
(name moonpool)
|
||||
(private_modules d_pool_)
|
||||
(private_modules d_pool_ dla_)
|
||||
(preprocess
|
||||
(action
|
||||
(run %{project_root}/src/cpp/cpp.exe %{input-file})))
|
||||
|
|
|
|||
|
|
@ -11,4 +11,9 @@ module Fut = Fut
|
|||
module Lock = Lock
|
||||
module Pool = Pool
|
||||
module Runner = Runner
|
||||
module Suspend_ = Suspend_
|
||||
module Simple_pool = Simple_pool
|
||||
|
||||
module Private = struct
|
||||
module Ws_deque_ = Ws_deque_
|
||||
module Suspend_ = Suspend_
|
||||
end
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
*)
|
||||
|
||||
module Pool = Pool
|
||||
module Simple_pool = Simple_pool
|
||||
module Runner = Runner
|
||||
|
||||
val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t
|
||||
|
|
@ -141,12 +142,19 @@ module Atomic = Atomic_
|
|||
This is either a shim using [ref], on pre-OCaml 5, or the
|
||||
standard [Atomic] module on OCaml 5. *)
|
||||
|
||||
(** {2 Suspensions} *)
|
||||
(**/**)
|
||||
|
||||
module Suspend_ = Suspend_
|
||||
[@@alert unstable "this module is an implementation detail of moonpool for now"]
|
||||
(** Suspensions.
|
||||
module Private : sig
|
||||
module Ws_deque_ = Ws_deque_
|
||||
|
||||
(** {2 Suspensions} *)
|
||||
|
||||
module Suspend_ = Suspend_
|
||||
[@@alert
|
||||
unstable "this module is an implementation detail of moonpool for now"]
|
||||
(** Suspensions.
|
||||
|
||||
This is only going to work on OCaml 5.x.
|
||||
|
||||
{b NOTE}: this is not stable for now. *)
|
||||
end
|
||||
|
|
|
|||
327
src/pool.ml
327
src/pool.ml
|
|
@ -1,204 +1,76 @@
|
|||
(* TODO: use a better queue for the tasks *)
|
||||
|
||||
module A = Atomic_
|
||||
module WSQ = Ws_deque_
|
||||
include Runner
|
||||
|
||||
let ( let@ ) = ( @@ )
|
||||
|
||||
(** Thread safe queue, non blocking *)
|
||||
module TS_queue = struct
|
||||
type 'a t = {
|
||||
mutex: Mutex.t;
|
||||
q: 'a Queue.t;
|
||||
}
|
||||
|
||||
let create () : _ t = { mutex = Mutex.create (); q = Queue.create () }
|
||||
|
||||
let try_push (self : _ t) x : bool =
|
||||
if Mutex.try_lock self.mutex then (
|
||||
Queue.push x self.q;
|
||||
Mutex.unlock self.mutex;
|
||||
true
|
||||
) else
|
||||
false
|
||||
|
||||
let push (self : _ t) x : unit =
|
||||
Mutex.lock self.mutex;
|
||||
Queue.push x self.q;
|
||||
Mutex.unlock self.mutex
|
||||
|
||||
let try_pop ~force_lock (self : _ t) : _ option =
|
||||
let has_lock =
|
||||
if force_lock then (
|
||||
Mutex.lock self.mutex;
|
||||
true
|
||||
) else
|
||||
Mutex.try_lock self.mutex
|
||||
in
|
||||
if has_lock then (
|
||||
match Queue.pop self.q with
|
||||
| x ->
|
||||
Mutex.unlock self.mutex;
|
||||
Some x
|
||||
| exception Queue.Empty ->
|
||||
Mutex.unlock self.mutex;
|
||||
None
|
||||
) else
|
||||
None
|
||||
end
|
||||
|
||||
type thread_loop_wrapper =
|
||||
thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit
|
||||
|
||||
let global_thread_wrappers_ : thread_loop_wrapper list A.t = A.make []
|
||||
|
||||
let add_global_thread_loop_wrapper f : unit =
|
||||
while
|
||||
let l = A.get global_thread_wrappers_ in
|
||||
not (A.compare_and_set global_thread_wrappers_ l (f :: l))
|
||||
do
|
||||
Domain_.relax ()
|
||||
done
|
||||
type worker_state = {
|
||||
mutable thread: Thread.t;
|
||||
q: task WSQ.t; (** Work stealing queue *)
|
||||
}
|
||||
|
||||
type state = {
|
||||
active: bool A.t;
|
||||
threads: Thread.t array;
|
||||
qs: task TS_queue.t array;
|
||||
num_tasks: int A.t;
|
||||
mutex: Mutex.t;
|
||||
cond: Condition.t;
|
||||
cur_q: int A.t; (** Selects queue into which to push *)
|
||||
workers: worker_state array;
|
||||
main_q: task Bb_queue.t; (** Main queue to block on *)
|
||||
}
|
||||
(** internal state *)
|
||||
|
||||
let[@inline] size_ (self : state) = Array.length self.threads
|
||||
let[@inline] num_tasks_ (self : state) : int = A.get self.num_tasks
|
||||
let[@inline] size_ (self : state) = Array.length self.workers
|
||||
|
||||
let awake_workers_ (self : state) : unit =
|
||||
Mutex.lock self.mutex;
|
||||
Condition.broadcast self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
let num_tasks_ (self : state) : int =
|
||||
let n = ref (Bb_queue.size self.main_q) in
|
||||
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
||||
!n
|
||||
|
||||
exception Got_worker of worker_state
|
||||
|
||||
let find_current_worker_ (self : state) : worker_state option =
|
||||
let self_id = Thread.id @@ Thread.self () in
|
||||
try
|
||||
(* see if we're in one of the worker threads *)
|
||||
for i = 0 to Array.length self.workers - 1 do
|
||||
let w = self.workers.(i) in
|
||||
if Thread.id w.thread = self_id then raise_notrace (Got_worker w)
|
||||
done;
|
||||
None
|
||||
with Got_worker w -> Some w
|
||||
|
||||
(** Run [task] as is, on the pool. *)
|
||||
let run_direct_ (self : state) (task : task) : unit =
|
||||
let n_qs = Array.length self.qs in
|
||||
let offset = A.fetch_and_add self.cur_q 1 in
|
||||
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
||||
match w with
|
||||
| Some w -> WSQ.push w.q task
|
||||
| None -> Bb_queue.push self.main_q task
|
||||
|
||||
(* push that forces lock acquisition, last resort *)
|
||||
let[@inline] push_wait f =
|
||||
let q_idx = offset mod Array.length self.qs in
|
||||
let q = self.qs.(q_idx) in
|
||||
TS_queue.push q f
|
||||
let run_async_ (self : state) (task : task) : unit =
|
||||
(* stay on current worker if possible *)
|
||||
let w = find_current_worker_ self in
|
||||
|
||||
let rec run_async_rec_ (task : task) =
|
||||
let task_with_suspend_ () =
|
||||
(* run [f()] and handle [suspend] in it *)
|
||||
Suspend_.with_suspend task ~run:(fun ~with_handler task' ->
|
||||
if with_handler then
|
||||
run_async_rec_ task'
|
||||
else
|
||||
run_direct_ self w task')
|
||||
in
|
||||
run_direct_ self w task_with_suspend_
|
||||
in
|
||||
|
||||
(try
|
||||
(* try each queue with a round-robin initial offset *)
|
||||
for _retry = 1 to 10 do
|
||||
for i = 0 to n_qs - 1 do
|
||||
let q_idx = (i + offset) mod Array.length self.qs in
|
||||
let q = self.qs.(q_idx) in
|
||||
|
||||
if TS_queue.try_push q task then raise_notrace Exit
|
||||
done
|
||||
done;
|
||||
push_wait task
|
||||
with Exit -> ());
|
||||
|
||||
(* successfully pushed, now see if we need to wakeup workers *)
|
||||
let old_num_tasks = A.fetch_and_add self.num_tasks 1 in
|
||||
if old_num_tasks < size_ self then awake_workers_ self
|
||||
|
||||
let rec run_async_ (self : state) (task : task) : unit =
|
||||
let task' () =
|
||||
(* run [f()] and handle [suspend] in it *)
|
||||
Suspend_.with_suspend task ~run:(fun ~with_handler task ->
|
||||
if with_handler then
|
||||
run_async_ self task
|
||||
else
|
||||
run_direct_ self task)
|
||||
in
|
||||
run_direct_ self task'
|
||||
run_async_rec_ task
|
||||
|
||||
let run = run_async
|
||||
|
||||
[@@@ifge 5.0]
|
||||
|
||||
(* DLA interop *)
|
||||
let prepare_for_await () : Dla_.t =
|
||||
(* current state *)
|
||||
let st :
|
||||
((with_handler:bool -> task -> unit) * Suspend_.suspension) option A.t =
|
||||
A.make None
|
||||
in
|
||||
|
||||
let release () : unit =
|
||||
match A.exchange st None with
|
||||
| None -> ()
|
||||
| Some (run, k) -> run ~with_handler:true (fun () -> k (Ok ()))
|
||||
and await () : unit =
|
||||
Suspend_.suspend
|
||||
{ Suspend_.handle = (fun ~run k -> A.set st (Some (run, k))) }
|
||||
in
|
||||
|
||||
let t = { Dla_.release; await } in
|
||||
t
|
||||
|
||||
[@@@else_]
|
||||
|
||||
let prepare_for_await () = { Dla_.release = ignore; await = ignore }
|
||||
|
||||
[@@@endif]
|
||||
|
||||
exception Got_task of task
|
||||
|
||||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||
|
||||
exception Closed
|
||||
|
||||
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task
|
||||
~(offset : int) : unit =
|
||||
let num_qs = Array.length self.qs in
|
||||
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||
~around_task : unit =
|
||||
let (AT_pair (before_task, after_task)) = around_task in
|
||||
|
||||
(* try to get a task that is already in one of the queues.
|
||||
@param force_lock if true, we force acquisition of the queue's mutex,
|
||||
which is slower but always succeeds to get a task if there's one. *)
|
||||
let get_task_already_in_queues ~force_lock () : _ option =
|
||||
try
|
||||
for _retry = 1 to 3 do
|
||||
for i = 0 to num_qs - 1 do
|
||||
let q = self.qs.((offset + i) mod num_qs) in
|
||||
match TS_queue.try_pop ~force_lock q with
|
||||
| Some f -> raise_notrace (Got_task f)
|
||||
| None -> ()
|
||||
done
|
||||
done;
|
||||
None
|
||||
with Got_task f ->
|
||||
A.decr self.num_tasks;
|
||||
Some f
|
||||
in
|
||||
|
||||
(* slow path: force locking when trying to get tasks,
|
||||
and wait on [self.cond] if no task is currently available. *)
|
||||
let pop_blocking () : task =
|
||||
try
|
||||
while A.get self.active do
|
||||
match get_task_already_in_queues ~force_lock:true () with
|
||||
| Some t -> raise_notrace (Got_task t)
|
||||
| None ->
|
||||
Mutex.lock self.mutex;
|
||||
(* NOTE: be careful about race conditions: we must only
|
||||
block if the [shutdown] that sets [active] to [false]
|
||||
has not broadcast over this condition first. Otherwise
|
||||
we might miss the signal and wait here forever. *)
|
||||
if A.get self.active then Condition.wait self.cond self.mutex;
|
||||
Mutex.unlock self.mutex
|
||||
done;
|
||||
raise Closed
|
||||
with Got_task t -> t
|
||||
in
|
||||
|
||||
(* run this task. *)
|
||||
let run_task task : unit =
|
||||
let _ctx = before_task runner in
|
||||
(* run the task now, catching errors *)
|
||||
|
|
@ -209,51 +81,69 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task
|
|||
after_task runner _ctx
|
||||
in
|
||||
|
||||
(* drain the queues from existing tasks. If [force_lock=false]
|
||||
then it is best effort. *)
|
||||
let run_tasks_already_present ~force_lock () =
|
||||
let run_self_tasks_ () =
|
||||
let continue = ref true in
|
||||
let pop_retries = ref 0 in
|
||||
while !continue do
|
||||
match get_task_already_in_queues ~force_lock () with
|
||||
| None -> continue := false
|
||||
| Some task -> run_task task
|
||||
match WSQ.pop w.q with
|
||||
| Some task ->
|
||||
pop_retries := 0;
|
||||
run_task task
|
||||
| None ->
|
||||
incr pop_retries;
|
||||
if !pop_retries > 10 then continue := false
|
||||
done
|
||||
in
|
||||
|
||||
(* get a task from another worker *)
|
||||
let try_to_steal_work () : task option =
|
||||
try
|
||||
for _retry = 1 to 3 do
|
||||
Array.iter
|
||||
(fun w' ->
|
||||
if w != w' then (
|
||||
match WSQ.steal w'.q with
|
||||
| None -> ()
|
||||
| Some task -> raise_notrace (Got_task task)
|
||||
))
|
||||
self.workers
|
||||
done;
|
||||
None
|
||||
with Got_task task -> Some task
|
||||
in
|
||||
|
||||
let main_loop () =
|
||||
while A.get self.active do
|
||||
run_tasks_already_present ~force_lock:false ();
|
||||
let steal_attempts = ref 0 in
|
||||
while true do
|
||||
run_self_tasks_ ();
|
||||
|
||||
(* no task available, block until one comes *)
|
||||
match pop_blocking () with
|
||||
| exception Closed -> ()
|
||||
| task -> run_task task
|
||||
done;
|
||||
match try_to_steal_work () with
|
||||
| Some task ->
|
||||
steal_attempts := 0;
|
||||
run_task task
|
||||
| None ->
|
||||
incr steal_attempts;
|
||||
Domain_.relax ();
|
||||
|
||||
(* cleanup *)
|
||||
run_tasks_already_present ~force_lock:true ()
|
||||
if !steal_attempts > 10 then (
|
||||
steal_attempts := 0;
|
||||
let task = Bb_queue.pop self.main_q in
|
||||
run_task task
|
||||
)
|
||||
done
|
||||
in
|
||||
|
||||
try
|
||||
(* handle domain-local await *)
|
||||
Dla_.using ~prepare_for_await ~while_running:main_loop
|
||||
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await
|
||||
~while_running:main_loop
|
||||
with Bb_queue.Closed -> ()
|
||||
|
||||
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
||||
|
||||
(** We want a reasonable number of queues. Even if your system is
|
||||
a beast with hundreds of cores, trying
|
||||
to work-steal through hundreds of queues will have a cost.
|
||||
|
||||
Hence, we limit the number of queues to at most 32 (number picked
|
||||
via the ancestral technique of the pifomètre). *)
|
||||
let max_queues = 32
|
||||
|
||||
let shutdown_ ~wait (self : state) : unit =
|
||||
let was_active = A.exchange self.active false in
|
||||
(* wake up the subset of [self.threads] that are waiting on new tasks *)
|
||||
if was_active then awake_workers_ self;
|
||||
if wait then Array.iter Thread.join self.threads
|
||||
Bb_queue.close self.main_q;
|
||||
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
||||
|
||||
type ('a, 'b) create_args =
|
||||
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
||||
|
|
@ -286,24 +176,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
|
||||
let offset = Random.int num_domains in
|
||||
|
||||
let active = A.make true in
|
||||
let qs =
|
||||
let num_qs = min (min num_domains num_threads) max_queues in
|
||||
Array.init num_qs (fun _ -> TS_queue.create ())
|
||||
let workers : worker_state array =
|
||||
let dummy = Thread.self () in
|
||||
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
|
||||
in
|
||||
|
||||
let pool =
|
||||
let dummy = Thread.self () in
|
||||
{
|
||||
active;
|
||||
threads = Array.make num_threads dummy;
|
||||
num_tasks = A.make 0;
|
||||
qs;
|
||||
mutex = Mutex.create ();
|
||||
cond = Condition.create ();
|
||||
cur_q = A.make 0;
|
||||
}
|
||||
in
|
||||
let pool = { workers; main_q = Bb_queue.create () } in
|
||||
|
||||
let runner =
|
||||
Runner.For_runner_implementors.create
|
||||
|
|
@ -320,6 +198,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
|
||||
(* start the thread with index [i] *)
|
||||
let start_thread_with_idx i =
|
||||
let w = pool.workers.(i) in
|
||||
let dom_idx = (offset + i) mod num_domains in
|
||||
|
||||
(* function run in the thread itself *)
|
||||
|
|
@ -328,17 +207,13 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
let t_id = Thread.id thread in
|
||||
on_init_thread ~dom_id:dom_idx ~t_id ();
|
||||
|
||||
let all_wrappers =
|
||||
List.rev_append thread_wrappers (A.get global_thread_wrappers_)
|
||||
in
|
||||
|
||||
let run () = worker_thread_ pool runner ~on_exn ~around_task ~offset:i in
|
||||
let run () = worker_thread_ pool runner w ~on_exn ~around_task in
|
||||
(* the actual worker loop is [worker_thread_], with all
|
||||
wrappers for this pool and for all pools (global_thread_wrappers_) *)
|
||||
let run' =
|
||||
List.fold_left
|
||||
(fun run f -> f ~thread ~pool:runner run)
|
||||
run all_wrappers
|
||||
run thread_wrappers
|
||||
in
|
||||
|
||||
(* now run the main loop *)
|
||||
|
|
@ -368,7 +243,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* receive the newly created threads back from domains *)
|
||||
for _j = 1 to num_threads do
|
||||
let i, th = Bb_queue.pop receive_threads in
|
||||
pool.threads.(i) <- th
|
||||
pool.workers.(i).thread <- th
|
||||
done;
|
||||
|
||||
runner
|
||||
|
|
|
|||
|
|
@ -23,11 +23,6 @@ type thread_loop_wrapper =
|
|||
By default it just returns the same loop function but it can be used
|
||||
to install tracing, effect handlers, etc. *)
|
||||
|
||||
val add_global_thread_loop_wrapper : thread_loop_wrapper -> unit
|
||||
(** [add_global_thread_loop_wrapper f] installs [f] to be installed in every new pool worker
|
||||
thread, for all existing pools, and all new pools created with [create].
|
||||
These wrappers accumulate: they all apply, but their order is not specified. *)
|
||||
|
||||
type ('a, 'b) create_args =
|
||||
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
||||
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue