wip: better work stealing pool

This commit is contained in:
Simon Cruanes 2023-10-25 12:11:41 -04:00
parent e0d3a18562
commit db33bec13f
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 84 additions and 30 deletions

View file

@ -1,4 +1,5 @@
module WSQ = Ws_deque_ module WSQ = Ws_deque_
module A = Atomic_
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
@ -11,20 +12,30 @@ type worker_state = {
q: task WSQ.t; (** Work stealing queue *) q: task WSQ.t; (** Work stealing queue *)
} }
type mut_cond = {
mutex: Mutex.t;
cond: Condition.t;
}
type state = { type state = {
active: bool Atomic.t;
workers: worker_state array; workers: worker_state array;
main_q: task Bb_queue.t; (** Main queue to block on *) main_q: task Queue.t; (** Main queue to block on *)
mc: mut_cond;
} }
(** internal state *) (** internal state *)
let[@inline] size_ (self : state) = Array.length self.workers let[@inline] size_ (self : state) = Array.length self.workers
let num_tasks_ (self : state) : int = let num_tasks_ (self : state) : int =
let n = ref (Bb_queue.size self.main_q) in Mutex.lock self.mc.mutex;
let n = ref (Queue.length self.main_q) in
Mutex.unlock self.mc.mutex;
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
!n !n
exception Got_worker of worker_state exception Got_worker of worker_state
exception Closed = Bb_queue.Closed
let find_current_worker_ (self : state) : worker_state option = let find_current_worker_ (self : state) : worker_state option =
let self_id = Thread.id @@ Thread.self () in let self_id = Thread.id @@ Thread.self () in
@ -41,11 +52,22 @@ let find_current_worker_ (self : state) : worker_state option =
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
match w with match w with
| Some w -> | Some w ->
print_endline "push local"; WSQ.push w.q task;
WSQ.push w.q task
(* see if we need to wakeup other workers to come and steal from us *)
Mutex.lock self.mc.mutex;
if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond;
Mutex.unlock self.mc.mutex
| None -> | None ->
print_endline "push blocking"; if A.get self.active then (
Bb_queue.push self.main_q task (* push into the main queue *)
Mutex.lock self.mc.mutex;
let was_empty = Queue.is_empty self.main_q in
Queue.push task self.main_q;
if was_empty then Condition.broadcast self.mc.cond;
Mutex.unlock self.mc.mutex
) else
raise Bb_queue.Closed
let run_async_ (self : state) (task : task) : unit = let run_async_ (self : state) (task : task) : unit =
(* stay on current worker if possible *) (* stay on current worker if possible *)
@ -74,7 +96,7 @@ type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
let run_self_task_max_retry = 5 let run_self_task_max_retry = 5
(** How many times in a row do we try to do work-stealing? *) (** How many times in a row do we try to do work-stealing? *)
let steal_attempt_max_retry = 5 let steal_attempt_max_retry = 7
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
~around_task : unit = ~around_task : unit =
@ -92,7 +114,6 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
in in
let run_self_tasks_ () = let run_self_tasks_ () =
print_endline "run self tasks";
let continue = ref true in let continue = ref true in
let pop_retries = ref 0 in let pop_retries = ref 0 in
while !continue do while !continue do
@ -107,23 +128,34 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
done done
in in
let work_steal_offset = ref 0 in
(* get a task from another worker *) (* get a task from another worker *)
let try_to_steal_work () : task option = let rec try_to_steal_work () : task option =
print_endline "try to steal work"; let i = !work_steal_offset in
work_steal_offset := (i + 1) mod Array.length self.workers;
let w' = self.workers.(i) in
if w == w' then
try_to_steal_work ()
else
WSQ.steal w'.q
in
(*
try try
for _retry = 1 to 3 do for _retry = 1 to 1 do
Array.iter for i = 0 to Array.length self.workers - 1 do
(fun w' -> let w' = self.workers.(i) in
if w != w' then ( if w != w' then (
match WSQ.steal w'.q with match WSQ.steal w'.q with
| None -> () | None -> ()
| Some task -> raise_notrace (Got_task task) | Some task -> raise_notrace (Got_task task)
)) )
self.workers done
done; done;
None None
with Got_task task -> Some task with Got_task task -> Some task
in *)
(* try to steal work multiple times *) (* try to steal work multiple times *)
let try_to_steal_work_loop () : bool = let try_to_steal_work_loop () : bool =
@ -142,19 +174,28 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
with Exit -> true with Exit -> true
in in
let get_task_from_main_queue_block () : task =
try
Mutex.lock self.mc.mutex;
while A.get self.active do
match Queue.pop self.main_q with
| exception Queue.Empty -> Condition.wait self.mc.cond self.mc.mutex
| task ->
Mutex.unlock self.mc.mutex;
raise_notrace (Got_task task)
done;
Mutex.unlock self.mc.mutex;
raise Bb_queue.Closed
with Got_task t -> t
in
let main_loop () = let main_loop () =
(try (try
while true do while true do
run_self_tasks_ (); run_self_tasks_ ();
if not (try_to_steal_work_loop ()) then ( if not (try_to_steal_work_loop ()) then (
Array.iteri let task = get_task_from_main_queue_block () in
(fun i w -> Printf.printf "w[%d].q.size=%d\n" i (WSQ.size w.q))
self.workers;
Printf.printf "bq.size=%d\n%!" (Bb_queue.size self.main_q);
print_endline "wait block";
let task = Bb_queue.pop self.main_q in
run_task task run_task task
) )
done done
@ -169,8 +210,12 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
let shutdown_ ~wait (self : state) : unit = let shutdown_ ~wait (self : state) : unit =
Bb_queue.close self.main_q; if A.exchange self.active false then (
Mutex.lock self.mc.mutex;
Condition.broadcast self.mc.cond;
Mutex.unlock self.mc.mutex;
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
)
type ('a, 'b) create_args = type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
@ -208,7 +253,14 @@ let create ?(on_init_thread = default_thread_init_exit_)
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () }) Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
in in
let pool = { workers; main_q = Bb_queue.create () } in let pool =
{
active = A.make true;
workers;
main_q = Queue.create ();
mc = { mutex = Mutex.create (); cond = Condition.create () };
}
in
let runner = let runner =
Runner.For_runner_implementors.create Runner.For_runner_implementors.create

View file

@ -34,6 +34,8 @@ type ('a, 'b) create_args =
'a 'a
(** Arguments used in {!create}. See {!create} for explanations. *) (** Arguments used in {!create}. See {!create} for explanations. *)
exception Closed
val create : (unit -> t, _) create_args val create : (unit -> t, _) create_args
(** [create ()] makes a new thread pool. (** [create ()] makes a new thread pool.
@param on_init_thread called at the beginning of each new thread @param on_init_thread called at the beginning of each new thread