mirror of
https://github.com/c-cube/moonpool.git
synced 2026-01-24 02:06:42 -05:00
ws pool: use non atomic boolean to reduce number of wakeups; refactor
This commit is contained in:
parent
359ec0352b
commit
b4ddd82ee8
1 changed files with 125 additions and 142 deletions
267
src/ws_pool.ml
267
src/ws_pool.ml
|
|
@ -7,21 +7,24 @@ let ( let@ ) = ( @@ )
|
|||
type worker_state = {
|
||||
mutable thread: Thread.t;
|
||||
q: task WSQ.t; (** Work stealing queue *)
|
||||
mutable work_steal_offset: int; (** Current offset for work stealing *)
|
||||
}
|
||||
(** State for a given worker. Only this worker is
|
||||
allowed to push into the queue, but other workers
|
||||
can come and steal from it if they're idle. *)
|
||||
|
||||
type mut_cond = {
|
||||
mutex: Mutex.t;
|
||||
cond: Condition.t;
|
||||
}
|
||||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||
|
||||
type state = {
|
||||
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
|
||||
workers: worker_state array; (** Fixed set of workers. *)
|
||||
main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
|
||||
mc: mut_cond; (** Used to block on [main_q] *)
|
||||
mutable n_waiting: int; (* protected by mutex *)
|
||||
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
|
||||
mutex: Mutex.t;
|
||||
cond: Condition.t;
|
||||
on_exn: exn -> Printexc.raw_backtrace -> unit;
|
||||
around_task: around_task;
|
||||
}
|
||||
(** internal state *)
|
||||
|
||||
|
|
@ -29,14 +32,13 @@ let[@inline] size_ (self : state) = Array.length self.workers
|
|||
|
||||
let num_tasks_ (self : state) : int =
|
||||
let n = ref 0 in
|
||||
Mutex.lock self.mc.mutex;
|
||||
n := Queue.length self.main_q;
|
||||
Mutex.unlock self.mc.mutex;
|
||||
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
||||
!n
|
||||
|
||||
exception Got_worker of worker_state
|
||||
|
||||
(* FIXME: replace with TLS *)
|
||||
let[@inline] find_current_worker_ (self : state) : worker_state option =
|
||||
let self_id = Thread.id @@ Thread.self () in
|
||||
try
|
||||
|
|
@ -48,159 +50,130 @@ let[@inline] find_current_worker_ (self : state) : worker_state option =
|
|||
None
|
||||
with Got_worker w -> Some w
|
||||
|
||||
(** Try to wake up a waiter, if there's any. *)
|
||||
let[@inline] try_wake_someone_ (self : state) : unit =
|
||||
if self.n_waiting_nonzero then (
|
||||
Mutex.lock self.mutex;
|
||||
Condition.broadcast self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
)
|
||||
|
||||
(** Run [task] as is, on the pool. *)
|
||||
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
||||
let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
|
||||
=
|
||||
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
|
||||
match w with
|
||||
| Some w ->
|
||||
WSQ.push w.q task;
|
||||
|
||||
(* see if we need to wakeup other workers to come and steal from us *)
|
||||
Mutex.lock self.mc.mutex;
|
||||
if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond;
|
||||
Mutex.unlock self.mc.mutex
|
||||
try_wake_someone_ self
|
||||
| None ->
|
||||
if A.get self.active then (
|
||||
(* push into the main queue *)
|
||||
Mutex.lock self.mc.mutex;
|
||||
let was_empty = Queue.is_empty self.main_q in
|
||||
Mutex.lock self.mutex;
|
||||
Queue.push task self.main_q;
|
||||
if was_empty then Condition.broadcast self.mc.cond;
|
||||
Mutex.unlock self.mc.mutex
|
||||
if self.n_waiting_nonzero then Condition.broadcast self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
) else
|
||||
(* notify the caller that scheduling tasks is no
|
||||
longer permitted *)
|
||||
raise Shutdown
|
||||
|
||||
let run_async_ (self : state) (task : task) : unit =
|
||||
(* run [task] inside a suspension handler *)
|
||||
let rec run_async_in_suspend_rec_ (task : task) =
|
||||
let task_with_suspend_ () =
|
||||
(* run [f()] and handle [suspend] in it *)
|
||||
Suspend_.with_suspend task ~run:(fun ~with_handler task' ->
|
||||
if with_handler then
|
||||
run_async_in_suspend_rec_ task'
|
||||
else (
|
||||
let w = find_current_worker_ self in
|
||||
run_direct_ self w task'
|
||||
))
|
||||
in
|
||||
(** Run this task, now. Must be called from a worker. *)
|
||||
let run_task_now_ (self : state) ~runner task : unit =
|
||||
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
|
||||
let (AT_pair (before_task, after_task)) = self.around_task in
|
||||
let _ctx = before_task runner in
|
||||
(* run the task now, catching errors *)
|
||||
(try
|
||||
(* run [task()] and handle [suspend] in it *)
|
||||
Suspend_.with_suspend task ~run:(fun task' ->
|
||||
let w = find_current_worker_ self in
|
||||
schedule_task_ self w task')
|
||||
with e ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
self.on_exn e bt);
|
||||
after_task runner _ctx
|
||||
|
||||
(* schedule on current worker, if run from a worker *)
|
||||
let w = find_current_worker_ self in
|
||||
run_direct_ self w task_with_suspend_
|
||||
in
|
||||
run_async_in_suspend_rec_ task
|
||||
let[@inline] run_async_ (self : state) (task : task) : unit =
|
||||
let w = find_current_worker_ self in
|
||||
schedule_task_ self w task
|
||||
|
||||
(* TODO: function to schedule many tasks from the outside.
|
||||
- build a queue
|
||||
- lock
|
||||
- queue transfer
|
||||
- wakeup all (broadcast)
|
||||
- unlock *)
|
||||
|
||||
let run = run_async
|
||||
|
||||
exception Got_task of task
|
||||
(** Wait on condition. Precondition: we hold the mutex. *)
|
||||
let[@inline] wait_ (self : state) : unit =
|
||||
self.n_waiting <- self.n_waiting + 1;
|
||||
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
|
||||
Condition.wait self.cond self.mutex;
|
||||
self.n_waiting <- self.n_waiting - 1;
|
||||
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
|
||||
|
||||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||
(** Try to steal a task from the worker [w] *)
|
||||
let try_to_steal_work_once_ (self : state) (w : worker_state) : task option =
|
||||
w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers;
|
||||
|
||||
(** How many times in a row do we try to do work-stealing? *)
|
||||
let steal_attempt_max_retry = 2
|
||||
(* if we're pointing to [w], skip to the next worker as
|
||||
it's useless to steal from oneself *)
|
||||
if Array.unsafe_get self.workers w.work_steal_offset == w then
|
||||
w.work_steal_offset <-
|
||||
(w.work_steal_offset + 1) mod Array.length self.workers;
|
||||
|
||||
let w' = Array.unsafe_get self.workers w.work_steal_offset in
|
||||
WSQ.steal w'.q
|
||||
|
||||
(** Try to steal work from several other workers. *)
|
||||
let try_to_steal_work_loop (self : state) ~runner w : bool =
|
||||
if size_ self = 1 then
|
||||
(* no stealing for single thread pool *)
|
||||
false
|
||||
else (
|
||||
let has_stolen = ref false in
|
||||
let n_retries_left = ref (size_ self - 1) in
|
||||
|
||||
while !n_retries_left > 0 do
|
||||
match try_to_steal_work_once_ self w with
|
||||
| Some task ->
|
||||
run_task_now_ self ~runner task;
|
||||
has_stolen := true;
|
||||
n_retries_left := 0
|
||||
| None -> decr n_retries_left
|
||||
done;
|
||||
!has_stolen
|
||||
)
|
||||
|
||||
(** Worker runs tasks from its queue until none remains *)
|
||||
let worker_run_self_tasks_ (self : state) ~runner w : unit =
|
||||
let continue = ref true in
|
||||
while !continue && A.get self.active do
|
||||
match WSQ.pop w.q with
|
||||
| Some task -> run_task_now_ self ~runner task
|
||||
| None -> continue := false
|
||||
done
|
||||
|
||||
(** Main loop for a worker thread. *)
|
||||
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||
~around_task : unit =
|
||||
let (AT_pair (before_task, after_task)) = around_task in
|
||||
|
||||
(* run this task. *)
|
||||
let run_task task : unit =
|
||||
let _ctx = before_task runner in
|
||||
(* run the task now, catching errors *)
|
||||
(try task ()
|
||||
with e ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
on_exn e bt);
|
||||
after_task runner _ctx
|
||||
in
|
||||
|
||||
let run_self_tasks_ () =
|
||||
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
||||
let main_loop () : unit =
|
||||
let continue = ref true in
|
||||
while !continue do
|
||||
match WSQ.pop w.q with
|
||||
| Some task -> run_task task
|
||||
| None -> continue := false
|
||||
done
|
||||
in
|
||||
while !continue && A.get self.active do
|
||||
worker_run_self_tasks_ self ~runner w;
|
||||
|
||||
let work_steal_offset = ref 0 in
|
||||
|
||||
(* get a task from another worker *)
|
||||
let try_to_steal_work () : task option =
|
||||
assert (size_ self > 1);
|
||||
|
||||
work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers;
|
||||
|
||||
(* if we're pointing to [w], skip to the next worker as
|
||||
it's useless to steal from oneself *)
|
||||
if self.workers.(!work_steal_offset) == w then
|
||||
work_steal_offset :=
|
||||
(!work_steal_offset + 1) mod Array.length self.workers;
|
||||
|
||||
let w' = self.workers.(!work_steal_offset) in
|
||||
assert (w != w');
|
||||
WSQ.steal w'.q
|
||||
in
|
||||
|
||||
(* try to steal work multiple times *)
|
||||
let try_to_steal_work_loop () : bool =
|
||||
if size_ self = 1 then
|
||||
(* no stealing for single thread pool *)
|
||||
false
|
||||
else (
|
||||
try
|
||||
let unsuccessful_steal_attempts = ref 0 in
|
||||
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
|
||||
match try_to_steal_work () with
|
||||
| Some task ->
|
||||
run_task task;
|
||||
raise_notrace Exit
|
||||
| None -> incr unsuccessful_steal_attempts
|
||||
(* Domain_.relax () *)
|
||||
done;
|
||||
false
|
||||
with Exit -> true
|
||||
)
|
||||
in
|
||||
|
||||
let get_task_from_main_queue_block () : task option =
|
||||
try
|
||||
Mutex.lock self.mc.mutex;
|
||||
while true do
|
||||
match Queue.pop self.main_q with
|
||||
| exception Queue.Empty ->
|
||||
if A.get self.active then
|
||||
Condition.wait self.mc.cond self.mc.mutex
|
||||
else (
|
||||
(* empty queue and we're closed, time to exit *)
|
||||
Mutex.unlock self.mc.mutex;
|
||||
raise_notrace Exit
|
||||
)
|
||||
| task ->
|
||||
Mutex.unlock self.mc.mutex;
|
||||
raise_notrace (Got_task task)
|
||||
done;
|
||||
(* unreachable *)
|
||||
assert false
|
||||
with
|
||||
| Got_task t -> Some t
|
||||
| Exit -> None
|
||||
in
|
||||
|
||||
let main_loop () =
|
||||
let continue = ref true in
|
||||
while !continue do
|
||||
run_self_tasks_ ();
|
||||
|
||||
let did_steal = try_to_steal_work_loop () in
|
||||
let did_steal = try_to_steal_work_loop self ~runner w in
|
||||
if not did_steal then (
|
||||
match get_task_from_main_queue_block () with
|
||||
| None ->
|
||||
(* main queue is closed *)
|
||||
continue := false
|
||||
| Some task -> run_task task
|
||||
Mutex.lock self.mutex;
|
||||
match Queue.pop self.main_q with
|
||||
| task ->
|
||||
Mutex.unlock self.mutex;
|
||||
run_task_now_ self ~runner task
|
||||
| exception Queue.Empty ->
|
||||
wait_ self;
|
||||
Mutex.unlock self.mutex
|
||||
)
|
||||
done;
|
||||
assert (WSQ.size w.q = 0)
|
||||
|
|
@ -214,9 +187,9 @@ let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
|||
|
||||
let shutdown_ ~wait (self : state) : unit =
|
||||
if A.exchange self.active false then (
|
||||
Mutex.lock self.mc.mutex;
|
||||
Condition.broadcast self.mc.cond;
|
||||
Mutex.unlock self.mc.mutex;
|
||||
Mutex.lock self.mutex;
|
||||
Condition.broadcast self.cond;
|
||||
Mutex.unlock self.mutex;
|
||||
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
||||
)
|
||||
|
||||
|
|
@ -251,7 +224,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
|
||||
let workers : worker_state array =
|
||||
let dummy = Thread.self () in
|
||||
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
|
||||
Array.init num_threads (fun i ->
|
||||
{
|
||||
thread = dummy;
|
||||
q = WSQ.create ();
|
||||
work_steal_offset = (i + 1) mod num_threads;
|
||||
})
|
||||
in
|
||||
|
||||
let pool =
|
||||
|
|
@ -259,7 +237,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
active = A.make true;
|
||||
workers;
|
||||
main_q = Queue.create ();
|
||||
mc = { mutex = Mutex.create (); cond = Condition.create () };
|
||||
n_waiting = 0;
|
||||
n_waiting_nonzero = true;
|
||||
mutex = Mutex.create ();
|
||||
cond = Condition.create ();
|
||||
around_task;
|
||||
on_exn;
|
||||
}
|
||||
in
|
||||
|
||||
|
|
@ -287,7 +270,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
let t_id = Thread.id thread in
|
||||
on_init_thread ~dom_id:dom_idx ~t_id ();
|
||||
|
||||
let run () = worker_thread_ pool runner w ~on_exn ~around_task in
|
||||
let run () = worker_thread_ pool ~runner w in
|
||||
|
||||
(* now run the main loop *)
|
||||
Fun.protect run ~finally:(fun () ->
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue