ws pool: use non atomic boolean to reduce number of wakeups; refactor

This commit is contained in:
Simon Cruanes 2023-10-27 14:48:13 -04:00
parent 359ec0352b
commit b4ddd82ee8
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -7,21 +7,24 @@ let ( let@ ) = ( @@ )
type worker_state = {
mutable thread: Thread.t;
q: task WSQ.t; (** Work stealing queue *)
mutable work_steal_offset: int; (** Current offset for work stealing *)
}
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
can come and steal from it if they're idle. *)
type mut_cond = {
mutex: Mutex.t;
cond: Condition.t;
}
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type state = {
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
workers: worker_state array; (** Fixed set of workers. *)
main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
mc: mut_cond; (** Used to block on [main_q] *)
mutable n_waiting: int; (* protected by mutex *)
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
mutex: Mutex.t;
cond: Condition.t;
on_exn: exn -> Printexc.raw_backtrace -> unit;
around_task: around_task;
}
(** internal state *)
@ -29,14 +32,13 @@ let[@inline] size_ (self : state) = Array.length self.workers
let num_tasks_ (self : state) : int =
let n = ref 0 in
Mutex.lock self.mc.mutex;
n := Queue.length self.main_q;
Mutex.unlock self.mc.mutex;
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
!n
exception Got_worker of worker_state
(* FIXME: replace with TLS *)
let[@inline] find_current_worker_ (self : state) : worker_state option =
let self_id = Thread.id @@ Thread.self () in
try
@ -48,159 +50,130 @@ let[@inline] find_current_worker_ (self : state) : worker_state option =
None
with Got_worker w -> Some w
(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
if self.n_waiting_nonzero then (
Mutex.lock self.mutex;
Condition.broadcast self.cond;
Mutex.unlock self.mutex
)
(** Run [task] as is, on the pool. *)
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
=
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
match w with
| Some w ->
WSQ.push w.q task;
(* see if we need to wakeup other workers to come and steal from us *)
Mutex.lock self.mc.mutex;
if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond;
Mutex.unlock self.mc.mutex
try_wake_someone_ self
| None ->
if A.get self.active then (
(* push into the main queue *)
Mutex.lock self.mc.mutex;
let was_empty = Queue.is_empty self.main_q in
Mutex.lock self.mutex;
Queue.push task self.main_q;
if was_empty then Condition.broadcast self.mc.cond;
Mutex.unlock self.mc.mutex
if self.n_waiting_nonzero then Condition.broadcast self.cond;
Mutex.unlock self.mutex
) else
(* notify the caller that scheduling tasks is no
longer permitted *)
raise Shutdown
let run_async_ (self : state) (task : task) : unit =
(* run [task] inside a suspension handler *)
let rec run_async_in_suspend_rec_ (task : task) =
let task_with_suspend_ () =
(* run [f()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:(fun ~with_handler task' ->
if with_handler then
run_async_in_suspend_rec_ task'
else (
let w = find_current_worker_ self in
run_direct_ self w task'
))
in
(** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner task : unit =
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let (AT_pair (before_task, after_task)) = self.around_task in
let _ctx = before_task runner in
(* run the task now, catching errors *)
(try
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:(fun task' ->
let w = find_current_worker_ self in
schedule_task_ self w task')
with e ->
let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt);
after_task runner _ctx
(* schedule on current worker, if run from a worker *)
let w = find_current_worker_ self in
run_direct_ self w task_with_suspend_
in
run_async_in_suspend_rec_ task
let[@inline] run_async_ (self : state) (task : task) : unit =
let w = find_current_worker_ self in
schedule_task_ self w task
(* TODO: function to schedule many tasks from the outside.
- build a queue
- lock
- queue transfer
- wakeup all (broadcast)
- unlock *)
let run = run_async
exception Got_task of task
(** Wait on condition. Precondition: we hold the mutex. *)
let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting + 1;
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
Condition.wait self.cond self.mutex;
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
(** Try to steal a task from the worker [w] *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task option =
w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers;
(** How many times in a row do we try to do work-stealing? *)
let steal_attempt_max_retry = 2
(* if we're pointing to [w], skip to the next worker as
it's useless to steal from oneself *)
if Array.unsafe_get self.workers w.work_steal_offset == w then
w.work_steal_offset <-
(w.work_steal_offset + 1) mod Array.length self.workers;
let w' = Array.unsafe_get self.workers w.work_steal_offset in
WSQ.steal w'.q
(** Try to steal work from several other workers. *)
let try_to_steal_work_loop (self : state) ~runner w : bool =
if size_ self = 1 then
(* no stealing for single thread pool *)
false
else (
let has_stolen = ref false in
let n_retries_left = ref (size_ self - 1) in
while !n_retries_left > 0 do
match try_to_steal_work_once_ self w with
| Some task ->
run_task_now_ self ~runner task;
has_stolen := true;
n_retries_left := 0
| None -> decr n_retries_left
done;
!has_stolen
)
(** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit =
let continue = ref true in
while !continue && A.get self.active do
match WSQ.pop w.q with
| Some task -> run_task_now_ self ~runner task
| None -> continue := false
done
(** Main loop for a worker thread. *)
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
~around_task : unit =
let (AT_pair (before_task, after_task)) = around_task in
(* run this task. *)
let run_task task : unit =
let _ctx = before_task runner in
(* run the task now, catching errors *)
(try task ()
with e ->
let bt = Printexc.get_raw_backtrace () in
on_exn e bt);
after_task runner _ctx
in
let run_self_tasks_ () =
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
let main_loop () : unit =
let continue = ref true in
while !continue do
match WSQ.pop w.q with
| Some task -> run_task task
| None -> continue := false
done
in
while !continue && A.get self.active do
worker_run_self_tasks_ self ~runner w;
let work_steal_offset = ref 0 in
(* get a task from another worker *)
let try_to_steal_work () : task option =
assert (size_ self > 1);
work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers;
(* if we're pointing to [w], skip to the next worker as
it's useless to steal from oneself *)
if self.workers.(!work_steal_offset) == w then
work_steal_offset :=
(!work_steal_offset + 1) mod Array.length self.workers;
let w' = self.workers.(!work_steal_offset) in
assert (w != w');
WSQ.steal w'.q
in
(* try to steal work multiple times *)
let try_to_steal_work_loop () : bool =
if size_ self = 1 then
(* no stealing for single thread pool *)
false
else (
try
let unsuccessful_steal_attempts = ref 0 in
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
match try_to_steal_work () with
| Some task ->
run_task task;
raise_notrace Exit
| None -> incr unsuccessful_steal_attempts
(* Domain_.relax () *)
done;
false
with Exit -> true
)
in
let get_task_from_main_queue_block () : task option =
try
Mutex.lock self.mc.mutex;
while true do
match Queue.pop self.main_q with
| exception Queue.Empty ->
if A.get self.active then
Condition.wait self.mc.cond self.mc.mutex
else (
(* empty queue and we're closed, time to exit *)
Mutex.unlock self.mc.mutex;
raise_notrace Exit
)
| task ->
Mutex.unlock self.mc.mutex;
raise_notrace (Got_task task)
done;
(* unreachable *)
assert false
with
| Got_task t -> Some t
| Exit -> None
in
let main_loop () =
let continue = ref true in
while !continue do
run_self_tasks_ ();
let did_steal = try_to_steal_work_loop () in
let did_steal = try_to_steal_work_loop self ~runner w in
if not did_steal then (
match get_task_from_main_queue_block () with
| None ->
(* main queue is closed *)
continue := false
| Some task -> run_task task
Mutex.lock self.mutex;
match Queue.pop self.main_q with
| task ->
Mutex.unlock self.mutex;
run_task_now_ self ~runner task
| exception Queue.Empty ->
wait_ self;
Mutex.unlock self.mutex
)
done;
assert (WSQ.size w.q = 0)
@ -214,9 +187,9 @@ let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
let shutdown_ ~wait (self : state) : unit =
if A.exchange self.active false then (
Mutex.lock self.mc.mutex;
Condition.broadcast self.mc.cond;
Mutex.unlock self.mc.mutex;
Mutex.lock self.mutex;
Condition.broadcast self.cond;
Mutex.unlock self.mutex;
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
)
@ -251,7 +224,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
let workers : worker_state array =
let dummy = Thread.self () in
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
Array.init num_threads (fun i ->
{
thread = dummy;
q = WSQ.create ();
work_steal_offset = (i + 1) mod num_threads;
})
in
let pool =
@ -259,7 +237,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
active = A.make true;
workers;
main_q = Queue.create ();
mc = { mutex = Mutex.create (); cond = Condition.create () };
n_waiting = 0;
n_waiting_nonzero = true;
mutex = Mutex.create ();
cond = Condition.create ();
around_task;
on_exn;
}
in
@ -287,7 +270,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id ();
let run () = worker_thread_ pool runner w ~on_exn ~around_task in
let run () = worker_thread_ pool ~runner w in
(* now run the main loop *)
Fun.protect run ~finally:(fun () ->