mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-18 00:26:43 -05:00
there would be a loop because it'd try to find the index of another worker to steal from, but loop forever because there is no other worker.
371 lines
11 KiB
OCaml
371 lines
11 KiB
OCaml
module WSQ = Ws_deque_
|
|
module A = Atomic_
|
|
|
|
module Int_tbl = Hashtbl.Make (struct
|
|
type t = int
|
|
|
|
let equal : t -> t -> bool = ( = )
|
|
let hash : t -> int = Hashtbl.hash
|
|
end)
|
|
|
|
include Runner
|
|
|
|
let ( let@ ) = ( @@ )
|
|
|
|
type thread_loop_wrapper =
|
|
thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit
|
|
|
|
type worker_state = {
|
|
mutable thread: Thread.t;
|
|
q: task WSQ.t; (** Work stealing queue *)
|
|
}
|
|
(** State for a given worker. Only this worker is
|
|
allowed to push into the queue, but other workers
|
|
can come and steal from it if they're idle. *)
|
|
|
|
type mut_cond = {
|
|
mutex: Mutex.t;
|
|
cond: Condition.t;
|
|
}
|
|
|
|
type state = {
|
|
active: bool Atomic.t; (** Becomes [false] when the pool is shutdown. *)
|
|
workers: worker_state array; (** Fixed set of workers. *)
|
|
worker_by_id: worker_state Int_tbl.t;
|
|
main_q: task Queue.t; (** Main queue for tasks coming from the outside *)
|
|
mc: mut_cond; (** Used to block on [main_q] *)
|
|
}
|
|
(** internal state *)
|
|
|
|
let[@inline] size_ (self : state) = Array.length self.workers
|
|
|
|
let num_tasks_ (self : state) : int =
|
|
let n = ref 0 in
|
|
Mutex.lock self.mc.mutex;
|
|
n := Queue.length self.main_q;
|
|
Mutex.unlock self.mc.mutex;
|
|
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
|
!n
|
|
|
|
let[@inline] find_current_worker_ (self : state) : worker_state option =
|
|
let self_id = Thread.id @@ Thread.self () in
|
|
Int_tbl.find_opt self.worker_by_id self_id
|
|
|
|
(** Run [task] as is, on the pool. *)
|
|
let run_direct_ (self : state) (w : worker_state option) (task : task) : unit =
|
|
match w with
|
|
| Some w ->
|
|
WSQ.push w.q task;
|
|
|
|
(* see if we need to wakeup other workers to come and steal from us *)
|
|
Mutex.lock self.mc.mutex;
|
|
if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond;
|
|
Mutex.unlock self.mc.mutex
|
|
| None ->
|
|
if A.get self.active then (
|
|
(* push into the main queue *)
|
|
Mutex.lock self.mc.mutex;
|
|
let was_empty = Queue.is_empty self.main_q in
|
|
Queue.push task self.main_q;
|
|
if was_empty then Condition.broadcast self.mc.cond;
|
|
Mutex.unlock self.mc.mutex
|
|
) else
|
|
(* notify the caller that scheduling tasks is no
|
|
longer permitted *)
|
|
raise Shutdown
|
|
|
|
let run_async_ (self : state) (task : task) : unit =
|
|
(* stay on current worker if possible *)
|
|
let w = find_current_worker_ self in
|
|
|
|
let rec run_async_rec_ (task : task) =
|
|
let task_with_suspend_ () =
|
|
(* run [f()] and handle [suspend] in it *)
|
|
Suspend_.with_suspend task ~run:(fun ~with_handler task' ->
|
|
if with_handler then
|
|
run_async_rec_ task'
|
|
else
|
|
run_direct_ self w task')
|
|
in
|
|
run_direct_ self w task_with_suspend_
|
|
in
|
|
run_async_rec_ task
|
|
|
|
let run = run_async
|
|
|
|
exception Got_task of task
|
|
|
|
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
|
|
|
(** How many times in a row do we try to read the next local task? *)
|
|
let run_self_task_max_retry = 5
|
|
|
|
(** How many times in a row do we try to do work-stealing? *)
|
|
let steal_attempt_max_retry = 7
|
|
|
|
let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn
|
|
~around_task : unit =
|
|
let (AT_pair (before_task, after_task)) = around_task in
|
|
|
|
(* run this task. *)
|
|
let run_task task : unit =
|
|
let _ctx = before_task runner in
|
|
(* run the task now, catching errors *)
|
|
(try task ()
|
|
with e ->
|
|
let bt = Printexc.get_raw_backtrace () in
|
|
on_exn e bt);
|
|
after_task runner _ctx
|
|
in
|
|
|
|
let run_self_tasks_ () =
|
|
let continue = ref true in
|
|
let pop_retries = ref 0 in
|
|
while !continue do
|
|
match WSQ.pop w.q with
|
|
| Some task ->
|
|
pop_retries := 0;
|
|
run_task task
|
|
| None ->
|
|
Domain_.relax ();
|
|
incr pop_retries;
|
|
if !pop_retries > run_self_task_max_retry then continue := false
|
|
done
|
|
in
|
|
|
|
let work_steal_offset = ref 0 in
|
|
|
|
(* get a task from another worker *)
|
|
let try_to_steal_work () : task option =
|
|
assert (size_ self > 1);
|
|
|
|
work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers;
|
|
|
|
(* if we're pointing to [w], skip to the next worker as
|
|
it's useless to steal from oneself *)
|
|
if self.workers.(!work_steal_offset) == w then
|
|
work_steal_offset :=
|
|
(!work_steal_offset + 1) mod Array.length self.workers;
|
|
|
|
let w' = self.workers.(!work_steal_offset) in
|
|
assert (w != w');
|
|
WSQ.steal w'.q
|
|
in
|
|
|
|
(*
|
|
try
|
|
for _retry = 1 to 1 do
|
|
for i = 0 to Array.length self.workers - 1 do
|
|
let w' = self.workers.(i) in
|
|
if w != w' then (
|
|
match WSQ.steal w'.q with
|
|
| None -> ()
|
|
| Some task -> raise_notrace (Got_task task)
|
|
)
|
|
done
|
|
done;
|
|
None
|
|
with Got_task task -> Some task
|
|
*)
|
|
|
|
(* try to steal work multiple times *)
|
|
let try_to_steal_work_loop () : bool =
|
|
if size_ self = 1 then
|
|
(* no stealing for single thread pool *)
|
|
false
|
|
else (
|
|
try
|
|
let unsuccessful_steal_attempts = ref 0 in
|
|
while !unsuccessful_steal_attempts < steal_attempt_max_retry do
|
|
match try_to_steal_work () with
|
|
| Some task ->
|
|
run_task task;
|
|
raise_notrace Exit
|
|
| None ->
|
|
incr unsuccessful_steal_attempts;
|
|
Domain_.relax ()
|
|
done;
|
|
false
|
|
with Exit -> true
|
|
)
|
|
in
|
|
|
|
let get_task_from_main_queue_block () : task option =
|
|
try
|
|
Mutex.lock self.mc.mutex;
|
|
while true do
|
|
match Queue.pop self.main_q with
|
|
| exception Queue.Empty ->
|
|
if A.get self.active then
|
|
Condition.wait self.mc.cond self.mc.mutex
|
|
else (
|
|
(* empty queue and we're closed, time to exit *)
|
|
Mutex.unlock self.mc.mutex;
|
|
raise_notrace Exit
|
|
)
|
|
| task ->
|
|
Mutex.unlock self.mc.mutex;
|
|
raise_notrace (Got_task task)
|
|
done;
|
|
(* unreachable *)
|
|
assert false
|
|
with
|
|
| Got_task t -> Some t
|
|
| Exit -> None
|
|
in
|
|
|
|
let main_loop () =
|
|
let continue = ref true in
|
|
while !continue do
|
|
run_self_tasks_ ();
|
|
|
|
let did_steal = try_to_steal_work_loop () in
|
|
if not did_steal then (
|
|
match get_task_from_main_queue_block () with
|
|
| None ->
|
|
(* main queue is closed *)
|
|
continue := false
|
|
| Some task -> run_task task
|
|
)
|
|
done;
|
|
assert (WSQ.size w.q = 0)
|
|
in
|
|
|
|
(* handle domain-local await *)
|
|
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await
|
|
~while_running:main_loop
|
|
|
|
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
|
|
|
let shutdown_ ~wait (self : state) : unit =
|
|
if A.exchange self.active false then (
|
|
Mutex.lock self.mc.mutex;
|
|
Condition.broadcast self.mc.cond;
|
|
Mutex.unlock self.mc.mutex;
|
|
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
|
)
|
|
|
|
type ('a, 'b) create_args =
|
|
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
|
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
|
?thread_wrappers:thread_loop_wrapper list ->
|
|
?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
|
|
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
|
|
?min:int ->
|
|
?per_domain:int ->
|
|
'a
|
|
(** Arguments used in {!create}. See {!create} for explanations. *)
|
|
|
|
let create ?(on_init_thread = default_thread_init_exit_)
|
|
?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = [])
|
|
?(on_exn = fun _ _ -> ()) ?around_task ?min:(min_threads = 1)
|
|
?(per_domain = 0) () : t =
|
|
(* wrapper *)
|
|
let around_task =
|
|
match around_task with
|
|
| Some (f, g) -> AT_pair (f, g)
|
|
| None -> AT_pair (ignore, fun _ _ -> ())
|
|
in
|
|
|
|
(* number of threads to run *)
|
|
let min_threads = max 1 min_threads in
|
|
let num_domains = D_pool_.n_domains () in
|
|
assert (num_domains >= 1);
|
|
let num_threads = max min_threads (num_domains * per_domain) in
|
|
|
|
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
|
|
let offset = Random.int num_domains in
|
|
|
|
let workers : worker_state array =
|
|
let dummy = Thread.self () in
|
|
Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () })
|
|
in
|
|
|
|
let pool =
|
|
{
|
|
active = A.make true;
|
|
workers;
|
|
worker_by_id = Int_tbl.create 8;
|
|
main_q = Queue.create ();
|
|
mc = { mutex = Mutex.create (); cond = Condition.create () };
|
|
}
|
|
in
|
|
|
|
let runner =
|
|
Runner.For_runner_implementors.create
|
|
~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
|
|
~run_async:(fun f -> run_async_ pool f)
|
|
~size:(fun () -> size_ pool)
|
|
~num_tasks:(fun () -> num_tasks_ pool)
|
|
()
|
|
in
|
|
|
|
(* temporary queue used to obtain thread handles from domains
|
|
on which the thread are started. *)
|
|
let receive_threads = Bb_queue.create () in
|
|
|
|
(* start the thread with index [i] *)
|
|
let start_thread_with_idx i =
|
|
let w = pool.workers.(i) in
|
|
let dom_idx = (offset + i) mod num_domains in
|
|
|
|
(* function run in the thread itself *)
|
|
let main_thread_fun () : unit =
|
|
let thread = Thread.self () in
|
|
let t_id = Thread.id thread in
|
|
on_init_thread ~dom_id:dom_idx ~t_id ();
|
|
|
|
let run () = worker_thread_ pool runner w ~on_exn ~around_task in
|
|
(* the actual worker loop is [worker_thread_], with all
|
|
wrappers for this pool and for all pools (global_thread_wrappers_) *)
|
|
let run' =
|
|
List.fold_left
|
|
(fun run f -> f ~thread ~pool:runner run)
|
|
run thread_wrappers
|
|
in
|
|
|
|
(* now run the main loop *)
|
|
Fun.protect run' ~finally:(fun () ->
|
|
(* on termination, decrease refcount of underlying domain *)
|
|
D_pool_.decr_on dom_idx);
|
|
on_exit_thread ~dom_id:dom_idx ~t_id ()
|
|
in
|
|
|
|
(* function called in domain with index [i], to
|
|
create the thread and push it into [receive_threads] *)
|
|
let create_thread_in_domain () =
|
|
let thread = Thread.create main_thread_fun () in
|
|
(* send the thread from the domain back to us *)
|
|
Bb_queue.push receive_threads (i, thread)
|
|
in
|
|
|
|
D_pool_.run_on dom_idx create_thread_in_domain
|
|
in
|
|
|
|
(* start all threads, placing them on the domains
|
|
according to their index and [offset] in a round-robin fashion. *)
|
|
for i = 0 to num_threads - 1 do
|
|
start_thread_with_idx i
|
|
done;
|
|
|
|
(* receive the newly created threads back from domains *)
|
|
for _j = 1 to num_threads do
|
|
let i, th = Bb_queue.pop receive_threads in
|
|
let worker_state = pool.workers.(i) in
|
|
worker_state.thread <- th;
|
|
|
|
Mutex.lock pool.mc.mutex;
|
|
Int_tbl.add pool.worker_by_id (Thread.id th) worker_state;
|
|
Mutex.unlock pool.mc.mutex
|
|
done;
|
|
|
|
runner
|
|
|
|
let with_ ?on_init_thread ?on_exit_thread ?thread_wrappers ?on_exn ?around_task
|
|
?min ?per_domain () f =
|
|
let pool =
|
|
create ?on_init_thread ?on_exit_thread ?thread_wrappers ?on_exn ?around_task
|
|
?min ?per_domain ()
|
|
in
|
|
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
|
|
f pool
|