mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-07 11:45:39 -05:00
fix ws_pool: no work stealing for pools of 1 worker
there would be a loop because it'd try to find the index of another worker to steal from, but loop forever because there is no other worker.
This commit is contained in:
parent
d9da7844e2
commit
3956fb6566
1 changed files with 86 additions and 54 deletions
140
src/pool.ml
140
src/pool.ml
|
|
@ -1,5 +1,13 @@
|
||||||
module WSQ = Ws_deque_
|
module WSQ = Ws_deque_
|
||||||
module A = Atomic_
|
module A = Atomic_
|
||||||
|
|
||||||
|
module Int_tbl = Hashtbl.Make (struct
|
||||||
|
type t = int
|
||||||
|
|
||||||
|
let equal : t -> t -> bool = ( = )
|
||||||
|
let hash : t -> int = Hashtbl.hash
|
||||||
|
end)
|
||||||
|
|
||||||
include Runner
|
include Runner
|
||||||
|
|
||||||
let ( let@ ) = ( @@ )
|
let ( let@ ) = ( @@ )
|
||||||
|
|
@ -11,6 +19,9 @@ type worker_state = {
|
||||||
mutable thread: Thread.t;
|
mutable thread: Thread.t;
|
||||||
q: task WSQ.t; (** Work stealing queue *)
|
q: task WSQ.t; (** Work stealing queue *)
|
||||||
}
|
}
|
||||||
|
(** State for a given worker. Only this worker is
|
||||||
|
allowed to push into the queue, but other workers
|
||||||
|
can come and steal from it if they're idle. *)
|
||||||
|
|
||||||
type mut_cond = {
|
type mut_cond = {
|
||||||
mutex: Mutex.t;
|
mutex: Mutex.t;
|
||||||
|
|
@ -18,35 +29,27 @@ type mut_cond = {
|
||||||
}
|
}
|
||||||
|
|
||||||
type state = {
|
type state = {
|
||||||
active: bool Atomic.t;
|
active: bool Atomic.t; (** Becomes [false] when the pool is shutdown. *)
|
||||||
workers: worker_state array;
|
workers: worker_state array; (** Fixed set of workers. *)
|
||||||
main_q: task Queue.t; (** Main queue to block on *)
|
worker_by_id: worker_state Int_tbl.t;
|
||||||
mc: mut_cond;
|
main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
|
||||||
|
mc: mut_cond; (** Used to block on [main_q] *)
|
||||||
}
|
}
|
||||||
(** internal state *)
|
(** internal state *)
|
||||||
|
|
||||||
let[@inline] size_ (self : state) = Array.length self.workers
|
let[@inline] size_ (self : state) = Array.length self.workers
|
||||||
|
|
||||||
let num_tasks_ (self : state) : int =
|
let num_tasks_ (self : state) : int =
|
||||||
|
let n = ref 0 in
|
||||||
Mutex.lock self.mc.mutex;
|
Mutex.lock self.mc.mutex;
|
||||||
let n = ref (Queue.length self.main_q) in
|
n := Queue.length self.main_q;
|
||||||
Mutex.unlock self.mc.mutex;
|
Mutex.unlock self.mc.mutex;
|
||||||
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
||||||
!n
|
!n
|
||||||
|
|
||||||
exception Got_worker of worker_state
|
let[@inline] find_current_worker_ (self : state) : worker_state option =
|
||||||
exception Closed = Bb_queue.Closed
|
|
||||||
|
|
||||||
let find_current_worker_ (self : state) : worker_state option =
|
|
||||||
let self_id = Thread.id @@ Thread.self () in
|
let self_id = Thread.id @@ Thread.self () in
|
||||||
try
|
Int_tbl.find_opt self.worker_by_id self_id
|
||||||
(* see if we're in one of the worker threads *)
|
|
||||||
for i = 0 to Array.length self.workers - 1 do
|
|
||||||
let w = self.workers.(i) in
|
|
||||||
if Thread.id w.thread = self_id then raise_notrace (Got_worker w)
|
|
||||||
done;
|
|
||||||
None
|
|
||||||
with Got_worker w -> Some w
|
|
||||||
|
|
||||||
(** Run [task] as is, on the pool. *)
|
(** Run [task] as is, on the pool. *)
|
||||||
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
||||||
|
|
@ -133,14 +136,20 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||||
let work_steal_offset = ref 0 in
|
let work_steal_offset = ref 0 in
|
||||||
|
|
||||||
(* get a task from another worker *)
|
(* get a task from another worker *)
|
||||||
let rec try_to_steal_work () : task option =
|
let try_to_steal_work () : task option =
|
||||||
let i = !work_steal_offset in
|
assert (size_ self > 1);
|
||||||
work_steal_offset := (i + 1) mod Array.length self.workers;
|
|
||||||
let w' = self.workers.(i) in
|
work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers;
|
||||||
if w == w' then
|
|
||||||
try_to_steal_work ()
|
(* if we're pointing to [w], skip to the next worker as
|
||||||
else
|
it's useless to steal from oneself *)
|
||||||
WSQ.steal w'.q
|
if self.workers.(!work_steal_offset) == w then
|
||||||
|
work_steal_offset :=
|
||||||
|
(!work_steal_offset + 1) mod Array.length self.workers;
|
||||||
|
|
||||||
|
let w' = self.workers.(!work_steal_offset) in
|
||||||
|
assert (w != w');
|
||||||
|
WSQ.steal w'.q
|
||||||
in
|
in
|
||||||
|
|
||||||
(*
|
(*
|
||||||
|
|
@ -161,48 +170,65 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||||
|
|
||||||
(* try to steal work multiple times *)
|
(* try to steal work multiple times *)
|
||||||
let try_to_steal_work_loop () : bool =
|
let try_to_steal_work_loop () : bool =
|
||||||
try
|
if size_ self = 1 then
|
||||||
let unsuccessful_steal_attempts = ref 0 in
|
(* no stealing for single thread pool *)
|
||||||
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
|
|
||||||
match try_to_steal_work () with
|
|
||||||
| Some task ->
|
|
||||||
run_task task;
|
|
||||||
raise_notrace Exit
|
|
||||||
| None ->
|
|
||||||
incr unsuccessful_steal_attempts;
|
|
||||||
Domain_.relax ()
|
|
||||||
done;
|
|
||||||
false
|
false
|
||||||
with Exit -> true
|
else (
|
||||||
|
try
|
||||||
|
let unsuccessful_steal_attempts = ref 0 in
|
||||||
|
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
|
||||||
|
match try_to_steal_work () with
|
||||||
|
| Some task ->
|
||||||
|
run_task task;
|
||||||
|
raise_notrace Exit
|
||||||
|
| None ->
|
||||||
|
incr unsuccessful_steal_attempts;
|
||||||
|
Domain_.relax ()
|
||||||
|
done;
|
||||||
|
false
|
||||||
|
with Exit -> true
|
||||||
|
)
|
||||||
in
|
in
|
||||||
|
|
||||||
let get_task_from_main_queue_block () : task =
|
let get_task_from_main_queue_block () : task option =
|
||||||
try
|
try
|
||||||
Mutex.lock self.mc.mutex;
|
Mutex.lock self.mc.mutex;
|
||||||
while A.get self.active do
|
while true do
|
||||||
match Queue.pop self.main_q with
|
match Queue.pop self.main_q with
|
||||||
| exception Queue.Empty -> Condition.wait self.mc.cond self.mc.mutex
|
| exception Queue.Empty ->
|
||||||
|
if A.get self.active then
|
||||||
|
Condition.wait self.mc.cond self.mc.mutex
|
||||||
|
else (
|
||||||
|
(* empty queue and we're closed, time to exit *)
|
||||||
|
Mutex.unlock self.mc.mutex;
|
||||||
|
raise_notrace Exit
|
||||||
|
)
|
||||||
| task ->
|
| task ->
|
||||||
Mutex.unlock self.mc.mutex;
|
Mutex.unlock self.mc.mutex;
|
||||||
raise_notrace (Got_task task)
|
raise_notrace (Got_task task)
|
||||||
done;
|
done;
|
||||||
Mutex.unlock self.mc.mutex;
|
(* unreachable *)
|
||||||
raise Shutdown
|
assert false
|
||||||
with Got_task t -> t
|
with
|
||||||
|
| Got_task t -> Some t
|
||||||
|
| Exit -> None
|
||||||
in
|
in
|
||||||
|
|
||||||
let main_loop () =
|
let main_loop () =
|
||||||
(try
|
let continue = ref true in
|
||||||
while true do
|
while !continue do
|
||||||
run_self_tasks_ ();
|
run_self_tasks_ ();
|
||||||
|
|
||||||
if not (try_to_steal_work_loop ()) then (
|
let did_steal = try_to_steal_work_loop () in
|
||||||
let task = get_task_from_main_queue_block () in
|
if not did_steal then (
|
||||||
run_task task
|
match get_task_from_main_queue_block () with
|
||||||
)
|
| None ->
|
||||||
done
|
(* main queue is closed *)
|
||||||
with Shutdown -> ());
|
continue := false
|
||||||
run_self_tasks_ ()
|
| Some task -> run_task task
|
||||||
|
)
|
||||||
|
done;
|
||||||
|
assert (WSQ.size w.q = 0)
|
||||||
in
|
in
|
||||||
|
|
||||||
(* handle domain-local await *)
|
(* handle domain-local await *)
|
||||||
|
|
@ -259,6 +285,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
{
|
{
|
||||||
active = A.make true;
|
active = A.make true;
|
||||||
workers;
|
workers;
|
||||||
|
worker_by_id = Int_tbl.create 8;
|
||||||
main_q = Queue.create ();
|
main_q = Queue.create ();
|
||||||
mc = { mutex = Mutex.create (); cond = Condition.create () };
|
mc = { mutex = Mutex.create (); cond = Condition.create () };
|
||||||
}
|
}
|
||||||
|
|
@ -324,7 +351,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
(* receive the newly created threads back from domains *)
|
(* receive the newly created threads back from domains *)
|
||||||
for _j = 1 to num_threads do
|
for _j = 1 to num_threads do
|
||||||
let i, th = Bb_queue.pop receive_threads in
|
let i, th = Bb_queue.pop receive_threads in
|
||||||
pool.workers.(i).thread <- th
|
let worker_state = pool.workers.(i) in
|
||||||
|
worker_state.thread <- th;
|
||||||
|
|
||||||
|
Mutex.lock pool.mc.mutex;
|
||||||
|
Int_tbl.add pool.worker_by_id (Thread.id th) worker_state;
|
||||||
|
Mutex.unlock pool.mc.mutex
|
||||||
done;
|
done;
|
||||||
|
|
||||||
runner
|
runner
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue