ws pool: random stealing; rework main state machine

in the state machine, after waiting, we check the main queue, else we
directly go to stealing.
This commit is contained in:
Simon Cruanes 2023-10-27 16:05:52 -04:00
parent aa7906eb2c
commit aba0d84ecf
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -8,7 +8,7 @@ let ( let@ ) = ( @@ )
type worker_state = {
mutable thread: Thread.t;
q: task WSQ.t; (** Work stealing queue *)
mutable work_steal_offset: int; (** Current offset for work stealing *)
rng: Random.State.t;
}
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
@ -111,39 +111,26 @@ let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
(** Try to steal a task from the worker [w] *)
exception Got_task of task
(** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task option =
w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers;
let init = Random.State.int w.rng (Array.length self.workers) in
(* if we're pointing to [w], skip to the next worker as
it's useless to steal from oneself *)
if Array.unsafe_get self.workers w.work_steal_offset == w then
w.work_steal_offset <-
(w.work_steal_offset + 1) mod Array.length self.workers;
try
for i = 0 to Array.length self.workers - 1 do
let w' =
Array.unsafe_get self.workers ((i + init) mod Array.length self.workers)
in
let w' = Array.unsafe_get self.workers w.work_steal_offset in
WSQ.steal w'.q
(** Try to steal work from several other workers. *)
let try_to_steal_work_loop (self : state) ~runner w : bool =
if size_ self = 1 then
(* no stealing for single thread pool *)
false
else (
let has_stolen = ref false in
let n_retries_left = ref (size_ self - 1) in
while !n_retries_left > 0 do
match try_to_steal_work_once_ self w with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner task;
has_stolen := true;
n_retries_left := 0
| None -> decr n_retries_left
done;
!has_stolen
if w != w' then (
match WSQ.steal w'.q with
| Some t -> raise_notrace (Got_task t)
| None -> ()
)
done;
None
with Got_task t -> Some t
(** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit =
@ -160,29 +147,41 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.get k_worker_state := Some w;
let main_loop () : unit =
let continue = ref true in
while !continue && A.get self.active do
let rec main () : unit =
if A.get self.active then (
worker_run_self_tasks_ self ~runner w;
let did_steal = try_to_steal_work_loop self ~runner w in
if not did_steal then (
try_steal ()
)
and run_task task : unit =
run_task_now_ self ~runner task;
main ()
and try_steal () =
if A.get self.active then (
match try_to_steal_work_once_ self w with
| Some task -> run_task task
| None -> wait ()
)
and wait () =
Mutex.lock self.mutex;
match Queue.pop self.main_q with
| task ->
Mutex.unlock self.mutex;
run_task_now_ self ~runner task
run_task task
| exception Queue.Empty ->
(* wait here *)
if A.get self.active then wait_ self;
Mutex.unlock self.mutex
)
done;
assert (WSQ.size w.q = 0)
(* see if a task became available *)
let task = try Some (Queue.pop self.main_q) with Queue.Empty -> None in
Mutex.unlock self.mutex;
(match task with
| Some t -> run_task t
| None -> try_steal ())
in
(* handle domain-local await *)
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await
~while_running:main_loop
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
@ -226,11 +225,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let workers : worker_state array =
let dummy = Thread.self () in
Array.init num_threads (fun i ->
{
thread = dummy;
q = WSQ.create ();
work_steal_offset = (i + 1) mod num_threads;
})
{ thread = dummy; q = WSQ.create (); rng = Random.State.make [| i |] })
in
let pool =