fix ws_pool: no work stealing for pools of 1 worker

there would be a loop because it'd try to find the index of another
worker to steal from, but loop forever because there is no other worker.
This commit is contained in:
Simon Cruanes 2023-10-25 22:33:08 -04:00
parent d9da7844e2
commit 3956fb6566
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -1,5 +1,13 @@
module WSQ = Ws_deque_ module WSQ = Ws_deque_
module A = Atomic_ module A = Atomic_
module Int_tbl = Hashtbl.Make (struct
type t = int
let equal : t -> t -> bool = ( = )
let hash : t -> int = Hashtbl.hash
end)
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
@ -11,6 +19,9 @@ type worker_state = {
mutable thread: Thread.t; mutable thread: Thread.t;
q: task WSQ.t; (** Work stealing queue *) q: task WSQ.t; (** Work stealing queue *)
} }
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
can come and steal from it if they're idle. *)
type mut_cond = { type mut_cond = {
mutex: Mutex.t; mutex: Mutex.t;
@ -18,35 +29,27 @@ type mut_cond = {
} }
type state = { type state = {
active: bool Atomic.t; active: bool Atomic.t; (** Becomes [false] when the pool is shutdown. *)
workers: worker_state array; workers: worker_state array; (** Fixed set of workers. *)
main_q: task Queue.t; (** Main queue to block on *) worker_by_id: worker_state Int_tbl.t;
mc: mut_cond; main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
mc: mut_cond; (** Used to block on [main_q] *)
} }
(** internal state *) (** internal state *)
let[@inline] size_ (self : state) = Array.length self.workers let[@inline] size_ (self : state) = Array.length self.workers
let num_tasks_ (self : state) : int = let num_tasks_ (self : state) : int =
let n = ref 0 in
Mutex.lock self.mc.mutex; Mutex.lock self.mc.mutex;
let n = ref (Queue.length self.main_q) in n := Queue.length self.main_q;
Mutex.unlock self.mc.mutex; Mutex.unlock self.mc.mutex;
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
!n !n
exception Got_worker of worker_state let[@inline] find_current_worker_ (self : state) : worker_state option =
exception Closed = Bb_queue.Closed
let find_current_worker_ (self : state) : worker_state option =
let self_id = Thread.id @@ Thread.self () in let self_id = Thread.id @@ Thread.self () in
try Int_tbl.find_opt self.worker_by_id self_id
(* see if we're in one of the worker threads *)
for i = 0 to Array.length self.workers - 1 do
let w = self.workers.(i) in
if Thread.id w.thread = self_id then raise_notrace (Got_worker w)
done;
None
with Got_worker w -> Some w
(** Run [task] as is, on the pool. *) (** Run [task] as is, on the pool. *)
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
@ -133,14 +136,20 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
let work_steal_offset = ref 0 in let work_steal_offset = ref 0 in
(* get a task from another worker *) (* get a task from another worker *)
let rec try_to_steal_work () : task option = let try_to_steal_work () : task option =
let i = !work_steal_offset in assert (size_ self > 1);
work_steal_offset := (i + 1) mod Array.length self.workers;
let w' = self.workers.(i) in work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers;
if w == w' then
try_to_steal_work () (* if we're pointing to [w], skip to the next worker as
else it's useless to steal from oneself *)
WSQ.steal w'.q if self.workers.(!work_steal_offset) == w then
work_steal_offset :=
(!work_steal_offset + 1) mod Array.length self.workers;
let w' = self.workers.(!work_steal_offset) in
assert (w != w');
WSQ.steal w'.q
in in
(* (*
@ -161,48 +170,65 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
(* try to steal work multiple times *) (* try to steal work multiple times *)
let try_to_steal_work_loop () : bool = let try_to_steal_work_loop () : bool =
try if size_ self = 1 then
let unsuccessful_steal_attempts = ref 0 in (* no stealing for single thread pool *)
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
match try_to_steal_work () with
| Some task ->
run_task task;
raise_notrace Exit
| None ->
incr unsuccessful_steal_attempts;
Domain_.relax ()
done;
false false
with Exit -> true else (
try
let unsuccessful_steal_attempts = ref 0 in
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
match try_to_steal_work () with
| Some task ->
run_task task;
raise_notrace Exit
| None ->
incr unsuccessful_steal_attempts;
Domain_.relax ()
done;
false
with Exit -> true
)
in in
let get_task_from_main_queue_block () : task = let get_task_from_main_queue_block () : task option =
try try
Mutex.lock self.mc.mutex; Mutex.lock self.mc.mutex;
while A.get self.active do while true do
match Queue.pop self.main_q with match Queue.pop self.main_q with
| exception Queue.Empty -> Condition.wait self.mc.cond self.mc.mutex | exception Queue.Empty ->
if A.get self.active then
Condition.wait self.mc.cond self.mc.mutex
else (
(* empty queue and we're closed, time to exit *)
Mutex.unlock self.mc.mutex;
raise_notrace Exit
)
| task -> | task ->
Mutex.unlock self.mc.mutex; Mutex.unlock self.mc.mutex;
raise_notrace (Got_task task) raise_notrace (Got_task task)
done; done;
Mutex.unlock self.mc.mutex; (* unreachable *)
raise Shutdown assert false
with Got_task t -> t with
| Got_task t -> Some t
| Exit -> None
in in
let main_loop () = let main_loop () =
(try let continue = ref true in
while true do while !continue do
run_self_tasks_ (); run_self_tasks_ ();
if not (try_to_steal_work_loop ()) then ( let did_steal = try_to_steal_work_loop () in
let task = get_task_from_main_queue_block () in if not did_steal then (
run_task task match get_task_from_main_queue_block () with
) | None ->
done (* main queue is closed *)
with Shutdown -> ()); continue := false
run_self_tasks_ () | Some task -> run_task task
)
done;
assert (WSQ.size w.q = 0)
in in
(* handle domain-local await *) (* handle domain-local await *)
@ -259,6 +285,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
{ {
active = A.make true; active = A.make true;
workers; workers;
worker_by_id = Int_tbl.create 8;
main_q = Queue.create (); main_q = Queue.create ();
mc = { mutex = Mutex.create (); cond = Condition.create () }; mc = { mutex = Mutex.create (); cond = Condition.create () };
} }
@ -324,7 +351,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* receive the newly created threads back from domains *) (* receive the newly created threads back from domains *)
for _j = 1 to num_threads do for _j = 1 to num_threads do
let i, th = Bb_queue.pop receive_threads in let i, th = Bb_queue.pop receive_threads in
pool.workers.(i).thread <- th let worker_state = pool.workers.(i) in
worker_state.thread <- th;
Mutex.lock pool.mc.mutex;
Int_tbl.add pool.worker_by_id (Thread.id th) worker_state;
Mutex.unlock pool.mc.mutex
done; done;
runner runner