refactor core: use picos for schedulers; add Worker_loop_

we factor most of the thread workers' logic in `Worker_loop_`,
which is now shared between Ws_pool and Fifo_pool
This commit is contained in:
Simon Cruanes 2024-08-28 12:39:15 -04:00
parent c73395635b
commit 9fb23bed4c
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
15 changed files with 497 additions and 589 deletions

View file

@ -1,87 +1,39 @@
open Types_
include Runner
module WL = Worker_loop_
type fiber = Picos.Fiber.t
type task_full = WL.task_full
let ( let@ ) = ( @@ )
type task_full =
| T_start of {
ls: Task_local_storage.t;
f: task;
}
| T_resume : {
ls: Task_local_storage.t;
k: 'a -> unit;
x: 'a;
}
-> task_full
type state = {
threads: Thread.t array;
q: task_full Bb_queue.t; (** Queue for tasks. *)
around_task: WL.around_task;
as_runner: t lazy_t;
(* init options *)
name: string option;
on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
on_exit_thread: dom_id:int -> t_id:int -> unit -> unit;
on_exn: exn -> Printexc.raw_backtrace -> unit;
}
(** internal state *)
type worker_state = {
idx: int;
dom_idx: int;
st: state;
mutable current: fiber;
}
let[@inline] size_ (self : state) = Array.length self.threads
let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q
(** Run [task] as is, on the pool. *)
let schedule_ (self : state) (task : task_full) : unit =
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type worker_state = { mutable cur_ls: Task_local_storage.t option }
let k_worker_state : worker_state TLS.t = TLS.create ()
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let w = { cur_ls = None } in
TLS.set k_worker_state w;
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
let (AT_pair (before_task, after_task)) = around_task in
let on_suspend () =
match TLS.get_opt k_worker_state with
| Some { cur_ls = Some ls; _ } -> ls
| _ -> assert false
in
let run_another_task ls task' = schedule_ self @@ T_start { f = task'; ls } in
let resume ls k res = schedule_ self @@ T_resume { ls; k; x = res } in
let run_task (task : task_full) : unit =
let ls =
match task with
| T_start { ls; _ } | T_resume { ls; _ } -> ls
in
w.cur_ls <- Some ls;
TLS.set k_cur_storage ls;
let _ctx = before_task runner in
(* run the task now, catching errors, handling effects *)
(try
match task with
| T_start { f = task; _ } ->
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend
(WSH { on_suspend; run = run_another_task; resume })
task
| T_resume { k; x; _ } ->
(* this is already in an effect handler *)
k x
with e ->
let bt = Printexc.get_raw_backtrace () in
on_exn e bt);
after_task runner _ctx;
w.cur_ls <- None;
TLS.set k_cur_storage _dummy_ls
in
let continue = ref true in
while !continue do
match Bb_queue.pop self.q with
| task -> run_task task
| exception Bb_queue.Closed -> continue := false
done
(*
get_thread_state = TLS.get_opt k_worker_state
*)
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
@ -98,10 +50,14 @@ type ('a, 'b) create_args =
?name:string ->
'a
let default_around_task_ : around_task = AT_pair (ignore, fun _ _ -> ())
let default_around_task_ : WL.around_task = AT_pair (ignore, fun _ _ -> ())
(** Run [task] as is, on the pool. *)
let schedule_ (self : state) (task : task_full) : unit =
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
let runner_of_state (pool : state) : t =
let run_async ~ls f = schedule_ pool @@ T_start { f; ls } in
let run_async ~fiber f = schedule_ pool @@ T_start { f; fiber } in
Runner.For_runner_implementors.create
~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
~run_async
@ -109,13 +65,59 @@ let runner_of_state (pool : state) : t =
~num_tasks:(fun () -> num_tasks_ pool)
()
(** Run [task] as is, on the pool. *)
let schedule_w (self : worker_state) (task : task_full) : unit =
try Bb_queue.push self.st.q task with Bb_queue.Closed -> raise Shutdown
let get_next_task (self : worker_state) =
try Bb_queue.pop self.st.q with Bb_queue.Closed -> raise WL.No_more_tasks
let get_thread_state () =
match TLS.get_exn k_worker_state with
| st -> st
| exception TLS.Not_set ->
failwith "Moonpool: get_thread_state called from outside a runner."
let before_start (self : worker_state) =
let t_id = Thread.id @@ Thread.self () in
self.st.on_init_thread ~dom_id:self.dom_idx ~t_id ();
(* set thread name *)
Option.iter
(fun name ->
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx))
self.st.name
let cleanup (self : worker_state) : unit =
(* on termination, decrease refcount of underlying domain *)
Domain_pool_.decr_on self.dom_idx;
let t_id = Thread.id @@ Thread.self () in
self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id ()
let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = Lazy.force st.st.as_runner in
let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn ebt.exn ebt.bt
in
{
WL.schedule = schedule_w;
runner;
get_next_task;
get_thread_state;
around_task;
on_exn;
before_start;
cleanup;
}
let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ?num_threads ?name () : t =
(* wrapper *)
let around_task =
match around_task with
| Some (f, g) -> AT_pair (f, g)
| Some (f, g) -> WL.AT_pair (f, g)
| None -> default_around_task_
in
@ -127,9 +129,18 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in
let pool =
let rec pool =
let dummy_thread = Thread.self () in
{ threads = Array.make num_threads dummy_thread; q = Bb_queue.create () }
{
threads = Array.make num_threads dummy_thread;
q = Bb_queue.create ();
around_task;
as_runner = lazy (runner_of_state pool);
name;
on_init_thread;
on_exit_thread;
on_exn;
}
in
let runner = runner_of_state pool in
@ -142,31 +153,11 @@ let create ?(on_init_thread = default_thread_init_exit_)
let start_thread_with_idx i =
let dom_idx = (offset + i) mod num_domains in
(* function run in the thread itself *)
let main_thread_fun () : unit =
let thread = Thread.self () in
let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id ();
(* set thread name *)
Option.iter
(fun name ->
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i))
name;
let run () = worker_thread_ pool runner ~on_exn ~around_task in
(* now run the main loop *)
Fun.protect run ~finally:(fun () ->
(* on termination, decrease refcount of underlying domain *)
Domain_pool_.decr_on dom_idx);
on_exit_thread ~dom_id:dom_idx ~t_id ()
in
(* function called in domain with index [i], to
create the thread and push it into [receive_threads] *)
let create_thread_in_domain () =
let thread = Thread.create main_thread_fun () in
let st = { idx = i; dom_idx; st = pool; current = _dummy_fiber } in
let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in
(* send the thread from the domain back to us *)
Bb_queue.push receive_threads (i, thread)
in
@ -196,13 +187,3 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool
module Private_ = struct
type nonrec state = state
let create_state ~threads () : state = { threads; q = Bb_queue.create () }
let runner_of_state = runner_of_state
let run_thread (st : state) (self : t) ~on_exn : unit =
worker_thread_ st self ~on_exn ~around_task:default_around_task_
end

View file

@ -44,17 +44,3 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args
When [f pool] returns or fails, [pool] is shutdown and its resources
are released.
Most parameters are the same as in {!create}. *)
(**/**)
module Private_ : sig
type state
val create_state : threads:Thread.t array -> unit -> state
val runner_of_state : state -> Runner.t
val run_thread :
state -> t -> on_exn:(exn -> Printexc.raw_backtrace -> unit) -> unit
end
(**/**)

View file

@ -36,7 +36,7 @@ module Ws_pool = Ws_pool
module Private = struct
module Ws_deque_ = Ws_deque_
module Suspend_ = Suspend_
module Worker_loop_ = Worker_loop_
module Domain_ = Domain_
module Tracing_ = Tracing_

View file

@ -33,13 +33,13 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t
to run the thread. This ensures that we don't always pick the same domain
to run all the various threads needed in an application (timers, event loops, etc.) *)
val run_async : ?ls:Task_local_storage.t -> Runner.t -> (unit -> unit) -> unit
val run_async : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> unit) -> unit
(** [run_async runner task] schedules the task to run
on the given runner. This means [task()] will be executed
at some point in the future, possibly in another thread.
@since 0.5 *)
val run_wait_block : ?ls:Task_local_storage.t -> Runner.t -> (unit -> 'a) -> 'a
val run_wait_block : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> 'a) -> 'a
(** [run_wait_block runner f] schedules [f] for later execution
on the runner, like {!run_async}.
It then blocks the current thread until [f()] is done executing,
@ -212,16 +212,10 @@ module Private : sig
module Ws_deque_ = Ws_deque_
(** A deque for work stealing, fixed size. *)
(** {2 Suspensions} *)
module Suspend_ = Suspend_
[@@alert
unstable "this module is an implementation detail of moonpool for now"]
(** Suspensions.
This is only going to work on OCaml 5.x.
{b NOTE}: this is not stable for now. *)
module Worker_loop_ = Worker_loop_
(** Worker loop. This is useful to implement custom runners, it
should run on each thread of the runner.
@since NEXT_RELEASE *)
module Domain_ = Domain_
(** Utils for domains *)

View file

@ -1,9 +1,10 @@
open Types_
type fiber = Picos.Fiber.t
type task = unit -> unit
type t = runner = {
run_async: ls:local_storage -> task -> unit;
run_async: fiber:fiber -> task -> unit;
shutdown: wait:bool -> unit -> unit;
size: unit -> int;
num_tasks: unit -> int;
@ -11,8 +12,15 @@ type t = runner = {
exception Shutdown
let[@inline] run_async ?(ls = create_local_storage ()) (self : t) f : unit =
self.run_async ~ls f
let[@inline] run_async ?fiber (self : t) f : unit =
let fiber =
match fiber with
| Some f -> f
| None ->
let comp = Picos.Computation.create () in
Picos.Fiber.create ~forbid:false comp
in
self.run_async ~fiber f
let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true ()
@ -22,9 +30,9 @@ let[@inline] shutdown_without_waiting (self : t) : unit =
let[@inline] num_tasks (self : t) : int = self.num_tasks ()
let[@inline] size (self : t) : int = self.size ()
let run_wait_block ?ls self (f : unit -> 'a) : 'a =
let run_wait_block ?fiber self (f : unit -> 'a) : 'a =
let q = Bb_queue.create () in
run_async ?ls self (fun () ->
run_async ?fiber self (fun () ->
try
let x = f () in
Bb_queue.push q (Ok x)
@ -47,9 +55,9 @@ let dummy : t =
~size:(fun () -> 0)
~num_tasks:(fun () -> 0)
~shutdown:(fun ~wait:_ () -> ())
~run_async:(fun ~ls:_ _ ->
~run_async:(fun ~fiber:_ _ ->
failwith "Runner.dummy: cannot actually run tasks")
()
let get_current_runner = get_current_runner
let get_current_storage = get_current_storage
let get_current_fiber = get_current_fiber

View file

@ -5,6 +5,7 @@
@since 0.3
*)
type fiber = Picos.Fiber.t
type task = unit -> unit
type t
@ -33,14 +34,14 @@ val shutdown_without_waiting : t -> unit
exception Shutdown
val run_async : ?ls:Task_local_storage.t -> t -> task -> unit
val run_async : ?fiber:fiber -> t -> task -> unit
(** [run_async pool f] schedules [f] for later execution on the runner
in one of the threads. [f()] will run on one of the runner's
worker threads/domains.
@param ls if provided, run the task with this initial local storage
@raise Shutdown if the runner was shut down before [run_async] was called. *)
val run_wait_block : ?ls:Task_local_storage.t -> t -> (unit -> 'a) -> 'a
val run_wait_block : ?fiber:fiber -> t -> (unit -> 'a) -> 'a
(** [run_wait_block pool f] schedules [f] for later execution
on the pool, like {!run_async}.
It then blocks the current thread until [f()] is done executing,
@ -65,7 +66,7 @@ module For_runner_implementors : sig
size:(unit -> int) ->
num_tasks:(unit -> int) ->
shutdown:(wait:bool -> unit -> unit) ->
run_async:(ls:Task_local_storage.t -> task -> unit) ->
run_async:(fiber:fiber -> task -> unit) ->
unit ->
t
(** Create a new runner.
@ -85,6 +86,6 @@ val get_current_runner : unit -> t option
happens on a thread that belongs in a runner.
@since 0.5 *)
val get_current_storage : unit -> Task_local_storage.t option
val get_current_fiber : unit -> fiber option
(** [get_current_storage runner] gets the local storage
for the currently running task. *)

View file

@ -1,70 +0,0 @@
type suspension = unit Exn_bt.result -> unit
type task = unit -> unit
type suspension_handler = {
handle:
run:(task -> unit) ->
resume:(suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
}
[@@unboxed]
type with_suspend_handler =
| WSH : {
on_suspend: unit -> 'state;
(** on_suspend called when [f()] suspends itself. *)
run: 'state -> task -> unit; (** run used to schedule new tasks *)
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
(** resume run the suspension. Must be called exactly once. *)
}
-> with_suspend_handler
[@@@ifge 5.0]
[@@@ocaml.alert "-unstable"]
module A = Atomic_
type _ Effect.t +=
| Suspend : suspension_handler -> unit Effect.t
| Yield : unit Effect.t
let[@inline] yield () = Effect.perform Yield
let[@inline] suspend h = Effect.perform (Suspend h)
let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit =
let module E = Effect.Deep in
(* effect handler *)
let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option =
function
| Suspend h ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some
(fun k ->
let state = on_suspend () in
let k' : suspension = function
| Ok () -> E.continue k ()
| Error ebt -> Exn_bt.discontinue k ebt
in
h.handle ~run:(run state) ~resume:(resume state) k')
| Yield ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some
(fun k ->
let state = on_suspend () in
let k' : suspension = function
| Ok () -> E.continue k ()
| Error ebt -> Exn_bt.discontinue k ebt
in
resume state k' @@ Ok ())
| _ -> None
in
E.try_with f () { E.effc }
[@@@ocaml.alert "+unstable"]
[@@@else_]
let[@inline] with_suspend (WSH _) f = f ()
[@@@endif]

View file

@ -1,86 +0,0 @@
(** (Private) suspending tasks using Effects.
This module is an implementation detail of Moonpool and should
not be used outside of it, except by experts to implement {!Runner}. *)
type suspension = unit Exn_bt.result -> unit
(** A suspended computation *)
type task = unit -> unit
type suspension_handler = {
handle:
run:(task -> unit) ->
resume:(suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
}
[@@unboxed]
(** The handler that knows what to do with the suspended computation.
The handler is given a few things:
- the suspended computation (which can be resumed with a result
eventually);
- a [run] function that can be used to start tasks to perform some
computation.
- a [resume] function to resume the suspended computation. This
must be called exactly once, in all situations.
This means that a fork-join primitive, for example, can use a single call
to {!suspend} to:
- suspend the caller until the fork-join is done
- use [run] to start all the tasks. Typically [run] is called multiple times,
which is where the "fork" part comes from. Each call to [run] potentially
runs in parallel with the other calls. The calls must coordinate so
that, once they are all done, the suspended caller is resumed with the
aggregated result of the computation.
- use [resume] exactly
*)
[@@@ifge 5.0]
[@@@ocaml.alert "-unstable"]
type _ Effect.t +=
| Suspend : suspension_handler -> unit Effect.t
(** The effect used to suspend the current thread and pass it, suspended,
to the handler. The handler will ensure that the suspension is resumed later
once some computation has been done. *)
| Yield : unit Effect.t
(** The effect used to interrupt the current computation and immediately re-schedule
it on the same runner. *)
[@@@ocaml.alert "+unstable"]
val yield : unit -> unit
(** Interrupt current computation, and re-schedule it at the end of the
runner's job queue. *)
val suspend : suspension_handler -> unit
(** [suspend h] jumps back to the nearest {!with_suspend}
and calls [h.handle] with the current continuation [k]
and a task runner function.
*)
[@@@endif]
type with_suspend_handler =
| WSH : {
on_suspend: unit -> 'state;
(** on_suspend called when [f()] suspends itself. *)
run: 'state -> task -> unit; (** run used to schedule new tasks *)
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
(** resume run the suspension. Must be called exactly once. *)
}
-> with_suspend_handler
val with_suspend : with_suspend_handler -> (unit -> unit) -> unit
(** [with_suspend wsh f]
runs [f()] in an environment where [suspend] will work.
If [f()] suspends with suspension handler [h],
this calls [wsh.on_suspend()] to capture the current state [st].
Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension.
The suspension should always be passed exactly once to
[resume]. [run] should be used to start other tasks.
*)

View file

@ -1,4 +1,5 @@
open Types_
(*
module A = Atomic_
type 'a key = 'a ls_key
@ -79,3 +80,4 @@ let with_value key x f =
Fun.protect ~finally:(fun () -> set key old) f
let get_current = get_current_storage
*)

View file

@ -8,6 +8,7 @@
@since 0.6
*)
(*
type t = Types_.local_storage
(** Underlying storage for a task. This is mutable and
not thread-safe. *)
@ -65,3 +66,4 @@ module Direct : sig
val create : unit -> t
val copy : t -> t
end
*)

View file

@ -1,36 +1,37 @@
module TLS = Thread_local_storage
module Domain_pool_ = Moonpool_dpool
(* TODO: replace with Picos.Fiber.FLS *)
type ls_value = ..
(** Key for task local storage *)
module type LS_KEY = sig
type t
type ls_value += V of t
val offset : int
(** Unique offset *)
val init : unit -> t
end
type 'a ls_key = (module LS_KEY with type t = 'a)
(** A LS key (task local storage) *)
type task = unit -> unit
type local_storage = ls_value array ref
type fiber = Picos.Fiber.t
type runner = {
run_async: ls:local_storage -> task -> unit;
run_async: fiber:fiber -> task -> unit;
shutdown: wait:bool -> unit -> unit;
size: unit -> int;
num_tasks: unit -> int;
}
let k_cur_runner : runner TLS.t = TLS.create ()
let k_cur_storage : local_storage TLS.t = TLS.create ()
let _dummy_ls : local_storage = ref [||]
[@@alert todo "remove me asap, done via picos now"]
let k_cur_fiber : fiber TLS.t = TLS.create ()
[@@alert todo "remove me asap, done via picos now"]
let _dummy_computation : Picos.Computation.packed =
let c = Picos.Computation.create () in
Picos.Computation.cancel c
{ exn = Failure "dummy fiber"; bt = Printexc.get_callstack 0 };
Picos.Computation.Packed c
let _dummy_fiber = Picos.Fiber.create_packed ~forbid:true _dummy_computation
let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner
let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage
let[@inline] create_local_storage () = ref [||]
let[@inline] get_current_fiber () : fiber option =
match TLS.get_exn k_cur_fiber with
| f when f != _dummy_fiber -> Some f
| _ -> None
let[@inline] get_current_fiber_exn () : fiber =
match TLS.get_exn k_cur_fiber with
| f when f != _dummy_fiber -> f
| _ -> failwith "Moonpool: get_current_fiber was called outside of a fiber."

153
src/core/worker_loop_.ml Normal file
View file

@ -0,0 +1,153 @@
open Types_
type fiber = Picos.Fiber.t
type task_full =
| T_start of {
fiber: fiber;
f: unit -> unit;
}
| T_resume : {
fiber: fiber;
k: unit -> unit;
}
-> task_full
type around_task =
| AT_pair : (Runner.t -> 'a) * (Runner.t -> 'a -> unit) -> around_task
exception No_more_tasks
type 'st ops = {
schedule: 'st -> task_full -> unit;
get_next_task: 'st -> task_full; (** @raise No_more_tasks *)
get_thread_state: unit -> 'st;
(** Access current thread's worker state from any worker *)
around_task: 'st -> around_task;
on_exn: 'st -> Exn_bt.t -> unit;
runner: 'st -> Runner.t;
before_start: 'st -> unit;
cleanup: 'st -> unit;
}
(** A dummy task. *)
let _dummy_task : task_full = T_start { f = ignore; fiber = _dummy_fiber }
[@@@ifge 5.0]
let[@inline] discontinue k exn =
let bt = Printexc.get_raw_backtrace () in
Effect.Deep.discontinue_with_backtrace k exn bt
let with_handler (type st arg) ~(ops : st ops) (self : st) :
(unit -> unit) -> unit =
let current =
Some
(fun k ->
match get_current_fiber_exn () with
| fiber -> Effect.Deep.continue k fiber
| exception exn -> discontinue k exn)
and yield =
Some
(fun k ->
let fiber = get_current_fiber_exn () in
match
let k () = Effect.Deep.continue k () in
ops.schedule self @@ T_resume { fiber; k }
with
| () -> ()
| exception exn -> discontinue k exn)
and reschedule trigger fiber k : unit =
ignore (Picos.Fiber.unsuspend fiber trigger : bool);
let k () = Picos.Fiber.resume fiber k in
let task = T_resume { fiber; k } in
ops.schedule self task
in
let effc (type a) :
a Effect.t -> ((a, _) Effect.Deep.continuation -> _) option = function
| Picos.Fiber.Current -> current
| Picos.Fiber.Yield -> yield
| Picos.Fiber.Spawn r ->
Some
(fun k ->
match
let f () = r.main r.fiber in
let task = T_start { fiber = r.fiber; f } in
ops.schedule self task
with
| unit -> Effect.Deep.continue k unit
| exception exn -> discontinue k exn)
| Picos.Trigger.Await trigger ->
Some
(fun k ->
let fiber = get_current_fiber_exn () in
(* when triggers is signaled, reschedule task *)
if not (Picos.Fiber.try_suspend fiber trigger fiber k reschedule) then
(* trigger was already signaled, run task now *)
Picos.Fiber.resume fiber k)
| Picos.Computation.Cancel_after _r ->
Some
(fun k ->
(* not implemented *)
let exn = Failure "Moonpool: cancel_after is not implemented" in
discontinue k exn)
| _ -> None
in
let handler = Effect.Deep.{ retc = Fun.id; exnc = raise; effc } in
fun f -> Effect.Deep.match_with f () handler
[@@@else_]
let with_handler ~ops:_ self f = f ()
[@@@endif]
let worker_loop (type st) ~(ops : st ops) (self : st) : unit =
let cur_fiber : fiber ref = ref _dummy_fiber in
let runner = ops.runner self in
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
let (AT_pair (before_task, after_task)) = ops.around_task self in
let run_task (task : task_full) : unit =
let fiber =
match task with
| T_start { fiber; _ } | T_resume { fiber; _ } -> fiber
in
cur_fiber := fiber;
TLS.set k_cur_fiber fiber;
let _ctx = before_task runner in
(* run the task now, catching errors, handling effects *)
assert (task != _dummy_task);
(try
match task with
| T_start { fiber = _; f } -> with_handler ~ops self f
| T_resume { fiber = _; k } ->
(* this is already in an effect handler *)
k ()
with e ->
let ebt = Exn_bt.get e in
ops.on_exn self ebt);
after_task runner _ctx;
cur_fiber := _dummy_fiber;
TLS.set k_cur_fiber _dummy_fiber
in
ops.before_start self;
let continue = ref true in
try
while !continue do
match ops.get_next_task self with
| task -> run_task task
| exception No_more_tasks -> continue := false
done;
ops.cleanup self
with exn ->
let bt = Printexc.get_raw_backtrace () in
ops.cleanup self;
Printexc.raise_with_backtrace exn bt

View file

@ -1,6 +1,7 @@
open Types_
module WSQ = Ws_deque_
module A = Atomic_
module WSQ = Ws_deque_
module WL = Worker_loop_
include Runner
let ( let@ ) = ( @@ )
@ -13,46 +14,39 @@ module Id = struct
let equal : t -> t -> bool = ( == )
end
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type task_full =
| T_start of {
ls: Task_local_storage.t;
f: task;
}
| T_resume : {
ls: Task_local_storage.t;
k: 'a -> unit;
x: 'a;
}
-> task_full
type worker_state = {
pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t;
q: task_full WSQ.t; (** Work stealing queue *)
mutable cur_ls: Task_local_storage.t option; (** Task storage *)
rng: Random.State.t;
}
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
can come and steal from it if they're idle. *)
type state = {
id_: Id.t;
(** Unique per pool. Used to make sure tasks stay within the same pool. *)
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
workers: worker_state array; (** Fixed set of workers. *)
main_q: task_full Queue.t;
mutable workers: worker_state array; (** Fixed set of workers. *)
main_q: WL.task_full Queue.t;
(** Main queue for tasks coming from the outside *)
mutable n_waiting: int; (* protected by mutex *)
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
mutex: Mutex.t;
cond: Condition.t;
as_runner: t lazy_t;
(* init options *)
around_task: WL.around_task;
name: string option;
on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
on_exit_thread: dom_id:int -> t_id:int -> unit -> unit;
on_exn: exn -> Printexc.raw_backtrace -> unit;
around_task: around_task;
}
(** internal state *)
and worker_state = {
mutable thread: Thread.t;
idx: int;
dom_id: int;
st: state;
q: WL.task_full WSQ.t; (** Work stealing queue *)
rng: Random.State.t;
}
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
can come and steal from it if they're idle. *)
let[@inline] size_ (self : state) = Array.length self.workers
let num_tasks_ (self : state) : int =
@ -66,9 +60,15 @@ let num_tasks_ (self : state) : int =
sub-tasks. *)
let k_worker_state : worker_state TLS.t = TLS.create ()
let[@inline] find_current_worker_ () : worker_state option =
let[@inline] get_current_worker_ () : worker_state option =
TLS.get_opt k_worker_state
let[@inline] get_current_worker_exn () : worker_state =
match TLS.get_exn k_worker_state with
| w -> w
| exception TLS.Not_set ->
failwith "Moonpool: get_current_runner was called from outside a pool."
(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
if self.n_waiting_nonzero then (
@ -77,194 +77,144 @@ let[@inline] try_wake_someone_ (self : state) : unit =
Mutex.unlock self.mutex
)
(** Run [task] as is, on the pool. *)
let schedule_task_ (self : state) ~w (task : task_full) : unit =
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
match w with
| Some w when Id.equal self.id_ w.pool_id_ ->
(* we're on this same pool, schedule in the worker's state. Otherwise
we might also be on pool A but asking to schedule on pool B,
so we have to check that identifiers match. *)
let pushed = WSQ.push w.q task in
if pushed then
try_wake_someone_ self
else (
(* overflow into main queue *)
Mutex.lock self.mutex;
Queue.push task self.main_q;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
let schedule_on_w (self : worker_state) task : unit =
(* we're on this same pool, schedule in the worker's state. Otherwise
we might also be on pool A but asking to schedule on pool B,
so we have to check that identifiers match. *)
let pushed = WSQ.push self.q task in
if pushed then
try_wake_someone_ self.st
else (
(* overflow into main queue *)
Mutex.lock self.st.mutex;
Queue.push task self.st.main_q;
if self.st.n_waiting_nonzero then Condition.signal self.st.cond;
Mutex.unlock self.st.mutex
)
let schedule_on_main (self : state) task : unit =
if A.get self.active then (
(* push into the main queue *)
Mutex.lock self.mutex;
Queue.push task self.main_q;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
) else
(* notify the caller that scheduling tasks is no
longer permitted *)
raise Shutdown
let schedule_from_w (self : worker_state) (task : WL.task_full) : unit =
match get_current_worker_ () with
| Some w when Id.equal self.st.id_ w.st.id_ ->
(* use worker from the same pool *)
schedule_on_w w task
| _ -> schedule_on_main self.st task
exception Got_task of WL.task_full
(** Try to steal a task.
@raise Got_task if it finds one. *)
let try_to_steal_work_once_ (self : worker_state) : unit =
let init = Random.State.int self.rng (Array.length self.st.workers) in
for i = 0 to Array.length self.st.workers - 1 do
let w' =
Array.unsafe_get self.st.workers
((i + init) mod Array.length self.st.workers)
in
if self != w' then (
match WSQ.steal w'.q with
| Some t -> raise_notrace (Got_task t)
| None -> ()
)
| _ ->
if A.get self.active then (
(* push into the main queue *)
Mutex.lock self.mutex;
Queue.push task self.main_q;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
) else
(* notify the caller that scheduling tasks is no
longer permitted *)
raise Shutdown
(** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
: unit =
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let (AT_pair (before_task, after_task)) = self.around_task in
let ls =
match task with
| T_start { ls; _ } | T_resume { ls; _ } -> ls
in
w.cur_ls <- Some ls;
TLS.set k_cur_storage ls;
let _ctx = before_task runner in
let[@inline] on_suspend () : _ ref =
match find_current_worker_ () with
| Some { cur_ls = Some w; _ } -> w
| _ -> assert false
in
let run_another_task ls (task' : task) =
let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
let ls' = Task_local_storage.Direct.copy ls in
schedule_task_ self ~w @@ T_start { ls = ls'; f = task' }
in
let resume ls k x =
let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
schedule_task_ self ~w @@ T_resume { ls; k; x }
in
(* run the task now, catching errors *)
(try
match task with
| T_start { f = task; _ } ->
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend
(WSH { on_suspend; run = run_another_task; resume })
task
| T_resume { k; x; _ } ->
(* this is already in an effect handler *)
k x
with e ->
let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt);
after_task runner _ctx;
w.cur_ls <- None;
TLS.set k_cur_storage _dummy_ls
let run_async_ (self : state) ~ls (f : task) : unit =
let w = find_current_worker_ () in
schedule_task_ self ~w @@ T_start { f; ls }
(* TODO: function to schedule many tasks from the outside.
- build a queue
- lock
- queue transfer
- wakeup all (broadcast)
- unlock *)
done
(** Wait on condition. Precondition: we hold the mutex. *)
let[@inline] wait_ (self : state) : unit =
let[@inline] wait_for_condition_ (self : state) : unit =
self.n_waiting <- self.n_waiting + 1;
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
Condition.wait self.cond self.mutex;
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
exception Got_task of task_full
let rec get_next_task (self : worker_state) : WL.task_full =
if not (A.get self.st.active) then raise WL.No_more_tasks;
match WSQ.pop_exn self.q with
| task ->
try_wake_someone_ self.st;
task
| exception WSQ.Empty -> try_steal_from_other_workers_ self
(** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option
=
let init = Random.State.int w.rng (Array.length self.workers) in
and try_steal_from_other_workers_ (self : worker_state) =
match try_to_steal_work_once_ self with
| exception Got_task task -> task
| () -> wait_on_worker self
try
for i = 0 to Array.length self.workers - 1 do
let w' =
Array.unsafe_get self.workers ((i + init) mod Array.length self.workers)
in
and wait_on_worker (self : worker_state) : WL.task_full =
Mutex.lock self.st.mutex;
match Queue.pop self.st.main_q with
| task ->
Mutex.unlock self.st.mutex;
task
| exception Queue.Empty ->
(* wait here *)
if A.get self.st.active then (
wait_for_condition_ self.st;
if w != w' then (
match WSQ.steal w'.q with
| Some t -> raise_notrace (Got_task t)
| None -> ()
)
done;
None
with Got_task t -> Some t
(* see if a task became available *)
match Queue.pop self.st.main_q with
| task ->
Mutex.unlock self.st.mutex;
task
| exception Queue.Empty -> try_steal_from_other_workers_ self
) else (
(* do nothing more: no task in main queue, and we are shutting
down so no new task should arrive.
The exception is if another task is creating subtasks
that overflow into the main queue, but we can ignore that at
the price of slightly decreased performance for the last few
tasks *)
Mutex.unlock self.st.mutex;
raise WL.No_more_tasks
)
(** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit =
let continue = ref true in
while !continue && A.get self.active do
match WSQ.pop w.q with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner ~w task
| None -> continue := false
done
let before_start (self : worker_state) : unit =
let t_id = Thread.id @@ Thread.self () in
self.st.on_init_thread ~dom_id:self.dom_id ~t_id ();
TLS.set k_cur_fiber _dummy_fiber;
TLS.set Runner.For_runner_implementors.k_cur_runner
(Lazy.force self.st.as_runner);
TLS.set k_worker_state self;
(** Main loop for a worker thread. *)
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
TLS.set k_worker_state w;
(* set thread name *)
Option.iter
(fun name ->
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx))
self.st.name
let rec main () : unit =
worker_run_self_tasks_ self ~runner w;
try_steal ()
and run_task task : unit =
run_task_now_ self ~runner ~w task;
main ()
and try_steal () =
match try_to_steal_work_once_ self w with
| Some task -> run_task task
| None -> wait ()
and wait () =
Mutex.lock self.mutex;
match Queue.pop self.main_q with
| task ->
Mutex.unlock self.mutex;
run_task task
| exception Queue.Empty ->
(* wait here *)
if A.get self.active then (
wait_ self;
let cleanup (self : worker_state) : unit =
(* on termination, decrease refcount of underlying domain *)
Domain_pool_.decr_on self.dom_id;
let t_id = Thread.id @@ Thread.self () in
self.st.on_exit_thread ~dom_id:self.dom_id ~t_id ()
(* see if a task became available *)
let task =
try Some (Queue.pop self.main_q) with Queue.Empty -> None
in
Mutex.unlock self.mutex;
match task with
| Some t -> run_task t
| None -> try_steal ()
) else
(* do nothing more: no task in main queue, and we are shutting
down so no new task should arrive.
The exception is if another task is creating subtasks
that overflow into the main queue, but we can ignore that at
the price of slightly decreased performance for the last few
tasks *)
Mutex.unlock self.mutex
let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = Lazy.force st.st.as_runner in
let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn ebt.exn ebt.bt
in
(* handle domain-local await *)
main ()
{
WL.schedule = schedule_from_w;
runner;
get_next_task;
get_thread_state = get_current_worker_exn;
around_task;
on_exn;
before_start;
cleanup;
}
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
@ -276,6 +226,14 @@ let shutdown_ ~wait (self : state) : unit =
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
)
let as_runner_ (self : state) : t =
Runner.For_runner_implementors.create
~shutdown:(fun ~wait () -> shutdown_ self ~wait)
~run_async:(fun ~fiber f -> schedule_on_main self @@ T_start { fiber; f })
~size:(fun () -> size_ self)
~num_tasks:(fun () -> num_tasks_ self)
()
type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
@ -286,9 +244,6 @@ type ('a, 'b) create_args =
'a
(** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ : task_full =
T_start { f = ignore; ls = Task_local_storage.dummy }
let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ?num_threads ?name () : t =
@ -296,8 +251,8 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* wrapper *)
let around_task =
match around_task with
| Some (f, g) -> AT_pair (f, g)
| None -> AT_pair (ignore, fun _ _ -> ())
| Some (f, g) -> WL.AT_pair (f, g)
| None -> WL.AT_pair (ignore, fun _ _ -> ())
in
let num_domains = Domain_pool_.max_number_of_domains () in
@ -306,23 +261,11 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in
let workers : worker_state array =
let dummy = Thread.self () in
Array.init num_threads (fun i ->
{
pool_id_;
thread = dummy;
q = WSQ.create ~dummy:dummy_task_ ();
rng = Random.State.make [| i |];
cur_ls = None;
})
in
let pool =
let rec pool =
{
id_ = pool_id_;
active = A.make true;
workers;
workers = [||];
main_q = Queue.create ();
n_waiting = 0;
n_waiting_nonzero = true;
@ -330,65 +273,47 @@ let create ?(on_init_thread = default_thread_init_exit_)
cond = Condition.create ();
around_task;
on_exn;
on_init_thread;
on_exit_thread;
name;
as_runner = lazy (as_runner_ pool);
}
in
let runner =
Runner.For_runner_implementors.create
~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
~run_async:(fun ~ls f -> run_async_ pool ~ls f)
~size:(fun () -> size_ pool)
~num_tasks:(fun () -> num_tasks_ pool)
()
in
(* temporary queue used to obtain thread handles from domains
on which the thread are started. *)
let receive_threads = Bb_queue.create () in
(* start the thread with index [i] *)
let start_thread_with_idx i =
let w = pool.workers.(i) in
let dom_idx = (offset + i) mod num_domains in
(* function run in the thread itself *)
let main_thread_fun () : unit =
let thread = Thread.self () in
let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.set k_cur_storage _dummy_ls;
(* set thread name *)
Option.iter
(fun name ->
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i))
name;
let run () = worker_thread_ pool ~runner w in
(* now run the main loop *)
Fun.protect run ~finally:(fun () ->
(* on termination, decrease refcount of underlying domain *)
Domain_pool_.decr_on dom_idx);
on_exit_thread ~dom_id:dom_idx ~t_id ()
let start_thread_with_idx idx =
let dom_id = (offset + idx) mod num_domains in
let st =
{
st = pool;
thread = (* dummy *) Thread.self ();
q = WSQ.create ~dummy:WL._dummy_task ();
rng = Random.State.make [| idx |];
dom_id;
idx;
}
in
(* function called in domain with index [i], to
create the thread and push it into [receive_threads] *)
let create_thread_in_domain () =
let thread = Thread.create main_thread_fun () in
let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in
(* send the thread from the domain back to us *)
Bb_queue.push receive_threads (i, thread)
Bb_queue.push receive_threads (idx, thread)
in
Domain_pool_.run_on dom_idx create_thread_in_domain
Domain_pool_.run_on dom_id create_thread_in_domain;
st
in
(* start all threads, placing them on the domains
(* start all worker threads, placing them on the domains
according to their index and [offset] in a round-robin fashion. *)
for i = 0 to num_threads - 1 do
start_thread_with_idx i
done;
pool.workers <- Array.init num_threads start_thread_with_idx;
(* receive the newly created threads back from domains *)
for _j = 1 to num_threads do
@ -397,7 +322,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
worker_state.thread <- th
done;
runner
Lazy.force pool.as_runner
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name () f =

View file

@ -72,7 +72,9 @@ let push (self : 'a t) (x : 'a) : bool =
true
with Full -> false
let pop (self : 'a t) : 'a option =
exception Empty
let pop_exn (self : 'a t) : 'a =
let b = A.get self.bottom in
let b = b - 1 in
A.set self.bottom b;
@ -84,11 +86,11 @@ let pop (self : 'a t) : 'a option =
if size < 0 then (
(* reset to basic empty state *)
A.set self.bottom t;
None
raise_notrace Empty
) else if size > 0 then (
(* can pop without modifying [top] *)
let x = CA.get self.arr b in
Some x
x
) else (
assert (size = 0);
(* there was exactly one slot, so we might be racing against stealers
@ -96,13 +98,18 @@ let pop (self : 'a t) : 'a option =
if A.compare_and_set self.top t (t + 1) then (
let x = CA.get self.arr b in
A.set self.bottom (t + 1);
Some x
x
) else (
A.set self.bottom (t + 1);
None
raise_notrace Empty
)
)
let[@inline] pop self : _ option =
match pop_exn self with
| exception Empty -> None
| t -> Some t
let steal (self : 'a t) : 'a option =
(* read [top], but do not update [top_cached]
as we're in another thread *)

View file

@ -21,6 +21,10 @@ val pop : 'a t -> 'a option
(** Pop value from the bottom of deque.
This must be called only by the owner thread. *)
exception Empty
val pop_exn : 'a t -> 'a
val steal : 'a t -> 'a option
(** Try to steal from the top of deque. This is thread-safe. *)