moonpool/src/core/ws_pool.ml
2024-08-16 10:07:51 -04:00

410 lines
12 KiB
OCaml

open Types_
module WSQ = Ws_deque_
module A = Atomic_
module TLS = Thread_local_storage_
include Runner
let ( let@ ) = ( @@ )
module Id = struct
type t = unit ref
(** Unique identifier for a pool *)
let create () : t = Sys.opaque_identity (ref ())
let equal : t -> t -> bool = ( == )
end
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type task_full =
| T_start of {
ls: Task_local_storage.t;
f: task;
}
| T_resume : {
ls: Task_local_storage.t;
k: 'a -> unit;
x: 'a;
}
-> task_full
type worker_state = {
pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t;
q: task_full WSQ.t; (** Work stealing queue *)
mutable cur_ls: Task_local_storage.t option; (** Task storage *)
rng: Random.State.t;
}
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
can come and steal from it if they're idle. *)
type state = {
id_: Id.t;
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
workers: worker_state array; (** Fixed set of workers. *)
main_q: task_full Queue.t;
(** Main queue for tasks coming from the outside *)
mutable n_waiting: int; (* protected by mutex *)
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
mutex: Mutex.t;
cond: Condition.t;
on_exn: exn -> Printexc.raw_backtrace -> unit;
around_task: around_task;
}
(** internal state *)
let[@inline] size_ (self : state) = Array.length self.workers
let num_tasks_ (self : state) : int =
let n = ref 0 in
n := Queue.length self.main_q;
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
!n
(** TLS, used by worker to store their specific state
and be able to retrieve it from tasks when we schedule new
sub-tasks. *)
let k_worker_state : worker_state TLS.t = TLS.create ()
let[@inline] find_current_worker_ () : worker_state option =
TLS.get_opt k_worker_state
(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
if self.n_waiting_nonzero then (
Mutex.lock self.mutex;
Condition.signal self.cond;
Mutex.unlock self.mutex
)
(** Run [task] as is, on the pool. *)
let schedule_task_ (self : state) ~w (task : task_full) : unit =
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
match w with
| Some w when Id.equal self.id_ w.pool_id_ ->
(* we're on this same pool, schedule in the worker's state. Otherwise
we might also be on pool A but asking to schedule on pool B,
so we have to check that identifiers match. *)
let pushed = WSQ.push w.q task in
if pushed then
try_wake_someone_ self
else (
(* overflow into main queue *)
Mutex.lock self.mutex;
Queue.push task self.main_q;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
)
| _ ->
if A.get self.active then (
(* push into the main queue *)
Mutex.lock self.mutex;
Queue.push task self.main_q;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
) else
(* notify the caller that scheduling tasks is no
longer permitted *)
raise Shutdown
(** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
: unit =
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let (AT_pair (before_task, after_task)) = self.around_task in
let ls =
match task with
| T_start { ls; _ } | T_resume { ls; _ } -> ls
in
w.cur_ls <- Some ls;
TLS.set k_cur_storage ls;
let _ctx = before_task runner in
let[@inline] on_suspend () : _ ref =
match find_current_worker_ () with
| Some { cur_ls = Some w; _ } -> w
| _ -> assert false
in
let run_another_task ls (task' : task) =
let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
let ls' = Task_local_storage.Direct.copy ls in
schedule_task_ self ~w @@ T_start { ls = ls'; f = task' }
in
let resume ls k x =
let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
schedule_task_ self ~w @@ T_resume { ls; k; x }
in
(* run the task now, catching errors *)
(try
match task with
| T_start { f = task; _ } ->
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend
(WSH { on_suspend; run = run_another_task; resume })
task
| T_resume { k; x; _ } ->
(* this is already in an effect handler *)
k x
with e ->
let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt);
after_task runner _ctx;
w.cur_ls <- None;
TLS.set k_cur_storage _dummy_ls
let run_async_ (self : state) ~ls (f : task) : unit =
let w = find_current_worker_ () in
schedule_task_ self ~w @@ T_start { f; ls }
(* TODO: function to schedule many tasks from the outside.
- build a queue
- lock
- queue transfer
- wakeup all (broadcast)
- unlock *)
(** Wait on condition. Precondition: we hold the mutex. *)
let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting + 1;
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
Condition.wait self.cond self.mutex;
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
exception Got_task of task_full
(** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option
=
let init = Random.State.int w.rng (Array.length self.workers) in
try
for i = 0 to Array.length self.workers - 1 do
let w' =
Array.unsafe_get self.workers ((i + init) mod Array.length self.workers)
in
if w != w' then (
match WSQ.steal w'.q with
| Some t -> raise_notrace (Got_task t)
| None -> ()
)
done;
None
with Got_task t -> Some t
(** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit =
let continue = ref true in
while !continue && A.get self.active do
match WSQ.pop w.q with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner ~w task
| None -> continue := false
done
(** Main loop for a worker thread. *)
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
TLS.set k_worker_state w;
let rec main () : unit =
worker_run_self_tasks_ self ~runner w;
try_steal ()
and run_task task : unit =
run_task_now_ self ~runner ~w task;
main ()
and try_steal () =
match try_to_steal_work_once_ self w with
| Some task -> run_task task
| None -> wait ()
and wait () =
Mutex.lock self.mutex;
match Queue.pop self.main_q with
| task ->
Mutex.unlock self.mutex;
run_task task
| exception Queue.Empty ->
(* wait here *)
if A.get self.active then (
wait_ self;
(* see if a task became available *)
let task =
try Some (Queue.pop self.main_q) with Queue.Empty -> None
in
Mutex.unlock self.mutex;
match task with
| Some t -> run_task t
| None -> try_steal ()
) else
(* do nothing more: no task in main queue, and we are shutting
down so no new task should arrive.
The exception is if another task is creating subtasks
that overflow into the main queue, but we can ignore that at
the price of slightly decreased performance for the last few
tasks *)
Mutex.unlock self.mutex
in
(* handle domain-local await *)
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
let shutdown_ ~wait (self : state) : unit =
if A.exchange self.active false then (
Mutex.lock self.mutex;
Condition.broadcast self.cond;
Mutex.unlock self.mutex;
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
)
type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?num_threads:int ->
?name:string ->
'a
(** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ : task_full =
T_start { f = ignore; ls = Task_local_storage.dummy }
let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ?num_threads ?name () : t =
let pool_id_ = Id.create () in
(* wrapper *)
let around_task =
match around_task with
| Some (f, g) -> AT_pair (f, g)
| None -> AT_pair (ignore, fun _ _ -> ())
in
let num_domains = Domain_pool_.max_number_of_domains () in
let num_threads = Util_pool_.num_threads ?num_threads () in
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in
let workers : worker_state array =
let dummy = Thread.self () in
Array.init num_threads (fun i ->
{
pool_id_;
thread = dummy;
q = WSQ.create ~dummy:dummy_task_ ();
rng = Random.State.make [| i |];
cur_ls = None;
})
in
let pool =
{
id_ = pool_id_;
active = A.make true;
workers;
main_q = Queue.create ();
n_waiting = 0;
n_waiting_nonzero = true;
mutex = Mutex.create ();
cond = Condition.create ();
around_task;
on_exn;
}
in
let runner =
Runner.For_runner_implementors.create
~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
~run_async:(fun ~ls f -> run_async_ pool ~ls f)
~size:(fun () -> size_ pool)
~num_tasks:(fun () -> num_tasks_ pool)
()
in
(* temporary queue used to obtain thread handles from domains
on which the thread are started. *)
let receive_threads = Bb_queue.create () in
(* start the thread with index [i] *)
let start_thread_with_idx i =
let w = pool.workers.(i) in
let dom_idx = (offset + i) mod num_domains in
(* function run in the thread itself *)
let main_thread_fun () : unit =
let thread = Thread.self () in
let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.set k_cur_storage _dummy_ls;
(* set thread name *)
Option.iter
(fun name ->
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i))
name;
let run () = worker_thread_ pool ~runner w in
(* now run the main loop *)
Fun.protect run ~finally:(fun () ->
(* on termination, decrease refcount of underlying domain *)
Domain_pool_.decr_on dom_idx);
on_exit_thread ~dom_id:dom_idx ~t_id ()
in
(* function called in domain with index [i], to
create the thread and push it into [receive_threads] *)
let create_thread_in_domain () =
let thread = Thread.create main_thread_fun () in
(* send the thread from the domain back to us *)
Bb_queue.push receive_threads (i, thread)
in
Domain_pool_.run_on dom_idx create_thread_in_domain
in
(* start all threads, placing them on the domains
according to their index and [offset] in a round-robin fashion. *)
for i = 0 to num_threads - 1 do
start_thread_with_idx i
done;
(* receive the newly created threads back from domains *)
for _j = 1 to num_threads do
let i, th = Bb_queue.pop receive_threads in
let worker_state = pool.workers.(i) in
worker_state.thread <- th
done;
runner
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name () f =
let pool =
create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name ()
in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool