ws pool: use non atomic boolean to reduce number of wakeups; refactor

This commit is contained in:
Simon Cruanes 2023-10-27 14:48:13 -04:00
parent 359ec0352b
commit b4ddd82ee8
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -7,21 +7,24 @@ let ( let@ ) = ( @@ )
type worker_state = { type worker_state = {
mutable thread: Thread.t; mutable thread: Thread.t;
q: task WSQ.t; (** Work stealing queue *) q: task WSQ.t; (** Work stealing queue *)
mutable work_steal_offset: int; (** Current offset for work stealing *)
} }
(** State for a given worker. Only this worker is (** State for a given worker. Only this worker is
allowed to push into the queue, but other workers allowed to push into the queue, but other workers
can come and steal from it if they're idle. *) can come and steal from it if they're idle. *)
type mut_cond = { type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
mutex: Mutex.t;
cond: Condition.t;
}
type state = { type state = {
active: bool A.t; (** Becomes [false] when the pool is shutdown. *) active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
workers: worker_state array; (** Fixed set of workers. *) workers: worker_state array; (** Fixed set of workers. *)
main_q: task Queue.t; (** Main queue for tasks coming from the outside *) main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
mc: mut_cond; (** Used to block on [main_q] *) mutable n_waiting: int; (* protected by mutex *)
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
mutex: Mutex.t;
cond: Condition.t;
on_exn: exn -> Printexc.raw_backtrace -> unit;
around_task: around_task;
} }
(** internal state *) (** internal state *)
@ -29,14 +32,13 @@ let[@inline] size_ (self : state) = Array.length self.workers
let num_tasks_ (self : state) : int = let num_tasks_ (self : state) : int =
let n = ref 0 in let n = ref 0 in
Mutex.lock self.mc.mutex;
n := Queue.length self.main_q; n := Queue.length self.main_q;
Mutex.unlock self.mc.mutex;
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
!n !n
exception Got_worker of worker_state exception Got_worker of worker_state
(* FIXME: replace with TLS *)
let[@inline] find_current_worker_ (self : state) : worker_state option = let[@inline] find_current_worker_ (self : state) : worker_state option =
let self_id = Thread.id @@ Thread.self () in let self_id = Thread.id @@ Thread.self () in
try try
@ -48,159 +50,130 @@ let[@inline] find_current_worker_ (self : state) : worker_state option =
None None
with Got_worker w -> Some w with Got_worker w -> Some w
(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
if self.n_waiting_nonzero then (
Mutex.lock self.mutex;
Condition.broadcast self.cond;
Mutex.unlock self.mutex
)
(** Run [task] as is, on the pool. *) (** Run [task] as is, on the pool. *)
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
=
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
match w with match w with
| Some w -> | Some w ->
WSQ.push w.q task; WSQ.push w.q task;
try_wake_someone_ self
(* see if we need to wakeup other workers to come and steal from us *)
Mutex.lock self.mc.mutex;
if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond;
Mutex.unlock self.mc.mutex
| None -> | None ->
if A.get self.active then ( if A.get self.active then (
(* push into the main queue *) (* push into the main queue *)
Mutex.lock self.mc.mutex; Mutex.lock self.mutex;
let was_empty = Queue.is_empty self.main_q in
Queue.push task self.main_q; Queue.push task self.main_q;
if was_empty then Condition.broadcast self.mc.cond; if self.n_waiting_nonzero then Condition.broadcast self.cond;
Mutex.unlock self.mc.mutex Mutex.unlock self.mutex
) else ) else
(* notify the caller that scheduling tasks is no (* notify the caller that scheduling tasks is no
longer permitted *) longer permitted *)
raise Shutdown raise Shutdown
let run_async_ (self : state) (task : task) : unit = (** Run this task, now. Must be called from a worker. *)
(* run [task] inside a suspension handler *) let run_task_now_ (self : state) ~runner task : unit =
let rec run_async_in_suspend_rec_ (task : task) = (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let task_with_suspend_ () = let (AT_pair (before_task, after_task)) = self.around_task in
(* run [f()] and handle [suspend] in it *) let _ctx = before_task runner in
Suspend_.with_suspend task ~run:(fun ~with_handler task' -> (* run the task now, catching errors *)
if with_handler then (try
run_async_in_suspend_rec_ task' (* run [task()] and handle [suspend] in it *)
else ( Suspend_.with_suspend task ~run:(fun task' ->
let w = find_current_worker_ self in let w = find_current_worker_ self in
run_direct_ self w task' schedule_task_ self w task')
)) with e ->
in let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt);
after_task runner _ctx
(* schedule on current worker, if run from a worker *) let[@inline] run_async_ (self : state) (task : task) : unit =
let w = find_current_worker_ self in let w = find_current_worker_ self in
run_direct_ self w task_with_suspend_ schedule_task_ self w task
in
run_async_in_suspend_rec_ task (* TODO: function to schedule many tasks from the outside.
- build a queue
- lock
- queue transfer
- wakeup all (broadcast)
- unlock *)
let run = run_async let run = run_async
exception Got_task of task (** Wait on condition. Precondition: we hold the mutex. *)
let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting + 1;
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
Condition.wait self.cond self.mutex;
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task (** Try to steal a task from the worker [w] *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task option =
w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers;
(** How many times in a row do we try to do work-stealing? *) (* if we're pointing to [w], skip to the next worker as
let steal_attempt_max_retry = 2 it's useless to steal from oneself *)
if Array.unsafe_get self.workers w.work_steal_offset == w then
w.work_steal_offset <-
(w.work_steal_offset + 1) mod Array.length self.workers;
let w' = Array.unsafe_get self.workers w.work_steal_offset in
WSQ.steal w'.q
(** Try to steal work from several other workers. *)
let try_to_steal_work_loop (self : state) ~runner w : bool =
if size_ self = 1 then
(* no stealing for single thread pool *)
false
else (
let has_stolen = ref false in
let n_retries_left = ref (size_ self - 1) in
while !n_retries_left > 0 do
match try_to_steal_work_once_ self w with
| Some task ->
run_task_now_ self ~runner task;
has_stolen := true;
n_retries_left := 0
| None -> decr n_retries_left
done;
!has_stolen
)
(** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit =
let continue = ref true in
while !continue && A.get self.active do
match WSQ.pop w.q with
| Some task -> run_task_now_ self ~runner task
| None -> continue := false
done
(** Main loop for a worker thread. *) (** Main loop for a worker thread. *)
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
~around_task : unit = let main_loop () : unit =
let (AT_pair (before_task, after_task)) = around_task in
(* run this task. *)
let run_task task : unit =
let _ctx = before_task runner in
(* run the task now, catching errors *)
(try task ()
with e ->
let bt = Printexc.get_raw_backtrace () in
on_exn e bt);
after_task runner _ctx
in
let run_self_tasks_ () =
let continue = ref true in let continue = ref true in
while !continue do while !continue && A.get self.active do
match WSQ.pop w.q with worker_run_self_tasks_ self ~runner w;
| Some task -> run_task task
| None -> continue := false
done
in
let work_steal_offset = ref 0 in let did_steal = try_to_steal_work_loop self ~runner w in
(* get a task from another worker *)
let try_to_steal_work () : task option =
assert (size_ self > 1);
work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers;
(* if we're pointing to [w], skip to the next worker as
it's useless to steal from oneself *)
if self.workers.(!work_steal_offset) == w then
work_steal_offset :=
(!work_steal_offset + 1) mod Array.length self.workers;
let w' = self.workers.(!work_steal_offset) in
assert (w != w');
WSQ.steal w'.q
in
(* try to steal work multiple times *)
let try_to_steal_work_loop () : bool =
if size_ self = 1 then
(* no stealing for single thread pool *)
false
else (
try
let unsuccessful_steal_attempts = ref 0 in
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
match try_to_steal_work () with
| Some task ->
run_task task;
raise_notrace Exit
| None -> incr unsuccessful_steal_attempts
(* Domain_.relax () *)
done;
false
with Exit -> true
)
in
let get_task_from_main_queue_block () : task option =
try
Mutex.lock self.mc.mutex;
while true do
match Queue.pop self.main_q with
| exception Queue.Empty ->
if A.get self.active then
Condition.wait self.mc.cond self.mc.mutex
else (
(* empty queue and we're closed, time to exit *)
Mutex.unlock self.mc.mutex;
raise_notrace Exit
)
| task ->
Mutex.unlock self.mc.mutex;
raise_notrace (Got_task task)
done;
(* unreachable *)
assert false
with
| Got_task t -> Some t
| Exit -> None
in
let main_loop () =
let continue = ref true in
while !continue do
run_self_tasks_ ();
let did_steal = try_to_steal_work_loop () in
if not did_steal then ( if not did_steal then (
match get_task_from_main_queue_block () with Mutex.lock self.mutex;
| None -> match Queue.pop self.main_q with
(* main queue is closed *) | task ->
continue := false Mutex.unlock self.mutex;
| Some task -> run_task task run_task_now_ self ~runner task
| exception Queue.Empty ->
wait_ self;
Mutex.unlock self.mutex
) )
done; done;
assert (WSQ.size w.q = 0) assert (WSQ.size w.q = 0)
@ -214,9 +187,9 @@ let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
let shutdown_ ~wait (self : state) : unit = let shutdown_ ~wait (self : state) : unit =
if A.exchange self.active false then ( if A.exchange self.active false then (
Mutex.lock self.mc.mutex; Mutex.lock self.mutex;
Condition.broadcast self.mc.cond; Condition.broadcast self.cond;
Mutex.unlock self.mc.mutex; Mutex.unlock self.mutex;
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
) )
@ -251,7 +224,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
let workers : worker_state array = let workers : worker_state array =
let dummy = Thread.self () in let dummy = Thread.self () in
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () }) Array.init num_threads (fun i ->
{
thread = dummy;
q = WSQ.create ();
work_steal_offset = (i + 1) mod num_threads;
})
in in
let pool = let pool =
@ -259,7 +237,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
active = A.make true; active = A.make true;
workers; workers;
main_q = Queue.create (); main_q = Queue.create ();
mc = { mutex = Mutex.create (); cond = Condition.create () }; n_waiting = 0;
n_waiting_nonzero = true;
mutex = Mutex.create ();
cond = Condition.create ();
around_task;
on_exn;
} }
in in
@ -287,7 +270,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let t_id = Thread.id thread in let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id (); on_init_thread ~dom_id:dom_idx ~t_id ();
let run () = worker_thread_ pool runner w ~on_exn ~around_task in let run () = worker_thread_ pool ~runner w in
(* now run the main loop *) (* now run the main loop *)
Fun.protect run ~finally:(fun () -> Fun.protect run ~finally:(fun () ->