mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-11 13:38:38 -05:00
wip: better work stealing pool
This commit is contained in:
parent
e0d3a18562
commit
db33bec13f
2 changed files with 84 additions and 30 deletions
102
src/pool.ml
102
src/pool.ml
|
|
@ -1,4 +1,5 @@
|
||||||
module WSQ = Ws_deque_
|
module WSQ = Ws_deque_
|
||||||
|
module A = Atomic_
|
||||||
include Runner
|
include Runner
|
||||||
|
|
||||||
let ( let@ ) = ( @@ )
|
let ( let@ ) = ( @@ )
|
||||||
|
|
@ -11,20 +12,30 @@ type worker_state = {
|
||||||
q: task WSQ.t; (** Work stealing queue *)
|
q: task WSQ.t; (** Work stealing queue *)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mut_cond = {
|
||||||
|
mutex: Mutex.t;
|
||||||
|
cond: Condition.t;
|
||||||
|
}
|
||||||
|
|
||||||
type state = {
|
type state = {
|
||||||
|
active: bool Atomic.t;
|
||||||
workers: worker_state array;
|
workers: worker_state array;
|
||||||
main_q: task Bb_queue.t; (** Main queue to block on *)
|
main_q: task Queue.t; (** Main queue to block on *)
|
||||||
|
mc: mut_cond;
|
||||||
}
|
}
|
||||||
(** internal state *)
|
(** internal state *)
|
||||||
|
|
||||||
let[@inline] size_ (self : state) = Array.length self.workers
|
let[@inline] size_ (self : state) = Array.length self.workers
|
||||||
|
|
||||||
let num_tasks_ (self : state) : int =
|
let num_tasks_ (self : state) : int =
|
||||||
let n = ref (Bb_queue.size self.main_q) in
|
Mutex.lock self.mc.mutex;
|
||||||
|
let n = ref (Queue.length self.main_q) in
|
||||||
|
Mutex.unlock self.mc.mutex;
|
||||||
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
||||||
!n
|
!n
|
||||||
|
|
||||||
exception Got_worker of worker_state
|
exception Got_worker of worker_state
|
||||||
|
exception Closed = Bb_queue.Closed
|
||||||
|
|
||||||
let find_current_worker_ (self : state) : worker_state option =
|
let find_current_worker_ (self : state) : worker_state option =
|
||||||
let self_id = Thread.id @@ Thread.self () in
|
let self_id = Thread.id @@ Thread.self () in
|
||||||
|
|
@ -41,11 +52,22 @@ let find_current_worker_ (self : state) : worker_state option =
|
||||||
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
||||||
match w with
|
match w with
|
||||||
| Some w ->
|
| Some w ->
|
||||||
print_endline "push local";
|
WSQ.push w.q task;
|
||||||
WSQ.push w.q task
|
|
||||||
|
(* see if we need to wakeup other workers to come and steal from us *)
|
||||||
|
Mutex.lock self.mc.mutex;
|
||||||
|
if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond;
|
||||||
|
Mutex.unlock self.mc.mutex
|
||||||
| None ->
|
| None ->
|
||||||
print_endline "push blocking";
|
if A.get self.active then (
|
||||||
Bb_queue.push self.main_q task
|
(* push into the main queue *)
|
||||||
|
Mutex.lock self.mc.mutex;
|
||||||
|
let was_empty = Queue.is_empty self.main_q in
|
||||||
|
Queue.push task self.main_q;
|
||||||
|
if was_empty then Condition.broadcast self.mc.cond;
|
||||||
|
Mutex.unlock self.mc.mutex
|
||||||
|
) else
|
||||||
|
raise Bb_queue.Closed
|
||||||
|
|
||||||
let run_async_ (self : state) (task : task) : unit =
|
let run_async_ (self : state) (task : task) : unit =
|
||||||
(* stay on current worker if possible *)
|
(* stay on current worker if possible *)
|
||||||
|
|
@ -74,7 +96,7 @@ type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||||
let run_self_task_max_retry = 5
|
let run_self_task_max_retry = 5
|
||||||
|
|
||||||
(** How many times in a row do we try to do work-stealing? *)
|
(** How many times in a row do we try to do work-stealing? *)
|
||||||
let steal_attempt_max_retry = 5
|
let steal_attempt_max_retry = 7
|
||||||
|
|
||||||
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||||
~around_task : unit =
|
~around_task : unit =
|
||||||
|
|
@ -92,7 +114,6 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||||
in
|
in
|
||||||
|
|
||||||
let run_self_tasks_ () =
|
let run_self_tasks_ () =
|
||||||
print_endline "run self tasks";
|
|
||||||
let continue = ref true in
|
let continue = ref true in
|
||||||
let pop_retries = ref 0 in
|
let pop_retries = ref 0 in
|
||||||
while !continue do
|
while !continue do
|
||||||
|
|
@ -107,23 +128,34 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||||
done
|
done
|
||||||
in
|
in
|
||||||
|
|
||||||
|
let work_steal_offset = ref 0 in
|
||||||
|
|
||||||
(* get a task from another worker *)
|
(* get a task from another worker *)
|
||||||
let try_to_steal_work () : task option =
|
let rec try_to_steal_work () : task option =
|
||||||
print_endline "try to steal work";
|
let i = !work_steal_offset in
|
||||||
|
work_steal_offset := (i + 1) mod Array.length self.workers;
|
||||||
|
let w' = self.workers.(i) in
|
||||||
|
if w == w' then
|
||||||
|
try_to_steal_work ()
|
||||||
|
else
|
||||||
|
WSQ.steal w'.q
|
||||||
|
in
|
||||||
|
|
||||||
|
(*
|
||||||
try
|
try
|
||||||
for _retry = 1 to 3 do
|
for _retry = 1 to 1 do
|
||||||
Array.iter
|
for i = 0 to Array.length self.workers - 1 do
|
||||||
(fun w' ->
|
let w' = self.workers.(i) in
|
||||||
if w != w' then (
|
if w != w' then (
|
||||||
match WSQ.steal w'.q with
|
match WSQ.steal w'.q with
|
||||||
| None -> ()
|
| None -> ()
|
||||||
| Some task -> raise_notrace (Got_task task)
|
| Some task -> raise_notrace (Got_task task)
|
||||||
))
|
)
|
||||||
self.workers
|
done
|
||||||
done;
|
done;
|
||||||
None
|
None
|
||||||
with Got_task task -> Some task
|
with Got_task task -> Some task
|
||||||
in
|
*)
|
||||||
|
|
||||||
(* try to steal work multiple times *)
|
(* try to steal work multiple times *)
|
||||||
let try_to_steal_work_loop () : bool =
|
let try_to_steal_work_loop () : bool =
|
||||||
|
|
@ -142,19 +174,28 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||||
with Exit -> true
|
with Exit -> true
|
||||||
in
|
in
|
||||||
|
|
||||||
|
let get_task_from_main_queue_block () : task =
|
||||||
|
try
|
||||||
|
Mutex.lock self.mc.mutex;
|
||||||
|
while A.get self.active do
|
||||||
|
match Queue.pop self.main_q with
|
||||||
|
| exception Queue.Empty -> Condition.wait self.mc.cond self.mc.mutex
|
||||||
|
| task ->
|
||||||
|
Mutex.unlock self.mc.mutex;
|
||||||
|
raise_notrace (Got_task task)
|
||||||
|
done;
|
||||||
|
Mutex.unlock self.mc.mutex;
|
||||||
|
raise Bb_queue.Closed
|
||||||
|
with Got_task t -> t
|
||||||
|
in
|
||||||
|
|
||||||
let main_loop () =
|
let main_loop () =
|
||||||
(try
|
(try
|
||||||
while true do
|
while true do
|
||||||
run_self_tasks_ ();
|
run_self_tasks_ ();
|
||||||
|
|
||||||
if not (try_to_steal_work_loop ()) then (
|
if not (try_to_steal_work_loop ()) then (
|
||||||
Array.iteri
|
let task = get_task_from_main_queue_block () in
|
||||||
(fun i w -> Printf.printf "w[%d].q.size=%d\n" i (WSQ.size w.q))
|
|
||||||
self.workers;
|
|
||||||
Printf.printf "bq.size=%d\n%!" (Bb_queue.size self.main_q);
|
|
||||||
|
|
||||||
print_endline "wait block";
|
|
||||||
let task = Bb_queue.pop self.main_q in
|
|
||||||
run_task task
|
run_task task
|
||||||
)
|
)
|
||||||
done
|
done
|
||||||
|
|
@ -169,8 +210,12 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
||||||
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
||||||
|
|
||||||
let shutdown_ ~wait (self : state) : unit =
|
let shutdown_ ~wait (self : state) : unit =
|
||||||
Bb_queue.close self.main_q;
|
if A.exchange self.active false then (
|
||||||
|
Mutex.lock self.mc.mutex;
|
||||||
|
Condition.broadcast self.mc.cond;
|
||||||
|
Mutex.unlock self.mc.mutex;
|
||||||
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
||||||
|
)
|
||||||
|
|
||||||
type ('a, 'b) create_args =
|
type ('a, 'b) create_args =
|
||||||
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
||||||
|
|
@ -208,7 +253,14 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
|
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
|
||||||
in
|
in
|
||||||
|
|
||||||
let pool = { workers; main_q = Bb_queue.create () } in
|
let pool =
|
||||||
|
{
|
||||||
|
active = A.make true;
|
||||||
|
workers;
|
||||||
|
main_q = Queue.create ();
|
||||||
|
mc = { mutex = Mutex.create (); cond = Condition.create () };
|
||||||
|
}
|
||||||
|
in
|
||||||
|
|
||||||
let runner =
|
let runner =
|
||||||
Runner.For_runner_implementors.create
|
Runner.For_runner_implementors.create
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,8 @@ type ('a, 'b) create_args =
|
||||||
'a
|
'a
|
||||||
(** Arguments used in {!create}. See {!create} for explanations. *)
|
(** Arguments used in {!create}. See {!create} for explanations. *)
|
||||||
|
|
||||||
|
exception Closed
|
||||||
|
|
||||||
val create : (unit -> t, _) create_args
|
val create : (unit -> t, _) create_args
|
||||||
(** [create ()] makes a new thread pool.
|
(** [create ()] makes a new thread pool.
|
||||||
@param on_init_thread called at the beginning of each new thread
|
@param on_init_thread called at the beginning of each new thread
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue