mirror of
https://github.com/c-cube/moonpool.git
synced 2026-01-27 11:44:50 -05:00
refactor core: use picos for schedulers; add Worker_loop_
we factor most of the thread workers' logic in `Worker_loop_`, which is now shared between Ws_pool and Fifo_pool
This commit is contained in:
parent
c73395635b
commit
9fb23bed4c
15 changed files with 497 additions and 589 deletions
|
|
@ -1,87 +1,39 @@
|
|||
open Types_
|
||||
include Runner
|
||||
module WL = Worker_loop_
|
||||
|
||||
type fiber = Picos.Fiber.t
|
||||
type task_full = WL.task_full
|
||||
|
||||
let ( let@ ) = ( @@ )
|
||||
|
||||
type task_full =
|
||||
| T_start of {
|
||||
ls: Task_local_storage.t;
|
||||
f: task;
|
||||
}
|
||||
| T_resume : {
|
||||
ls: Task_local_storage.t;
|
||||
k: 'a -> unit;
|
||||
x: 'a;
|
||||
}
|
||||
-> task_full
|
||||
|
||||
type state = {
|
||||
threads: Thread.t array;
|
||||
q: task_full Bb_queue.t; (** Queue for tasks. *)
|
||||
around_task: WL.around_task;
|
||||
as_runner: t lazy_t;
|
||||
(* init options *)
|
||||
name: string option;
|
||||
on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
|
||||
on_exit_thread: dom_id:int -> t_id:int -> unit -> unit;
|
||||
on_exn: exn -> Printexc.raw_backtrace -> unit;
|
||||
}
|
||||
(** internal state *)
|
||||
|
||||
type worker_state = {
|
||||
idx: int;
|
||||
dom_idx: int;
|
||||
st: state;
|
||||
mutable current: fiber;
|
||||
}
|
||||
|
||||
let[@inline] size_ (self : state) = Array.length self.threads
|
||||
let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q
|
||||
|
||||
(** Run [task] as is, on the pool. *)
|
||||
let schedule_ (self : state) (task : task_full) : unit =
|
||||
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
|
||||
|
||||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||
type worker_state = { mutable cur_ls: Task_local_storage.t option }
|
||||
|
||||
let k_worker_state : worker_state TLS.t = TLS.create ()
|
||||
|
||||
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
||||
let w = { cur_ls = None } in
|
||||
TLS.set k_worker_state w;
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
|
||||
|
||||
let (AT_pair (before_task, after_task)) = around_task in
|
||||
|
||||
let on_suspend () =
|
||||
match TLS.get_opt k_worker_state with
|
||||
| Some { cur_ls = Some ls; _ } -> ls
|
||||
| _ -> assert false
|
||||
in
|
||||
let run_another_task ls task' = schedule_ self @@ T_start { f = task'; ls } in
|
||||
let resume ls k res = schedule_ self @@ T_resume { ls; k; x = res } in
|
||||
|
||||
let run_task (task : task_full) : unit =
|
||||
let ls =
|
||||
match task with
|
||||
| T_start { ls; _ } | T_resume { ls; _ } -> ls
|
||||
in
|
||||
w.cur_ls <- Some ls;
|
||||
TLS.set k_cur_storage ls;
|
||||
let _ctx = before_task runner in
|
||||
|
||||
(* run the task now, catching errors, handling effects *)
|
||||
(try
|
||||
match task with
|
||||
| T_start { f = task; _ } ->
|
||||
(* run [task()] and handle [suspend] in it *)
|
||||
Suspend_.with_suspend
|
||||
(WSH { on_suspend; run = run_another_task; resume })
|
||||
task
|
||||
| T_resume { k; x; _ } ->
|
||||
(* this is already in an effect handler *)
|
||||
k x
|
||||
with e ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
on_exn e bt);
|
||||
after_task runner _ctx;
|
||||
w.cur_ls <- None;
|
||||
TLS.set k_cur_storage _dummy_ls
|
||||
in
|
||||
|
||||
let continue = ref true in
|
||||
while !continue do
|
||||
match Bb_queue.pop self.q with
|
||||
| task -> run_task task
|
||||
| exception Bb_queue.Closed -> continue := false
|
||||
done
|
||||
(*
|
||||
get_thread_state = TLS.get_opt k_worker_state
|
||||
*)
|
||||
|
||||
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
||||
|
||||
|
|
@ -98,10 +50,14 @@ type ('a, 'b) create_args =
|
|||
?name:string ->
|
||||
'a
|
||||
|
||||
let default_around_task_ : around_task = AT_pair (ignore, fun _ _ -> ())
|
||||
let default_around_task_ : WL.around_task = AT_pair (ignore, fun _ _ -> ())
|
||||
|
||||
(** Run [task] as is, on the pool. *)
|
||||
let schedule_ (self : state) (task : task_full) : unit =
|
||||
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
|
||||
|
||||
let runner_of_state (pool : state) : t =
|
||||
let run_async ~ls f = schedule_ pool @@ T_start { f; ls } in
|
||||
let run_async ~fiber f = schedule_ pool @@ T_start { f; fiber } in
|
||||
Runner.For_runner_implementors.create
|
||||
~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
|
||||
~run_async
|
||||
|
|
@ -109,13 +65,59 @@ let runner_of_state (pool : state) : t =
|
|||
~num_tasks:(fun () -> num_tasks_ pool)
|
||||
()
|
||||
|
||||
(** Run [task] as is, on the pool. *)
|
||||
let schedule_w (self : worker_state) (task : task_full) : unit =
|
||||
try Bb_queue.push self.st.q task with Bb_queue.Closed -> raise Shutdown
|
||||
|
||||
let get_next_task (self : worker_state) =
|
||||
try Bb_queue.pop self.st.q with Bb_queue.Closed -> raise WL.No_more_tasks
|
||||
|
||||
let get_thread_state () =
|
||||
match TLS.get_exn k_worker_state with
|
||||
| st -> st
|
||||
| exception TLS.Not_set ->
|
||||
failwith "Moonpool: get_thread_state called from outside a runner."
|
||||
|
||||
let before_start (self : worker_state) =
|
||||
let t_id = Thread.id @@ Thread.self () in
|
||||
self.st.on_init_thread ~dom_id:self.dom_idx ~t_id ();
|
||||
|
||||
(* set thread name *)
|
||||
Option.iter
|
||||
(fun name ->
|
||||
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx))
|
||||
self.st.name
|
||||
|
||||
let cleanup (self : worker_state) : unit =
|
||||
(* on termination, decrease refcount of underlying domain *)
|
||||
Domain_pool_.decr_on self.dom_idx;
|
||||
let t_id = Thread.id @@ Thread.self () in
|
||||
self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id ()
|
||||
|
||||
let worker_ops : worker_state WL.ops =
|
||||
let runner (st : worker_state) = Lazy.force st.st.as_runner in
|
||||
let around_task st = st.st.around_task in
|
||||
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
|
||||
st.st.on_exn ebt.exn ebt.bt
|
||||
in
|
||||
{
|
||||
WL.schedule = schedule_w;
|
||||
runner;
|
||||
get_next_task;
|
||||
get_thread_state;
|
||||
around_task;
|
||||
on_exn;
|
||||
before_start;
|
||||
cleanup;
|
||||
}
|
||||
|
||||
let create ?(on_init_thread = default_thread_init_exit_)
|
||||
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
|
||||
?around_task ?num_threads ?name () : t =
|
||||
(* wrapper *)
|
||||
let around_task =
|
||||
match around_task with
|
||||
| Some (f, g) -> AT_pair (f, g)
|
||||
| Some (f, g) -> WL.AT_pair (f, g)
|
||||
| None -> default_around_task_
|
||||
in
|
||||
|
||||
|
|
@ -127,9 +129,18 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
|
||||
let offset = Random.int num_domains in
|
||||
|
||||
let pool =
|
||||
let rec pool =
|
||||
let dummy_thread = Thread.self () in
|
||||
{ threads = Array.make num_threads dummy_thread; q = Bb_queue.create () }
|
||||
{
|
||||
threads = Array.make num_threads dummy_thread;
|
||||
q = Bb_queue.create ();
|
||||
around_task;
|
||||
as_runner = lazy (runner_of_state pool);
|
||||
name;
|
||||
on_init_thread;
|
||||
on_exit_thread;
|
||||
on_exn;
|
||||
}
|
||||
in
|
||||
|
||||
let runner = runner_of_state pool in
|
||||
|
|
@ -142,31 +153,11 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
let start_thread_with_idx i =
|
||||
let dom_idx = (offset + i) mod num_domains in
|
||||
|
||||
(* function run in the thread itself *)
|
||||
let main_thread_fun () : unit =
|
||||
let thread = Thread.self () in
|
||||
let t_id = Thread.id thread in
|
||||
on_init_thread ~dom_id:dom_idx ~t_id ();
|
||||
|
||||
(* set thread name *)
|
||||
Option.iter
|
||||
(fun name ->
|
||||
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i))
|
||||
name;
|
||||
|
||||
let run () = worker_thread_ pool runner ~on_exn ~around_task in
|
||||
|
||||
(* now run the main loop *)
|
||||
Fun.protect run ~finally:(fun () ->
|
||||
(* on termination, decrease refcount of underlying domain *)
|
||||
Domain_pool_.decr_on dom_idx);
|
||||
on_exit_thread ~dom_id:dom_idx ~t_id ()
|
||||
in
|
||||
|
||||
(* function called in domain with index [i], to
|
||||
create the thread and push it into [receive_threads] *)
|
||||
let create_thread_in_domain () =
|
||||
let thread = Thread.create main_thread_fun () in
|
||||
let st = { idx = i; dom_idx; st = pool; current = _dummy_fiber } in
|
||||
let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in
|
||||
(* send the thread from the domain back to us *)
|
||||
Bb_queue.push receive_threads (i, thread)
|
||||
in
|
||||
|
|
@ -196,13 +187,3 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
|
|||
in
|
||||
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
|
||||
f pool
|
||||
|
||||
module Private_ = struct
|
||||
type nonrec state = state
|
||||
|
||||
let create_state ~threads () : state = { threads; q = Bb_queue.create () }
|
||||
let runner_of_state = runner_of_state
|
||||
|
||||
let run_thread (st : state) (self : t) ~on_exn : unit =
|
||||
worker_thread_ st self ~on_exn ~around_task:default_around_task_
|
||||
end
|
||||
|
|
|
|||
|
|
@ -44,17 +44,3 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args
|
|||
When [f pool] returns or fails, [pool] is shutdown and its resources
|
||||
are released.
|
||||
Most parameters are the same as in {!create}. *)
|
||||
|
||||
(**/**)
|
||||
|
||||
module Private_ : sig
|
||||
type state
|
||||
|
||||
val create_state : threads:Thread.t array -> unit -> state
|
||||
val runner_of_state : state -> Runner.t
|
||||
|
||||
val run_thread :
|
||||
state -> t -> on_exn:(exn -> Printexc.raw_backtrace -> unit) -> unit
|
||||
end
|
||||
|
||||
(**/**)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ module Ws_pool = Ws_pool
|
|||
|
||||
module Private = struct
|
||||
module Ws_deque_ = Ws_deque_
|
||||
module Suspend_ = Suspend_
|
||||
module Worker_loop_ = Worker_loop_
|
||||
module Domain_ = Domain_
|
||||
module Tracing_ = Tracing_
|
||||
|
||||
|
|
|
|||
|
|
@ -33,13 +33,13 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t
|
|||
to run the thread. This ensures that we don't always pick the same domain
|
||||
to run all the various threads needed in an application (timers, event loops, etc.) *)
|
||||
|
||||
val run_async : ?ls:Task_local_storage.t -> Runner.t -> (unit -> unit) -> unit
|
||||
val run_async : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> unit) -> unit
|
||||
(** [run_async runner task] schedules the task to run
|
||||
on the given runner. This means [task()] will be executed
|
||||
at some point in the future, possibly in another thread.
|
||||
@since 0.5 *)
|
||||
|
||||
val run_wait_block : ?ls:Task_local_storage.t -> Runner.t -> (unit -> 'a) -> 'a
|
||||
val run_wait_block : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> 'a) -> 'a
|
||||
(** [run_wait_block runner f] schedules [f] for later execution
|
||||
on the runner, like {!run_async}.
|
||||
It then blocks the current thread until [f()] is done executing,
|
||||
|
|
@ -212,16 +212,10 @@ module Private : sig
|
|||
module Ws_deque_ = Ws_deque_
|
||||
(** A deque for work stealing, fixed size. *)
|
||||
|
||||
(** {2 Suspensions} *)
|
||||
|
||||
module Suspend_ = Suspend_
|
||||
[@@alert
|
||||
unstable "this module is an implementation detail of moonpool for now"]
|
||||
(** Suspensions.
|
||||
|
||||
This is only going to work on OCaml 5.x.
|
||||
|
||||
{b NOTE}: this is not stable for now. *)
|
||||
module Worker_loop_ = Worker_loop_
|
||||
(** Worker loop. This is useful to implement custom runners, it
|
||||
should run on each thread of the runner.
|
||||
@since NEXT_RELEASE *)
|
||||
|
||||
module Domain_ = Domain_
|
||||
(** Utils for domains *)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
open Types_
|
||||
|
||||
type fiber = Picos.Fiber.t
|
||||
type task = unit -> unit
|
||||
|
||||
type t = runner = {
|
||||
run_async: ls:local_storage -> task -> unit;
|
||||
run_async: fiber:fiber -> task -> unit;
|
||||
shutdown: wait:bool -> unit -> unit;
|
||||
size: unit -> int;
|
||||
num_tasks: unit -> int;
|
||||
|
|
@ -11,8 +12,15 @@ type t = runner = {
|
|||
|
||||
exception Shutdown
|
||||
|
||||
let[@inline] run_async ?(ls = create_local_storage ()) (self : t) f : unit =
|
||||
self.run_async ~ls f
|
||||
let[@inline] run_async ?fiber (self : t) f : unit =
|
||||
let fiber =
|
||||
match fiber with
|
||||
| Some f -> f
|
||||
| None ->
|
||||
let comp = Picos.Computation.create () in
|
||||
Picos.Fiber.create ~forbid:false comp
|
||||
in
|
||||
self.run_async ~fiber f
|
||||
|
||||
let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true ()
|
||||
|
||||
|
|
@ -22,9 +30,9 @@ let[@inline] shutdown_without_waiting (self : t) : unit =
|
|||
let[@inline] num_tasks (self : t) : int = self.num_tasks ()
|
||||
let[@inline] size (self : t) : int = self.size ()
|
||||
|
||||
let run_wait_block ?ls self (f : unit -> 'a) : 'a =
|
||||
let run_wait_block ?fiber self (f : unit -> 'a) : 'a =
|
||||
let q = Bb_queue.create () in
|
||||
run_async ?ls self (fun () ->
|
||||
run_async ?fiber self (fun () ->
|
||||
try
|
||||
let x = f () in
|
||||
Bb_queue.push q (Ok x)
|
||||
|
|
@ -47,9 +55,9 @@ let dummy : t =
|
|||
~size:(fun () -> 0)
|
||||
~num_tasks:(fun () -> 0)
|
||||
~shutdown:(fun ~wait:_ () -> ())
|
||||
~run_async:(fun ~ls:_ _ ->
|
||||
~run_async:(fun ~fiber:_ _ ->
|
||||
failwith "Runner.dummy: cannot actually run tasks")
|
||||
()
|
||||
|
||||
let get_current_runner = get_current_runner
|
||||
let get_current_storage = get_current_storage
|
||||
let get_current_fiber = get_current_fiber
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
@since 0.3
|
||||
*)
|
||||
|
||||
type fiber = Picos.Fiber.t
|
||||
type task = unit -> unit
|
||||
|
||||
type t
|
||||
|
|
@ -33,14 +34,14 @@ val shutdown_without_waiting : t -> unit
|
|||
|
||||
exception Shutdown
|
||||
|
||||
val run_async : ?ls:Task_local_storage.t -> t -> task -> unit
|
||||
val run_async : ?fiber:fiber -> t -> task -> unit
|
||||
(** [run_async pool f] schedules [f] for later execution on the runner
|
||||
in one of the threads. [f()] will run on one of the runner's
|
||||
worker threads/domains.
|
||||
@param ls if provided, run the task with this initial local storage
|
||||
@raise Shutdown if the runner was shut down before [run_async] was called. *)
|
||||
|
||||
val run_wait_block : ?ls:Task_local_storage.t -> t -> (unit -> 'a) -> 'a
|
||||
val run_wait_block : ?fiber:fiber -> t -> (unit -> 'a) -> 'a
|
||||
(** [run_wait_block pool f] schedules [f] for later execution
|
||||
on the pool, like {!run_async}.
|
||||
It then blocks the current thread until [f()] is done executing,
|
||||
|
|
@ -65,7 +66,7 @@ module For_runner_implementors : sig
|
|||
size:(unit -> int) ->
|
||||
num_tasks:(unit -> int) ->
|
||||
shutdown:(wait:bool -> unit -> unit) ->
|
||||
run_async:(ls:Task_local_storage.t -> task -> unit) ->
|
||||
run_async:(fiber:fiber -> task -> unit) ->
|
||||
unit ->
|
||||
t
|
||||
(** Create a new runner.
|
||||
|
|
@ -85,6 +86,6 @@ val get_current_runner : unit -> t option
|
|||
happens on a thread that belongs in a runner.
|
||||
@since 0.5 *)
|
||||
|
||||
val get_current_storage : unit -> Task_local_storage.t option
|
||||
val get_current_fiber : unit -> fiber option
|
||||
(** [get_current_storage runner] gets the local storage
|
||||
for the currently running task. *)
|
||||
|
|
|
|||
|
|
@ -1,70 +0,0 @@
|
|||
type suspension = unit Exn_bt.result -> unit
|
||||
type task = unit -> unit
|
||||
|
||||
type suspension_handler = {
|
||||
handle:
|
||||
run:(task -> unit) ->
|
||||
resume:(suspension -> unit Exn_bt.result -> unit) ->
|
||||
suspension ->
|
||||
unit;
|
||||
}
|
||||
[@@unboxed]
|
||||
|
||||
type with_suspend_handler =
|
||||
| WSH : {
|
||||
on_suspend: unit -> 'state;
|
||||
(** on_suspend called when [f()] suspends itself. *)
|
||||
run: 'state -> task -> unit; (** run used to schedule new tasks *)
|
||||
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
|
||||
(** resume run the suspension. Must be called exactly once. *)
|
||||
}
|
||||
-> with_suspend_handler
|
||||
|
||||
[@@@ifge 5.0]
|
||||
[@@@ocaml.alert "-unstable"]
|
||||
|
||||
module A = Atomic_
|
||||
|
||||
type _ Effect.t +=
|
||||
| Suspend : suspension_handler -> unit Effect.t
|
||||
| Yield : unit Effect.t
|
||||
|
||||
let[@inline] yield () = Effect.perform Yield
|
||||
let[@inline] suspend h = Effect.perform (Suspend h)
|
||||
|
||||
let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit =
|
||||
let module E = Effect.Deep in
|
||||
(* effect handler *)
|
||||
let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option =
|
||||
function
|
||||
| Suspend h ->
|
||||
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
|
||||
Some
|
||||
(fun k ->
|
||||
let state = on_suspend () in
|
||||
let k' : suspension = function
|
||||
| Ok () -> E.continue k ()
|
||||
| Error ebt -> Exn_bt.discontinue k ebt
|
||||
in
|
||||
h.handle ~run:(run state) ~resume:(resume state) k')
|
||||
| Yield ->
|
||||
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
|
||||
Some
|
||||
(fun k ->
|
||||
let state = on_suspend () in
|
||||
let k' : suspension = function
|
||||
| Ok () -> E.continue k ()
|
||||
| Error ebt -> Exn_bt.discontinue k ebt
|
||||
in
|
||||
resume state k' @@ Ok ())
|
||||
| _ -> None
|
||||
in
|
||||
|
||||
E.try_with f () { E.effc }
|
||||
|
||||
[@@@ocaml.alert "+unstable"]
|
||||
[@@@else_]
|
||||
|
||||
let[@inline] with_suspend (WSH _) f = f ()
|
||||
|
||||
[@@@endif]
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
(** (Private) suspending tasks using Effects.
|
||||
|
||||
This module is an implementation detail of Moonpool and should
|
||||
not be used outside of it, except by experts to implement {!Runner}. *)
|
||||
|
||||
type suspension = unit Exn_bt.result -> unit
|
||||
(** A suspended computation *)
|
||||
|
||||
type task = unit -> unit
|
||||
|
||||
type suspension_handler = {
|
||||
handle:
|
||||
run:(task -> unit) ->
|
||||
resume:(suspension -> unit Exn_bt.result -> unit) ->
|
||||
suspension ->
|
||||
unit;
|
||||
}
|
||||
[@@unboxed]
|
||||
(** The handler that knows what to do with the suspended computation.
|
||||
|
||||
The handler is given a few things:
|
||||
|
||||
- the suspended computation (which can be resumed with a result
|
||||
eventually);
|
||||
- a [run] function that can be used to start tasks to perform some
|
||||
computation.
|
||||
- a [resume] function to resume the suspended computation. This
|
||||
must be called exactly once, in all situations.
|
||||
|
||||
This means that a fork-join primitive, for example, can use a single call
|
||||
to {!suspend} to:
|
||||
- suspend the caller until the fork-join is done
|
||||
- use [run] to start all the tasks. Typically [run] is called multiple times,
|
||||
which is where the "fork" part comes from. Each call to [run] potentially
|
||||
runs in parallel with the other calls. The calls must coordinate so
|
||||
that, once they are all done, the suspended caller is resumed with the
|
||||
aggregated result of the computation.
|
||||
- use [resume] exactly
|
||||
*)
|
||||
|
||||
[@@@ifge 5.0]
|
||||
[@@@ocaml.alert "-unstable"]
|
||||
|
||||
type _ Effect.t +=
|
||||
| Suspend : suspension_handler -> unit Effect.t
|
||||
(** The effect used to suspend the current thread and pass it, suspended,
|
||||
to the handler. The handler will ensure that the suspension is resumed later
|
||||
once some computation has been done. *)
|
||||
| Yield : unit Effect.t
|
||||
(** The effect used to interrupt the current computation and immediately re-schedule
|
||||
it on the same runner. *)
|
||||
|
||||
[@@@ocaml.alert "+unstable"]
|
||||
|
||||
val yield : unit -> unit
|
||||
(** Interrupt current computation, and re-schedule it at the end of the
|
||||
runner's job queue. *)
|
||||
|
||||
val suspend : suspension_handler -> unit
|
||||
(** [suspend h] jumps back to the nearest {!with_suspend}
|
||||
and calls [h.handle] with the current continuation [k]
|
||||
and a task runner function.
|
||||
*)
|
||||
|
||||
[@@@endif]
|
||||
|
||||
type with_suspend_handler =
|
||||
| WSH : {
|
||||
on_suspend: unit -> 'state;
|
||||
(** on_suspend called when [f()] suspends itself. *)
|
||||
run: 'state -> task -> unit; (** run used to schedule new tasks *)
|
||||
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
|
||||
(** resume run the suspension. Must be called exactly once. *)
|
||||
}
|
||||
-> with_suspend_handler
|
||||
|
||||
val with_suspend : with_suspend_handler -> (unit -> unit) -> unit
|
||||
(** [with_suspend wsh f]
|
||||
runs [f()] in an environment where [suspend] will work.
|
||||
|
||||
If [f()] suspends with suspension handler [h],
|
||||
this calls [wsh.on_suspend()] to capture the current state [st].
|
||||
Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension.
|
||||
The suspension should always be passed exactly once to
|
||||
[resume]. [run] should be used to start other tasks.
|
||||
*)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
open Types_
|
||||
(*
|
||||
module A = Atomic_
|
||||
|
||||
type 'a key = 'a ls_key
|
||||
|
|
@ -79,3 +80,4 @@ let with_value key x f =
|
|||
Fun.protect ~finally:(fun () -> set key old) f
|
||||
|
||||
let get_current = get_current_storage
|
||||
*)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
@since 0.6
|
||||
*)
|
||||
|
||||
(*
|
||||
type t = Types_.local_storage
|
||||
(** Underlying storage for a task. This is mutable and
|
||||
not thread-safe. *)
|
||||
|
|
@ -65,3 +66,4 @@ module Direct : sig
|
|||
val create : unit -> t
|
||||
val copy : t -> t
|
||||
end
|
||||
*)
|
||||
|
|
|
|||
|
|
@ -1,36 +1,37 @@
|
|||
module TLS = Thread_local_storage
|
||||
module Domain_pool_ = Moonpool_dpool
|
||||
|
||||
(* TODO: replace with Picos.Fiber.FLS *)
|
||||
type ls_value = ..
|
||||
|
||||
(** Key for task local storage *)
|
||||
module type LS_KEY = sig
|
||||
type t
|
||||
type ls_value += V of t
|
||||
|
||||
val offset : int
|
||||
(** Unique offset *)
|
||||
|
||||
val init : unit -> t
|
||||
end
|
||||
|
||||
type 'a ls_key = (module LS_KEY with type t = 'a)
|
||||
(** A LS key (task local storage) *)
|
||||
|
||||
type task = unit -> unit
|
||||
type local_storage = ls_value array ref
|
||||
type fiber = Picos.Fiber.t
|
||||
|
||||
type runner = {
|
||||
run_async: ls:local_storage -> task -> unit;
|
||||
run_async: fiber:fiber -> task -> unit;
|
||||
shutdown: wait:bool -> unit -> unit;
|
||||
size: unit -> int;
|
||||
num_tasks: unit -> int;
|
||||
}
|
||||
|
||||
let k_cur_runner : runner TLS.t = TLS.create ()
|
||||
let k_cur_storage : local_storage TLS.t = TLS.create ()
|
||||
let _dummy_ls : local_storage = ref [||]
|
||||
[@@alert todo "remove me asap, done via picos now"]
|
||||
|
||||
let k_cur_fiber : fiber TLS.t = TLS.create ()
|
||||
[@@alert todo "remove me asap, done via picos now"]
|
||||
|
||||
let _dummy_computation : Picos.Computation.packed =
|
||||
let c = Picos.Computation.create () in
|
||||
Picos.Computation.cancel c
|
||||
{ exn = Failure "dummy fiber"; bt = Printexc.get_callstack 0 };
|
||||
Picos.Computation.Packed c
|
||||
|
||||
let _dummy_fiber = Picos.Fiber.create_packed ~forbid:true _dummy_computation
|
||||
let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner
|
||||
let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage
|
||||
let[@inline] create_local_storage () = ref [||]
|
||||
|
||||
let[@inline] get_current_fiber () : fiber option =
|
||||
match TLS.get_exn k_cur_fiber with
|
||||
| f when f != _dummy_fiber -> Some f
|
||||
| _ -> None
|
||||
|
||||
let[@inline] get_current_fiber_exn () : fiber =
|
||||
match TLS.get_exn k_cur_fiber with
|
||||
| f when f != _dummy_fiber -> f
|
||||
| _ -> failwith "Moonpool: get_current_fiber was called outside of a fiber."
|
||||
|
|
|
|||
153
src/core/worker_loop_.ml
Normal file
153
src/core/worker_loop_.ml
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
open Types_
|
||||
|
||||
type fiber = Picos.Fiber.t
|
||||
|
||||
type task_full =
|
||||
| T_start of {
|
||||
fiber: fiber;
|
||||
f: unit -> unit;
|
||||
}
|
||||
| T_resume : {
|
||||
fiber: fiber;
|
||||
k: unit -> unit;
|
||||
}
|
||||
-> task_full
|
||||
|
||||
type around_task =
|
||||
| AT_pair : (Runner.t -> 'a) * (Runner.t -> 'a -> unit) -> around_task
|
||||
|
||||
exception No_more_tasks
|
||||
|
||||
type 'st ops = {
|
||||
schedule: 'st -> task_full -> unit;
|
||||
get_next_task: 'st -> task_full; (** @raise No_more_tasks *)
|
||||
get_thread_state: unit -> 'st;
|
||||
(** Access current thread's worker state from any worker *)
|
||||
around_task: 'st -> around_task;
|
||||
on_exn: 'st -> Exn_bt.t -> unit;
|
||||
runner: 'st -> Runner.t;
|
||||
before_start: 'st -> unit;
|
||||
cleanup: 'st -> unit;
|
||||
}
|
||||
|
||||
(** A dummy task. *)
|
||||
let _dummy_task : task_full = T_start { f = ignore; fiber = _dummy_fiber }
|
||||
|
||||
[@@@ifge 5.0]
|
||||
|
||||
let[@inline] discontinue k exn =
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
Effect.Deep.discontinue_with_backtrace k exn bt
|
||||
|
||||
let with_handler (type st arg) ~(ops : st ops) (self : st) :
|
||||
(unit -> unit) -> unit =
|
||||
let current =
|
||||
Some
|
||||
(fun k ->
|
||||
match get_current_fiber_exn () with
|
||||
| fiber -> Effect.Deep.continue k fiber
|
||||
| exception exn -> discontinue k exn)
|
||||
and yield =
|
||||
Some
|
||||
(fun k ->
|
||||
let fiber = get_current_fiber_exn () in
|
||||
match
|
||||
let k () = Effect.Deep.continue k () in
|
||||
ops.schedule self @@ T_resume { fiber; k }
|
||||
with
|
||||
| () -> ()
|
||||
| exception exn -> discontinue k exn)
|
||||
and reschedule trigger fiber k : unit =
|
||||
ignore (Picos.Fiber.unsuspend fiber trigger : bool);
|
||||
let k () = Picos.Fiber.resume fiber k in
|
||||
let task = T_resume { fiber; k } in
|
||||
ops.schedule self task
|
||||
in
|
||||
let effc (type a) :
|
||||
a Effect.t -> ((a, _) Effect.Deep.continuation -> _) option = function
|
||||
| Picos.Fiber.Current -> current
|
||||
| Picos.Fiber.Yield -> yield
|
||||
| Picos.Fiber.Spawn r ->
|
||||
Some
|
||||
(fun k ->
|
||||
match
|
||||
let f () = r.main r.fiber in
|
||||
let task = T_start { fiber = r.fiber; f } in
|
||||
ops.schedule self task
|
||||
with
|
||||
| unit -> Effect.Deep.continue k unit
|
||||
| exception exn -> discontinue k exn)
|
||||
| Picos.Trigger.Await trigger ->
|
||||
Some
|
||||
(fun k ->
|
||||
let fiber = get_current_fiber_exn () in
|
||||
(* when triggers is signaled, reschedule task *)
|
||||
if not (Picos.Fiber.try_suspend fiber trigger fiber k reschedule) then
|
||||
(* trigger was already signaled, run task now *)
|
||||
Picos.Fiber.resume fiber k)
|
||||
| Picos.Computation.Cancel_after _r ->
|
||||
Some
|
||||
(fun k ->
|
||||
(* not implemented *)
|
||||
let exn = Failure "Moonpool: cancel_after is not implemented" in
|
||||
discontinue k exn)
|
||||
| _ -> None
|
||||
in
|
||||
let handler = Effect.Deep.{ retc = Fun.id; exnc = raise; effc } in
|
||||
fun f -> Effect.Deep.match_with f () handler
|
||||
|
||||
[@@@else_]
|
||||
|
||||
let with_handler ~ops:_ self f = f ()
|
||||
|
||||
[@@@endif]
|
||||
|
||||
let worker_loop (type st) ~(ops : st ops) (self : st) : unit =
|
||||
let cur_fiber : fiber ref = ref _dummy_fiber in
|
||||
let runner = ops.runner self in
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
|
||||
|
||||
let (AT_pair (before_task, after_task)) = ops.around_task self in
|
||||
|
||||
let run_task (task : task_full) : unit =
|
||||
let fiber =
|
||||
match task with
|
||||
| T_start { fiber; _ } | T_resume { fiber; _ } -> fiber
|
||||
in
|
||||
|
||||
cur_fiber := fiber;
|
||||
TLS.set k_cur_fiber fiber;
|
||||
let _ctx = before_task runner in
|
||||
|
||||
(* run the task now, catching errors, handling effects *)
|
||||
assert (task != _dummy_task);
|
||||
(try
|
||||
match task with
|
||||
| T_start { fiber = _; f } -> with_handler ~ops self f
|
||||
| T_resume { fiber = _; k } ->
|
||||
(* this is already in an effect handler *)
|
||||
k ()
|
||||
with e ->
|
||||
let ebt = Exn_bt.get e in
|
||||
ops.on_exn self ebt);
|
||||
|
||||
after_task runner _ctx;
|
||||
|
||||
cur_fiber := _dummy_fiber;
|
||||
TLS.set k_cur_fiber _dummy_fiber
|
||||
in
|
||||
|
||||
ops.before_start self;
|
||||
|
||||
let continue = ref true in
|
||||
try
|
||||
while !continue do
|
||||
match ops.get_next_task self with
|
||||
| task -> run_task task
|
||||
| exception No_more_tasks -> continue := false
|
||||
done;
|
||||
ops.cleanup self
|
||||
with exn ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
ops.cleanup self;
|
||||
Printexc.raise_with_backtrace exn bt
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
open Types_
|
||||
module WSQ = Ws_deque_
|
||||
module A = Atomic_
|
||||
module WSQ = Ws_deque_
|
||||
module WL = Worker_loop_
|
||||
include Runner
|
||||
|
||||
let ( let@ ) = ( @@ )
|
||||
|
|
@ -13,46 +14,39 @@ module Id = struct
|
|||
let equal : t -> t -> bool = ( == )
|
||||
end
|
||||
|
||||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||
|
||||
type task_full =
|
||||
| T_start of {
|
||||
ls: Task_local_storage.t;
|
||||
f: task;
|
||||
}
|
||||
| T_resume : {
|
||||
ls: Task_local_storage.t;
|
||||
k: 'a -> unit;
|
||||
x: 'a;
|
||||
}
|
||||
-> task_full
|
||||
|
||||
type worker_state = {
|
||||
pool_id_: Id.t; (** Unique per pool *)
|
||||
mutable thread: Thread.t;
|
||||
q: task_full WSQ.t; (** Work stealing queue *)
|
||||
mutable cur_ls: Task_local_storage.t option; (** Task storage *)
|
||||
rng: Random.State.t;
|
||||
}
|
||||
(** State for a given worker. Only this worker is
|
||||
allowed to push into the queue, but other workers
|
||||
can come and steal from it if they're idle. *)
|
||||
|
||||
type state = {
|
||||
id_: Id.t;
|
||||
(** Unique per pool. Used to make sure tasks stay within the same pool. *)
|
||||
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
|
||||
workers: worker_state array; (** Fixed set of workers. *)
|
||||
main_q: task_full Queue.t;
|
||||
mutable workers: worker_state array; (** Fixed set of workers. *)
|
||||
main_q: WL.task_full Queue.t;
|
||||
(** Main queue for tasks coming from the outside *)
|
||||
mutable n_waiting: int; (* protected by mutex *)
|
||||
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
|
||||
mutex: Mutex.t;
|
||||
cond: Condition.t;
|
||||
as_runner: t lazy_t;
|
||||
(* init options *)
|
||||
around_task: WL.around_task;
|
||||
name: string option;
|
||||
on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
|
||||
on_exit_thread: dom_id:int -> t_id:int -> unit -> unit;
|
||||
on_exn: exn -> Printexc.raw_backtrace -> unit;
|
||||
around_task: around_task;
|
||||
}
|
||||
(** internal state *)
|
||||
|
||||
and worker_state = {
|
||||
mutable thread: Thread.t;
|
||||
idx: int;
|
||||
dom_id: int;
|
||||
st: state;
|
||||
q: WL.task_full WSQ.t; (** Work stealing queue *)
|
||||
rng: Random.State.t;
|
||||
}
|
||||
(** State for a given worker. Only this worker is
|
||||
allowed to push into the queue, but other workers
|
||||
can come and steal from it if they're idle. *)
|
||||
|
||||
let[@inline] size_ (self : state) = Array.length self.workers
|
||||
|
||||
let num_tasks_ (self : state) : int =
|
||||
|
|
@ -66,9 +60,15 @@ let num_tasks_ (self : state) : int =
|
|||
sub-tasks. *)
|
||||
let k_worker_state : worker_state TLS.t = TLS.create ()
|
||||
|
||||
let[@inline] find_current_worker_ () : worker_state option =
|
||||
let[@inline] get_current_worker_ () : worker_state option =
|
||||
TLS.get_opt k_worker_state
|
||||
|
||||
let[@inline] get_current_worker_exn () : worker_state =
|
||||
match TLS.get_exn k_worker_state with
|
||||
| w -> w
|
||||
| exception TLS.Not_set ->
|
||||
failwith "Moonpool: get_current_runner was called from outside a pool."
|
||||
|
||||
(** Try to wake up a waiter, if there's any. *)
|
||||
let[@inline] try_wake_someone_ (self : state) : unit =
|
||||
if self.n_waiting_nonzero then (
|
||||
|
|
@ -77,194 +77,144 @@ let[@inline] try_wake_someone_ (self : state) : unit =
|
|||
Mutex.unlock self.mutex
|
||||
)
|
||||
|
||||
(** Run [task] as is, on the pool. *)
|
||||
let schedule_task_ (self : state) ~w (task : task_full) : unit =
|
||||
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
|
||||
match w with
|
||||
| Some w when Id.equal self.id_ w.pool_id_ ->
|
||||
(* we're on this same pool, schedule in the worker's state. Otherwise
|
||||
we might also be on pool A but asking to schedule on pool B,
|
||||
so we have to check that identifiers match. *)
|
||||
let pushed = WSQ.push w.q task in
|
||||
if pushed then
|
||||
try_wake_someone_ self
|
||||
else (
|
||||
(* overflow into main queue *)
|
||||
Mutex.lock self.mutex;
|
||||
Queue.push task self.main_q;
|
||||
if self.n_waiting_nonzero then Condition.signal self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
let schedule_on_w (self : worker_state) task : unit =
|
||||
(* we're on this same pool, schedule in the worker's state. Otherwise
|
||||
we might also be on pool A but asking to schedule on pool B,
|
||||
so we have to check that identifiers match. *)
|
||||
let pushed = WSQ.push self.q task in
|
||||
if pushed then
|
||||
try_wake_someone_ self.st
|
||||
else (
|
||||
(* overflow into main queue *)
|
||||
Mutex.lock self.st.mutex;
|
||||
Queue.push task self.st.main_q;
|
||||
if self.st.n_waiting_nonzero then Condition.signal self.st.cond;
|
||||
Mutex.unlock self.st.mutex
|
||||
)
|
||||
|
||||
let schedule_on_main (self : state) task : unit =
|
||||
if A.get self.active then (
|
||||
(* push into the main queue *)
|
||||
Mutex.lock self.mutex;
|
||||
Queue.push task self.main_q;
|
||||
if self.n_waiting_nonzero then Condition.signal self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
) else
|
||||
(* notify the caller that scheduling tasks is no
|
||||
longer permitted *)
|
||||
raise Shutdown
|
||||
|
||||
let schedule_from_w (self : worker_state) (task : WL.task_full) : unit =
|
||||
match get_current_worker_ () with
|
||||
| Some w when Id.equal self.st.id_ w.st.id_ ->
|
||||
(* use worker from the same pool *)
|
||||
schedule_on_w w task
|
||||
| _ -> schedule_on_main self.st task
|
||||
|
||||
exception Got_task of WL.task_full
|
||||
|
||||
(** Try to steal a task.
|
||||
@raise Got_task if it finds one. *)
|
||||
let try_to_steal_work_once_ (self : worker_state) : unit =
|
||||
let init = Random.State.int self.rng (Array.length self.st.workers) in
|
||||
for i = 0 to Array.length self.st.workers - 1 do
|
||||
let w' =
|
||||
Array.unsafe_get self.st.workers
|
||||
((i + init) mod Array.length self.st.workers)
|
||||
in
|
||||
|
||||
if self != w' then (
|
||||
match WSQ.steal w'.q with
|
||||
| Some t -> raise_notrace (Got_task t)
|
||||
| None -> ()
|
||||
)
|
||||
| _ ->
|
||||
if A.get self.active then (
|
||||
(* push into the main queue *)
|
||||
Mutex.lock self.mutex;
|
||||
Queue.push task self.main_q;
|
||||
if self.n_waiting_nonzero then Condition.signal self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
) else
|
||||
(* notify the caller that scheduling tasks is no
|
||||
longer permitted *)
|
||||
raise Shutdown
|
||||
|
||||
(** Run this task, now. Must be called from a worker. *)
|
||||
let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
|
||||
: unit =
|
||||
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
|
||||
let (AT_pair (before_task, after_task)) = self.around_task in
|
||||
|
||||
let ls =
|
||||
match task with
|
||||
| T_start { ls; _ } | T_resume { ls; _ } -> ls
|
||||
in
|
||||
|
||||
w.cur_ls <- Some ls;
|
||||
TLS.set k_cur_storage ls;
|
||||
let _ctx = before_task runner in
|
||||
|
||||
let[@inline] on_suspend () : _ ref =
|
||||
match find_current_worker_ () with
|
||||
| Some { cur_ls = Some w; _ } -> w
|
||||
| _ -> assert false
|
||||
in
|
||||
|
||||
let run_another_task ls (task' : task) =
|
||||
let w =
|
||||
match find_current_worker_ () with
|
||||
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
|
||||
| _ -> None
|
||||
in
|
||||
let ls' = Task_local_storage.Direct.copy ls in
|
||||
schedule_task_ self ~w @@ T_start { ls = ls'; f = task' }
|
||||
in
|
||||
|
||||
let resume ls k x =
|
||||
let w =
|
||||
match find_current_worker_ () with
|
||||
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
|
||||
| _ -> None
|
||||
in
|
||||
schedule_task_ self ~w @@ T_resume { ls; k; x }
|
||||
in
|
||||
|
||||
(* run the task now, catching errors *)
|
||||
(try
|
||||
match task with
|
||||
| T_start { f = task; _ } ->
|
||||
(* run [task()] and handle [suspend] in it *)
|
||||
Suspend_.with_suspend
|
||||
(WSH { on_suspend; run = run_another_task; resume })
|
||||
task
|
||||
| T_resume { k; x; _ } ->
|
||||
(* this is already in an effect handler *)
|
||||
k x
|
||||
with e ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
self.on_exn e bt);
|
||||
|
||||
after_task runner _ctx;
|
||||
w.cur_ls <- None;
|
||||
TLS.set k_cur_storage _dummy_ls
|
||||
|
||||
let run_async_ (self : state) ~ls (f : task) : unit =
|
||||
let w = find_current_worker_ () in
|
||||
schedule_task_ self ~w @@ T_start { f; ls }
|
||||
|
||||
(* TODO: function to schedule many tasks from the outside.
|
||||
- build a queue
|
||||
- lock
|
||||
- queue transfer
|
||||
- wakeup all (broadcast)
|
||||
- unlock *)
|
||||
done
|
||||
|
||||
(** Wait on condition. Precondition: we hold the mutex. *)
|
||||
let[@inline] wait_ (self : state) : unit =
|
||||
let[@inline] wait_for_condition_ (self : state) : unit =
|
||||
self.n_waiting <- self.n_waiting + 1;
|
||||
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
|
||||
Condition.wait self.cond self.mutex;
|
||||
self.n_waiting <- self.n_waiting - 1;
|
||||
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
|
||||
|
||||
exception Got_task of task_full
|
||||
let rec get_next_task (self : worker_state) : WL.task_full =
|
||||
if not (A.get self.st.active) then raise WL.No_more_tasks;
|
||||
match WSQ.pop_exn self.q with
|
||||
| task ->
|
||||
try_wake_someone_ self.st;
|
||||
task
|
||||
| exception WSQ.Empty -> try_steal_from_other_workers_ self
|
||||
|
||||
(** Try to steal a task *)
|
||||
let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option
|
||||
=
|
||||
let init = Random.State.int w.rng (Array.length self.workers) in
|
||||
and try_steal_from_other_workers_ (self : worker_state) =
|
||||
match try_to_steal_work_once_ self with
|
||||
| exception Got_task task -> task
|
||||
| () -> wait_on_worker self
|
||||
|
||||
try
|
||||
for i = 0 to Array.length self.workers - 1 do
|
||||
let w' =
|
||||
Array.unsafe_get self.workers ((i + init) mod Array.length self.workers)
|
||||
in
|
||||
and wait_on_worker (self : worker_state) : WL.task_full =
|
||||
Mutex.lock self.st.mutex;
|
||||
match Queue.pop self.st.main_q with
|
||||
| task ->
|
||||
Mutex.unlock self.st.mutex;
|
||||
task
|
||||
| exception Queue.Empty ->
|
||||
(* wait here *)
|
||||
if A.get self.st.active then (
|
||||
wait_for_condition_ self.st;
|
||||
|
||||
if w != w' then (
|
||||
match WSQ.steal w'.q with
|
||||
| Some t -> raise_notrace (Got_task t)
|
||||
| None -> ()
|
||||
)
|
||||
done;
|
||||
None
|
||||
with Got_task t -> Some t
|
||||
(* see if a task became available *)
|
||||
match Queue.pop self.st.main_q with
|
||||
| task ->
|
||||
Mutex.unlock self.st.mutex;
|
||||
task
|
||||
| exception Queue.Empty -> try_steal_from_other_workers_ self
|
||||
) else (
|
||||
(* do nothing more: no task in main queue, and we are shutting
|
||||
down so no new task should arrive.
|
||||
The exception is if another task is creating subtasks
|
||||
that overflow into the main queue, but we can ignore that at
|
||||
the price of slightly decreased performance for the last few
|
||||
tasks *)
|
||||
Mutex.unlock self.st.mutex;
|
||||
raise WL.No_more_tasks
|
||||
)
|
||||
|
||||
(** Worker runs tasks from its queue until none remains *)
|
||||
let worker_run_self_tasks_ (self : state) ~runner w : unit =
|
||||
let continue = ref true in
|
||||
while !continue && A.get self.active do
|
||||
match WSQ.pop w.q with
|
||||
| Some task ->
|
||||
try_wake_someone_ self;
|
||||
run_task_now_ self ~runner ~w task
|
||||
| None -> continue := false
|
||||
done
|
||||
let before_start (self : worker_state) : unit =
|
||||
let t_id = Thread.id @@ Thread.self () in
|
||||
self.st.on_init_thread ~dom_id:self.dom_id ~t_id ();
|
||||
TLS.set k_cur_fiber _dummy_fiber;
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner
|
||||
(Lazy.force self.st.as_runner);
|
||||
TLS.set k_worker_state self;
|
||||
|
||||
(** Main loop for a worker thread. *)
|
||||
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
|
||||
TLS.set k_worker_state w;
|
||||
(* set thread name *)
|
||||
Option.iter
|
||||
(fun name ->
|
||||
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx))
|
||||
self.st.name
|
||||
|
||||
let rec main () : unit =
|
||||
worker_run_self_tasks_ self ~runner w;
|
||||
try_steal ()
|
||||
and run_task task : unit =
|
||||
run_task_now_ self ~runner ~w task;
|
||||
main ()
|
||||
and try_steal () =
|
||||
match try_to_steal_work_once_ self w with
|
||||
| Some task -> run_task task
|
||||
| None -> wait ()
|
||||
and wait () =
|
||||
Mutex.lock self.mutex;
|
||||
match Queue.pop self.main_q with
|
||||
| task ->
|
||||
Mutex.unlock self.mutex;
|
||||
run_task task
|
||||
| exception Queue.Empty ->
|
||||
(* wait here *)
|
||||
if A.get self.active then (
|
||||
wait_ self;
|
||||
let cleanup (self : worker_state) : unit =
|
||||
(* on termination, decrease refcount of underlying domain *)
|
||||
Domain_pool_.decr_on self.dom_id;
|
||||
let t_id = Thread.id @@ Thread.self () in
|
||||
self.st.on_exit_thread ~dom_id:self.dom_id ~t_id ()
|
||||
|
||||
(* see if a task became available *)
|
||||
let task =
|
||||
try Some (Queue.pop self.main_q) with Queue.Empty -> None
|
||||
in
|
||||
Mutex.unlock self.mutex;
|
||||
|
||||
match task with
|
||||
| Some t -> run_task t
|
||||
| None -> try_steal ()
|
||||
) else
|
||||
(* do nothing more: no task in main queue, and we are shutting
|
||||
down so no new task should arrive.
|
||||
The exception is if another task is creating subtasks
|
||||
that overflow into the main queue, but we can ignore that at
|
||||
the price of slightly decreased performance for the last few
|
||||
tasks *)
|
||||
Mutex.unlock self.mutex
|
||||
let worker_ops : worker_state WL.ops =
|
||||
let runner (st : worker_state) = Lazy.force st.st.as_runner in
|
||||
let around_task st = st.st.around_task in
|
||||
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
|
||||
st.st.on_exn ebt.exn ebt.bt
|
||||
in
|
||||
|
||||
(* handle domain-local await *)
|
||||
main ()
|
||||
{
|
||||
WL.schedule = schedule_from_w;
|
||||
runner;
|
||||
get_next_task;
|
||||
get_thread_state = get_current_worker_exn;
|
||||
around_task;
|
||||
on_exn;
|
||||
before_start;
|
||||
cleanup;
|
||||
}
|
||||
|
||||
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
|
||||
|
||||
|
|
@ -276,6 +226,14 @@ let shutdown_ ~wait (self : state) : unit =
|
|||
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
|
||||
)
|
||||
|
||||
let as_runner_ (self : state) : t =
|
||||
Runner.For_runner_implementors.create
|
||||
~shutdown:(fun ~wait () -> shutdown_ self ~wait)
|
||||
~run_async:(fun ~fiber f -> schedule_on_main self @@ T_start { fiber; f })
|
||||
~size:(fun () -> size_ self)
|
||||
~num_tasks:(fun () -> num_tasks_ self)
|
||||
()
|
||||
|
||||
type ('a, 'b) create_args =
|
||||
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
||||
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
|
||||
|
|
@ -286,9 +244,6 @@ type ('a, 'b) create_args =
|
|||
'a
|
||||
(** Arguments used in {!create}. See {!create} for explanations. *)
|
||||
|
||||
let dummy_task_ : task_full =
|
||||
T_start { f = ignore; ls = Task_local_storage.dummy }
|
||||
|
||||
let create ?(on_init_thread = default_thread_init_exit_)
|
||||
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
|
||||
?around_task ?num_threads ?name () : t =
|
||||
|
|
@ -296,8 +251,8 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* wrapper *)
|
||||
let around_task =
|
||||
match around_task with
|
||||
| Some (f, g) -> AT_pair (f, g)
|
||||
| None -> AT_pair (ignore, fun _ _ -> ())
|
||||
| Some (f, g) -> WL.AT_pair (f, g)
|
||||
| None -> WL.AT_pair (ignore, fun _ _ -> ())
|
||||
in
|
||||
|
||||
let num_domains = Domain_pool_.max_number_of_domains () in
|
||||
|
|
@ -306,23 +261,11 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
|
||||
let offset = Random.int num_domains in
|
||||
|
||||
let workers : worker_state array =
|
||||
let dummy = Thread.self () in
|
||||
Array.init num_threads (fun i ->
|
||||
{
|
||||
pool_id_;
|
||||
thread = dummy;
|
||||
q = WSQ.create ~dummy:dummy_task_ ();
|
||||
rng = Random.State.make [| i |];
|
||||
cur_ls = None;
|
||||
})
|
||||
in
|
||||
|
||||
let pool =
|
||||
let rec pool =
|
||||
{
|
||||
id_ = pool_id_;
|
||||
active = A.make true;
|
||||
workers;
|
||||
workers = [||];
|
||||
main_q = Queue.create ();
|
||||
n_waiting = 0;
|
||||
n_waiting_nonzero = true;
|
||||
|
|
@ -330,65 +273,47 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
cond = Condition.create ();
|
||||
around_task;
|
||||
on_exn;
|
||||
on_init_thread;
|
||||
on_exit_thread;
|
||||
name;
|
||||
as_runner = lazy (as_runner_ pool);
|
||||
}
|
||||
in
|
||||
|
||||
let runner =
|
||||
Runner.For_runner_implementors.create
|
||||
~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
|
||||
~run_async:(fun ~ls f -> run_async_ pool ~ls f)
|
||||
~size:(fun () -> size_ pool)
|
||||
~num_tasks:(fun () -> num_tasks_ pool)
|
||||
()
|
||||
in
|
||||
|
||||
(* temporary queue used to obtain thread handles from domains
|
||||
on which the thread are started. *)
|
||||
let receive_threads = Bb_queue.create () in
|
||||
|
||||
(* start the thread with index [i] *)
|
||||
let start_thread_with_idx i =
|
||||
let w = pool.workers.(i) in
|
||||
let dom_idx = (offset + i) mod num_domains in
|
||||
|
||||
(* function run in the thread itself *)
|
||||
let main_thread_fun () : unit =
|
||||
let thread = Thread.self () in
|
||||
let t_id = Thread.id thread in
|
||||
on_init_thread ~dom_id:dom_idx ~t_id ();
|
||||
TLS.set k_cur_storage _dummy_ls;
|
||||
|
||||
(* set thread name *)
|
||||
Option.iter
|
||||
(fun name ->
|
||||
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i))
|
||||
name;
|
||||
|
||||
let run () = worker_thread_ pool ~runner w in
|
||||
|
||||
(* now run the main loop *)
|
||||
Fun.protect run ~finally:(fun () ->
|
||||
(* on termination, decrease refcount of underlying domain *)
|
||||
Domain_pool_.decr_on dom_idx);
|
||||
on_exit_thread ~dom_id:dom_idx ~t_id ()
|
||||
let start_thread_with_idx idx =
|
||||
let dom_id = (offset + idx) mod num_domains in
|
||||
let st =
|
||||
{
|
||||
st = pool;
|
||||
thread = (* dummy *) Thread.self ();
|
||||
q = WSQ.create ~dummy:WL._dummy_task ();
|
||||
rng = Random.State.make [| idx |];
|
||||
dom_id;
|
||||
idx;
|
||||
}
|
||||
in
|
||||
|
||||
(* function called in domain with index [i], to
|
||||
create the thread and push it into [receive_threads] *)
|
||||
let create_thread_in_domain () =
|
||||
let thread = Thread.create main_thread_fun () in
|
||||
let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in
|
||||
(* send the thread from the domain back to us *)
|
||||
Bb_queue.push receive_threads (i, thread)
|
||||
Bb_queue.push receive_threads (idx, thread)
|
||||
in
|
||||
|
||||
Domain_pool_.run_on dom_idx create_thread_in_domain
|
||||
Domain_pool_.run_on dom_id create_thread_in_domain;
|
||||
|
||||
st
|
||||
in
|
||||
|
||||
(* start all threads, placing them on the domains
|
||||
(* start all worker threads, placing them on the domains
|
||||
according to their index and [offset] in a round-robin fashion. *)
|
||||
for i = 0 to num_threads - 1 do
|
||||
start_thread_with_idx i
|
||||
done;
|
||||
pool.workers <- Array.init num_threads start_thread_with_idx;
|
||||
|
||||
(* receive the newly created threads back from domains *)
|
||||
for _j = 1 to num_threads do
|
||||
|
|
@ -397,7 +322,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
worker_state.thread <- th
|
||||
done;
|
||||
|
||||
runner
|
||||
Lazy.force pool.as_runner
|
||||
|
||||
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
|
||||
?name () f =
|
||||
|
|
|
|||
|
|
@ -72,7 +72,9 @@ let push (self : 'a t) (x : 'a) : bool =
|
|||
true
|
||||
with Full -> false
|
||||
|
||||
let pop (self : 'a t) : 'a option =
|
||||
exception Empty
|
||||
|
||||
let pop_exn (self : 'a t) : 'a =
|
||||
let b = A.get self.bottom in
|
||||
let b = b - 1 in
|
||||
A.set self.bottom b;
|
||||
|
|
@ -84,11 +86,11 @@ let pop (self : 'a t) : 'a option =
|
|||
if size < 0 then (
|
||||
(* reset to basic empty state *)
|
||||
A.set self.bottom t;
|
||||
None
|
||||
raise_notrace Empty
|
||||
) else if size > 0 then (
|
||||
(* can pop without modifying [top] *)
|
||||
let x = CA.get self.arr b in
|
||||
Some x
|
||||
x
|
||||
) else (
|
||||
assert (size = 0);
|
||||
(* there was exactly one slot, so we might be racing against stealers
|
||||
|
|
@ -96,13 +98,18 @@ let pop (self : 'a t) : 'a option =
|
|||
if A.compare_and_set self.top t (t + 1) then (
|
||||
let x = CA.get self.arr b in
|
||||
A.set self.bottom (t + 1);
|
||||
Some x
|
||||
x
|
||||
) else (
|
||||
A.set self.bottom (t + 1);
|
||||
None
|
||||
raise_notrace Empty
|
||||
)
|
||||
)
|
||||
|
||||
let[@inline] pop self : _ option =
|
||||
match pop_exn self with
|
||||
| exception Empty -> None
|
||||
| t -> Some t
|
||||
|
||||
let steal (self : 'a t) : 'a option =
|
||||
(* read [top], but do not update [top_cached]
|
||||
as we're in another thread *)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ val pop : 'a t -> 'a option
|
|||
(** Pop value from the bottom of deque.
|
||||
This must be called only by the owner thread. *)
|
||||
|
||||
exception Empty
|
||||
|
||||
val pop_exn : 'a t -> 'a
|
||||
|
||||
val steal : 'a t -> 'a option
|
||||
(** Try to steal from the top of deque. This is thread-safe. *)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue