mirror of
https://github.com/c-cube/moonpool.git
synced 2026-01-27 03:34:50 -05:00
ws pool: use non atomic boolean to reduce number of wakeups; refactor
This commit is contained in:
parent
359ec0352b
commit
b4ddd82ee8
1 changed files with 125 additions and 142 deletions
267
src/ws_pool.ml
267
src/ws_pool.ml
|
|
@ -7,21 +7,24 @@ let ( let@ ) = ( @@ )
|
||||||
type worker_state = {
|
type worker_state = {
|
||||||
mutable thread: Thread.t;
|
mutable thread: Thread.t;
|
||||||
q: task WSQ.t; (** Work stealing queue *)
|
q: task WSQ.t; (** Work stealing queue *)
|
||||||
|
mutable work_steal_offset: int; (** Current offset for work stealing *)
|
||||||
}
|
}
|
||||||
(** State for a given worker. Only this worker is
|
(** State for a given worker. Only this worker is
|
||||||
allowed to push into the queue, but other workers
|
allowed to push into the queue, but other workers
|
||||||
can come and steal from it if they're idle. *)
|
can come and steal from it if they're idle. *)
|
||||||
|
|
||||||
type mut_cond = {
|
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||||
mutex: Mutex.t;
|
|
||||||
cond: Condition.t;
|
|
||||||
}
|
|
||||||
|
|
||||||
type state = {
|
type state = {
|
||||||
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
|
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
|
||||||
workers: worker_state array; (** Fixed set of workers. *)
|
workers: worker_state array; (** Fixed set of workers. *)
|
||||||
main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
|
main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
|
||||||
mc: mut_cond; (** Used to block on [main_q] *)
|
mutable n_waiting: int; (* protected by mutex *)
|
||||||
|
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
|
||||||
|
mutex: Mutex.t;
|
||||||
|
cond: Condition.t;
|
||||||
|
on_exn: exn -> Printexc.raw_backtrace -> unit;
|
||||||
|
around_task: around_task;
|
||||||
}
|
}
|
||||||
(** internal state *)
|
(** internal state *)
|
||||||
|
|
||||||
|
|
@ -29,14 +32,13 @@ let[@inline] size_ (self : state) = Array.length self.workers
|
||||||
|
|
||||||
let num_tasks_ (self : state) : int =
|
let num_tasks_ (self : state) : int =
|
||||||
let n = ref 0 in
|
let n = ref 0 in
|
||||||
Mutex.lock self.mc.mutex;
|
|
||||||
n := Queue.length self.main_q;
|
n := Queue.length self.main_q;
|
||||||
Mutex.unlock self.mc.mutex;
|
|
||||||
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
||||||
!n
|
!n
|
||||||
|
|
||||||
exception Got_worker of worker_state
|
exception Got_worker of worker_state
|
||||||
|
|
||||||
|
(* FIXME: replace with TLS *)
|
||||||
let[@inline] find_current_worker_ (self : state) : worker_state option =
|
let[@inline] find_current_worker_ (self : state) : worker_state option =
|
||||||
let self_id = Thread.id @@ Thread.self () in
|
let self_id = Thread.id @@ Thread.self () in
|
||||||
try
|
try
|
||||||
|
|
@ -48,159 +50,130 @@ let[@inline] find_current_worker_ (self : state) : worker_state option =
|
||||||
None
|
None
|
||||||
with Got_worker w -> Some w
|
with Got_worker w -> Some w
|
||||||
|
|
||||||
|
(** Try to wake up a waiter, if there's any. *)
|
||||||
|
let[@inline] try_wake_someone_ (self : state) : unit =
|
||||||
|
if self.n_waiting_nonzero then (
|
||||||
|
Mutex.lock self.mutex;
|
||||||
|
Condition.broadcast self.cond;
|
||||||
|
Mutex.unlock self.mutex
|
||||||
|
)
|
||||||
|
|
||||||
(** Run [task] as is, on the pool. *)
|
(** Run [task] as is, on the pool. *)
|
||||||
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
|
||||||
|
=
|
||||||
|
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
|
||||||
match w with
|
match w with
|
||||||
| Some w ->
|
| Some w ->
|
||||||
WSQ.push w.q task;
|
WSQ.push w.q task;
|
||||||
|
try_wake_someone_ self
|
||||||
(* see if we need to wakeup other workers to come and steal from us *)
|
|
||||||
Mutex.lock self.mc.mutex;
|
|
||||||
if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond;
|
|
||||||
Mutex.unlock self.mc.mutex
|
|
||||||
| None ->
|
| None ->
|
||||||
if A.get self.active then (
|
if A.get self.active then (
|
||||||
(* push into the main queue *)
|
(* push into the main queue *)
|
||||||
Mutex.lock self.mc.mutex;
|
Mutex.lock self.mutex;
|
||||||
let was_empty = Queue.is_empty self.main_q in
|
|
||||||
Queue.push task self.main_q;
|
Queue.push task self.main_q;
|
||||||
if was_empty then Condition.broadcast self.mc.cond;
|
if self.n_waiting_nonzero then Condition.broadcast self.cond;
|
||||||
Mutex.unlock self.mc.mutex
|
Mutex.unlock self.mutex
|
||||||
) else
|
) else
|
||||||
(* notify the caller that scheduling tasks is no
|
(* notify the caller that scheduling tasks is no
|
||||||
longer permitted *)
|
longer permitted *)
|
||||||
raise Shutdown
|
raise Shutdown
|
||||||
|
|
||||||
let run_async_ (self : state) (task : task) : unit =
|
(** Run this task, now. Must be called from a worker. *)
|
||||||
(* run [task] inside a suspension handler *)
|
let run_task_now_ (self : state) ~runner task : unit =
|
||||||
let rec run_async_in_suspend_rec_ (task : task) =
|
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
|
||||||
let task_with_suspend_ () =
|
let (AT_pair (before_task, after_task)) = self.around_task in
|
||||||
(* run [f()] and handle [suspend] in it *)
|
let _ctx = before_task runner in
|
||||||
Suspend_.with_suspend task ~run:(fun ~with_handler task' ->
|
(* run the task now, catching errors *)
|
||||||
if with_handler then
|
(try
|
||||||
run_async_in_suspend_rec_ task'
|
(* run [task()] and handle [suspend] in it *)
|
||||||
else (
|
Suspend_.with_suspend task ~run:(fun task' ->
|
||||||
let w = find_current_worker_ self in
|
let w = find_current_worker_ self in
|
||||||
run_direct_ self w task'
|
schedule_task_ self w task')
|
||||||
))
|
with e ->
|
||||||
in
|
let bt = Printexc.get_raw_backtrace () in
|
||||||
|
self.on_exn e bt);
|
||||||
|
after_task runner _ctx
|
||||||
|
|
||||||
(* schedule on current worker, if run from a worker *)
|
let[@inline] run_async_ (self : state) (task : task) : unit =
|
||||||
let w = find_current_worker_ self in
|
let w = find_current_worker_ self in
|
||||||
run_direct_ self w task_with_suspend_
|
schedule_task_ self w task
|
||||||
in
|
|
||||||
run_async_in_suspend_rec_ task
|
(* TODO: function to schedule many tasks from the outside.
|
||||||
|
- build a queue
|
||||||
|
- lock
|
||||||
|
- queue transfer
|
||||||
|
- wakeup all (broadcast)
|
||||||
|
- unlock *)
|
||||||
|
|
||||||
let run = run_async
|
let run = run_async
|
||||||
|
|
||||||
exception Got_task of task
|
(** Wait on condition. Precondition: we hold the mutex. *)
|
||||||
|
let[@inline] wait_ (self : state) : unit =
|
||||||
|
self.n_waiting <- self.n_waiting + 1;
|
||||||
|
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
|
||||||
|
Condition.wait self.cond self.mutex;
|
||||||
|
self.n_waiting <- self.n_waiting - 1;
|
||||||
|
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
|
||||||
|
|
||||||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
(** Try to steal a task from the worker [w] *)
|
||||||
|
let try_to_steal_work_once_ (self : state) (w : worker_state) : task option =
|
||||||
|
w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers;
|
||||||
|
|
||||||
(** How many times in a row do we try to do work-stealing? *)
|
(* if we're pointing to [w], skip to the next worker as
|
||||||
let steal_attempt_max_retry = 2
|
it's useless to steal from oneself *)
|
||||||
|
if Array.unsafe_get self.workers w.work_steal_offset == w then
|
||||||
|
w.work_steal_offset <-
|
||||||
|
(w.work_steal_offset + 1) mod Array.length self.workers;
|
||||||
|
|
||||||
|
let w' = Array.unsafe_get self.workers w.work_steal_offset in
|
||||||
|
WSQ.steal w'.q
|
||||||
|
|
||||||
|
(** Try to steal work from several other workers. *)
|
||||||
|
let try_to_steal_work_loop (self : state) ~runner w : bool =
|
||||||
|
if size_ self = 1 then
|
||||||
|
(* no stealing for single thread pool *)
|
||||||
|
false
|
||||||
|
else (
|
||||||
|
let has_stolen = ref false in
|
||||||
|
let n_retries_left = ref (size_ self - 1) in
|
||||||
|
|
||||||
|
while !n_retries_left > 0 do
|
||||||
|
match try_to_steal_work_once_ self w with
|
||||||
|
| Some task ->
|
||||||
|
run_task_now_ self ~runner task;
|
||||||
|
has_stolen := true;
|
||||||
|
n_retries_left := 0
|
||||||
|
| None -> decr n_retries_left
|
||||||
|
done;
|
||||||
|
!has_stolen
|
||||||
|
)
|
||||||
|
|
||||||
|
(** Worker runs tasks from its queue until none remains *)
|
||||||
|
let worker_run_self_tasks_ (self : state) ~runner w : unit =
|
||||||
|
let continue = ref true in
|
||||||
|
while !continue && A.get self.active do
|
||||||
|
match WSQ.pop w.q with
|
||||||
|
| Some task -> run_task_now_ self ~runner task
|
||||||
|
| None -> continue := false
|
||||||
|
done
|
||||||
|
|
||||||
(** Main loop for a worker thread. *)
|
(** Main loop for a worker thread. *)
|
||||||
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
||||||
~around_task : unit =
|
let main_loop () : unit =
|
||||||
let (AT_pair (before_task, after_task)) = around_task in
|
|
||||||
|
|
||||||
(* run this task. *)
|
|
||||||
let run_task task : unit =
|
|
||||||
let _ctx = before_task runner in
|
|
||||||
(* run the task now, catching errors *)
|
|
||||||
(try task ()
|
|
||||||
with e ->
|
|
||||||
let bt = Printexc.get_raw_backtrace () in
|
|
||||||
on_exn e bt);
|
|
||||||
after_task runner _ctx
|
|
||||||
in
|
|
||||||
|
|
||||||
let run_self_tasks_ () =
|
|
||||||
let continue = ref true in
|
let continue = ref true in
|
||||||
while !continue do
|
while !continue && A.get self.active do
|
||||||
match WSQ.pop w.q with
|
worker_run_self_tasks_ self ~runner w;
|
||||||
| Some task -> run_task task
|
|
||||||
| None -> continue := false
|
|
||||||
done
|
|
||||||
in
|
|
||||||
|
|
||||||
let work_steal_offset = ref 0 in
|
let did_steal = try_to_steal_work_loop self ~runner w in
|
||||||
|
|
||||||
(* get a task from another worker *)
|
|
||||||
let try_to_steal_work () : task option =
|
|
||||||
assert (size_ self > 1);
|
|
||||||
|
|
||||||
work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers;
|
|
||||||
|
|
||||||
(* if we're pointing to [w], skip to the next worker as
|
|
||||||
it's useless to steal from oneself *)
|
|
||||||
if self.workers.(!work_steal_offset) == w then
|
|
||||||
work_steal_offset :=
|
|
||||||
(!work_steal_offset + 1) mod Array.length self.workers;
|
|
||||||
|
|
||||||
let w' = self.workers.(!work_steal_offset) in
|
|
||||||
assert (w != w');
|
|
||||||
WSQ.steal w'.q
|
|
||||||
in
|
|
||||||
|
|
||||||
(* try to steal work multiple times *)
|
|
||||||
let try_to_steal_work_loop () : bool =
|
|
||||||
if size_ self = 1 then
|
|
||||||
(* no stealing for single thread pool *)
|
|
||||||
false
|
|
||||||
else (
|
|
||||||
try
|
|
||||||
let unsuccessful_steal_attempts = ref 0 in
|
|
||||||
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
|
|
||||||
match try_to_steal_work () with
|
|
||||||
| Some task ->
|
|
||||||
run_task task;
|
|
||||||
raise_notrace Exit
|
|
||||||
| None -> incr unsuccessful_steal_attempts
|
|
||||||
(* Domain_.relax () *)
|
|
||||||
done;
|
|
||||||
false
|
|
||||||
with Exit -> true
|
|
||||||
)
|
|
||||||
in
|
|
||||||
|
|
||||||
let get_task_from_main_queue_block () : task option =
|
|
||||||
try
|
|
||||||
Mutex.lock self.mc.mutex;
|
|
||||||
while true do
|
|
||||||
match Queue.pop self.main_q with
|
|
||||||
| exception Queue.Empty ->
|
|
||||||
if A.get self.active then
|
|
||||||
Condition.wait self.mc.cond self.mc.mutex
|
|
||||||
else (
|
|
||||||
(* empty queue and we're closed, time to exit *)
|
|
||||||
Mutex.unlock self.mc.mutex;
|
|
||||||
raise_notrace Exit
|
|
||||||
)
|
|
||||||
| task ->
|
|
||||||
Mutex.unlock self.mc.mutex;
|
|
||||||
raise_notrace (Got_task task)
|
|
||||||
done;
|
|
||||||
(* unreachable *)
|
|
||||||
assert false
|
|
||||||
with
|
|
||||||
| Got_task t -> Some t
|
|
||||||
| Exit -> None
|
|
||||||
in
|
|
||||||
|
|
||||||
let main_loop () =
|
|
||||||
let continue = ref true in
|
|
||||||
while !continue do
|
|
||||||
run_self_tasks_ ();
|
|
||||||
|
|
||||||
let did_steal = try_to_steal_work_loop () in
|
|
||||||
if not did_steal then (
|
if not did_steal then (
|
||||||
match get_task_from_main_queue_block () with
|
Mutex.lock self.mutex;
|
||||||
| None ->
|
match Queue.pop self.main_q with
|
||||||
(* main queue is closed *)
|
| task ->
|
||||||
continue := false
|
Mutex.unlock self.mutex;
|
||||||
| Some task -> run_task task
|
run_task_now_ self ~runner task
|
||||||
|
| exception Queue.Empty ->
|
||||||
|
wait_ self;
|
||||||
|
Mutex.unlock self.mutex
|
||||||
)
|
)
|
||||||
done;
|
done;
|
||||||
assert (WSQ.size w.q = 0)
|
assert (WSQ.size w.q = 0)
|
||||||
|
|
@ -214,9 +187,9 @@ let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
||||||
|
|
||||||
let shutdown_ ~wait (self : state) : unit =
|
let shutdown_ ~wait (self : state) : unit =
|
||||||
if A.exchange self.active false then (
|
if A.exchange self.active false then (
|
||||||
Mutex.lock self.mc.mutex;
|
Mutex.lock self.mutex;
|
||||||
Condition.broadcast self.mc.cond;
|
Condition.broadcast self.cond;
|
||||||
Mutex.unlock self.mc.mutex;
|
Mutex.unlock self.mutex;
|
||||||
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -251,7 +224,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
|
|
||||||
let workers : worker_state array =
|
let workers : worker_state array =
|
||||||
let dummy = Thread.self () in
|
let dummy = Thread.self () in
|
||||||
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
|
Array.init num_threads (fun i ->
|
||||||
|
{
|
||||||
|
thread = dummy;
|
||||||
|
q = WSQ.create ();
|
||||||
|
work_steal_offset = (i + 1) mod num_threads;
|
||||||
|
})
|
||||||
in
|
in
|
||||||
|
|
||||||
let pool =
|
let pool =
|
||||||
|
|
@ -259,7 +237,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
active = A.make true;
|
active = A.make true;
|
||||||
workers;
|
workers;
|
||||||
main_q = Queue.create ();
|
main_q = Queue.create ();
|
||||||
mc = { mutex = Mutex.create (); cond = Condition.create () };
|
n_waiting = 0;
|
||||||
|
n_waiting_nonzero = true;
|
||||||
|
mutex = Mutex.create ();
|
||||||
|
cond = Condition.create ();
|
||||||
|
around_task;
|
||||||
|
on_exn;
|
||||||
}
|
}
|
||||||
in
|
in
|
||||||
|
|
||||||
|
|
@ -287,7 +270,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
let t_id = Thread.id thread in
|
let t_id = Thread.id thread in
|
||||||
on_init_thread ~dom_id:dom_idx ~t_id ();
|
on_init_thread ~dom_id:dom_idx ~t_id ();
|
||||||
|
|
||||||
let run () = worker_thread_ pool runner w ~on_exn ~around_task in
|
let run () = worker_thread_ pool ~runner w in
|
||||||
|
|
||||||
(* now run the main loop *)
|
(* now run the main loop *)
|
||||||
Fun.protect run ~finally:(fun () ->
|
Fun.protect run ~finally:(fun () ->
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue