ws pool: random stealing; rework main state machine

in the state machine, after waiting, we check the main queue, else we
directly go to stealing.
This commit is contained in:
Simon Cruanes 2023-10-27 16:05:52 -04:00
parent aa7906eb2c
commit aba0d84ecf
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -8,7 +8,7 @@ let ( let@ ) = ( @@ )
type worker_state = { type worker_state = {
mutable thread: Thread.t; mutable thread: Thread.t;
q: task WSQ.t; (** Work stealing queue *) q: task WSQ.t; (** Work stealing queue *)
mutable work_steal_offset: int; (** Current offset for work stealing *) rng: Random.State.t;
} }
(** State for a given worker. Only this worker is (** State for a given worker. Only this worker is
allowed to push into the queue, but other workers allowed to push into the queue, but other workers
@ -111,39 +111,26 @@ let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting - 1; self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false if self.n_waiting = 0 then self.n_waiting_nonzero <- false
(** Try to steal a task from the worker [w] *) exception Got_task of task
(** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task option = let try_to_steal_work_once_ (self : state) (w : worker_state) : task option =
w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers; let init = Random.State.int w.rng (Array.length self.workers) in
(* if we're pointing to [w], skip to the next worker as try
it's useless to steal from oneself *) for i = 0 to Array.length self.workers - 1 do
if Array.unsafe_get self.workers w.work_steal_offset == w then let w' =
w.work_steal_offset <- Array.unsafe_get self.workers ((i + init) mod Array.length self.workers)
(w.work_steal_offset + 1) mod Array.length self.workers; in
let w' = Array.unsafe_get self.workers w.work_steal_offset in if w != w' then (
WSQ.steal w'.q match WSQ.steal w'.q with
| Some t -> raise_notrace (Got_task t)
(** Try to steal work from several other workers. *) | None -> ()
let try_to_steal_work_loop (self : state) ~runner w : bool = )
if size_ self = 1 then
(* no stealing for single thread pool *)
false
else (
let has_stolen = ref false in
let n_retries_left = ref (size_ self - 1) in
while !n_retries_left > 0 do
match try_to_steal_work_once_ self w with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner task;
has_stolen := true;
n_retries_left := 0
| None -> decr n_retries_left
done; done;
!has_stolen None
) with Got_task t -> Some t
(** Worker runs tasks from its queue until none remains *) (** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit = let worker_run_self_tasks_ (self : state) ~runner w : unit =
@ -160,29 +147,41 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.get k_worker_state := Some w; TLS.get k_worker_state := Some w;
let main_loop () : unit = let rec main () : unit =
let continue = ref true in if A.get self.active then (
while !continue && A.get self.active do
worker_run_self_tasks_ self ~runner w; worker_run_self_tasks_ self ~runner w;
try_steal ()
)
and run_task task : unit =
run_task_now_ self ~runner task;
main ()
and try_steal () =
if A.get self.active then (
match try_to_steal_work_once_ self w with
| Some task -> run_task task
| None -> wait ()
)
and wait () =
Mutex.lock self.mutex;
match Queue.pop self.main_q with
| task ->
Mutex.unlock self.mutex;
run_task task
| exception Queue.Empty ->
(* wait here *)
if A.get self.active then wait_ self;
let did_steal = try_to_steal_work_loop self ~runner w in (* see if a task became available *)
if not did_steal then ( let task = try Some (Queue.pop self.main_q) with Queue.Empty -> None in
Mutex.lock self.mutex; Mutex.unlock self.mutex;
match Queue.pop self.main_q with
| task -> (match task with
Mutex.unlock self.mutex; | Some t -> run_task t
run_task_now_ self ~runner task | None -> try_steal ())
| exception Queue.Empty ->
if A.get self.active then wait_ self;
Mutex.unlock self.mutex
)
done;
assert (WSQ.size w.q = 0)
in in
(* handle domain-local await *) (* handle domain-local await *)
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main
~while_running:main_loop
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
@ -226,11 +225,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let workers : worker_state array = let workers : worker_state array =
let dummy = Thread.self () in let dummy = Thread.self () in
Array.init num_threads (fun i -> Array.init num_threads (fun i ->
{ { thread = dummy; q = WSQ.create (); rng = Random.State.make [| i |] })
thread = dummy;
q = WSQ.create ();
work_steal_offset = (i + 1) mod num_threads;
})
in in
let pool = let pool =