add task_local_storage to core, modify how suspend works

This commit is contained in:
Simon Cruanes 2024-02-02 23:18:59 -05:00
parent f84414a412
commit c05a38d617
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
13 changed files with 286 additions and 56 deletions

View file

@ -3,7 +3,7 @@
(name moonpool)
(libraries moonpool.private)
(flags :standard -open Moonpool_private)
(private_modules domain_pool_ util_pool_)
(private_modules types_ domain_pool_ util_pool_)
(preprocess
(action
(run %{project_root}/src/cpp/cpp.exe %{input-file}))))

View file

@ -1,16 +1,17 @@
module TLS = Thread_local_storage_
open Types_
include Runner
let ( let@ ) = ( @@ )
type task_with_name = {
type task_full = {
f: unit -> unit;
name: string;
ls: task_ls;
}
type state = {
threads: Thread.t array;
q: task_with_name Bb_queue.t; (** Queue for tasks. *)
q: task_full Bb_queue.t; (** Queue for tasks. *)
}
(** internal state *)
@ -18,13 +19,16 @@ let[@inline] size_ (self : state) = Array.length self.threads
let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q
(** Run [task] as is, on the pool. *)
let schedule_ (self : state) (task : task_with_name) : unit =
let schedule_ (self : state) (task : task_full) : unit =
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let cur_ls : task_ls ref = ref [||] in
TLS.set Types_.k_ls_values (Some cur_ls);
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
let (AT_pair (before_task, after_task)) = around_task in
let cur_span = ref Tracing_.dummy_span in
@ -34,20 +38,32 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
cur_span := Tracing_.dummy_span
in
let run_another_task ~name task' = schedule_ self { f = task'; name } in
let on_suspend () =
exit_span_ ();
!cur_ls
in
let run_task (task : task_with_name) : unit =
let run_another_task ~name task' =
schedule_ self { f = task'; name; ls = [||] }
in
let run_task (task : task_full) : unit =
cur_ls := task.ls;
let _ctx = before_task runner in
cur_span := Tracing_.enter_span task.name;
(* run the task now, catching errors *)
(try
Suspend_.with_suspend task.f ~name:task.name ~run:run_another_task
~on_suspend:exit_span_
let resume ~ls k res =
schedule_ self { f = (fun () -> k res); name = task.name; ls }
in
(* run the task now, catching errors, handling effects *)
(try Suspend_.with_suspend task.f ~run:run_another_task ~resume ~on_suspend
with e ->
let bt = Printexc.get_raw_backtrace () in
on_exn e bt);
exit_span_ ();
after_task runner _ctx
after_task runner _ctx;
cur_ls := [||]
in
let main_loop () =
@ -100,7 +116,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
{ threads = Array.make num_threads dummy; q = Bb_queue.create () }
in
let run_async ~name f = schedule_ pool { f; name } in
let run_async ~name f = schedule_ pool { f; name; ls = [||] } in
let runner =
Runner.For_runner_implementors.create

View file

@ -42,6 +42,16 @@ let[@inline] is_done self : bool =
| Done _ -> true
| Waiting _ -> false
let[@inline] is_success self =
match A.get self.st with
| Done (Ok _) -> true
| _ -> false
let[@inline] is_failed self =
match A.get self.st with
| Done (Error _) -> true
| _ -> false
exception Not_ready
let[@inline] get_or_fail self =
@ -427,14 +437,14 @@ let await (fut : 'a t) : 'a =
Suspend_.suspend
{
Suspend_.handle =
(fun ~name ~run k ->
(fun ~ls ~run:_ ~resume k ->
on_result fut (function
| Ok _ ->
(* schedule continuation with the same name *)
run ~name (fun () -> k (Ok ()))
resume ~ls k (Ok ())
| Error (exn, bt) ->
(* fail continuation immediately *)
k (Error (exn, bt))));
resume ~ls k (Error (exn, bt))));
};
(* un-suspended: we should have a result! *)
get_or_fail_exn fut
@ -452,3 +462,7 @@ end
include Infix
module Infix_local = Infix [@@deprecated "use Infix"]
module Private_ = struct
let[@inline] unsafe_promise_of_fut x = x
end

View file

@ -84,6 +84,14 @@ val is_done : _ t -> bool
(** Is the future resolved? This is the same as [peek fut |> Option.is_some].
@since 0.2 *)
val is_success : _ t -> bool
(** Checks if the future is resolved with [Ok _] as a result.
@since NEXT_RELEASE *)
val is_failed : _ t -> bool
(** Checks if the future is resolved with [Error _] as a result.
@since NEXT_RELEASE *)
(** {2 Combinators} *)
val spawn : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a t
@ -268,3 +276,12 @@ include module type of Infix
module Infix_local = Infix
[@@deprecated "Use Infix"]
(** @deprecated use Infix instead *)
(**/**)
module Private_ : sig
val unsafe_promise_of_fut : 'a t -> 'a promise
(** please do not use *)
end
(**/**)

View file

@ -23,6 +23,7 @@ module Fut = Fut
module Lock = Lock
module Immediate_runner = Immediate_runner
module Runner = Runner
module Task_local_storage = Task_local_storage
module Thread_local_storage = Thread_local_storage_
module Ws_pool = Ws_pool

View file

@ -59,6 +59,7 @@ val await : 'a Fut.t -> 'a
module Lock = Lock
module Fut = Fut
module Chan = Chan
module Task_local_storage = Task_local_storage
module Thread_local_storage = Thread_local_storage_
(** A simple blocking queue.
@ -187,8 +188,10 @@ module Atomic = Atomic_
(**/**)
(** Private internals, with no stability guarantees *)
module Private : sig
module Ws_deque_ = Ws_deque_
(** A deque for work stealing, fixed size. *)
(** {2 Suspensions} *)

View file

@ -1,33 +1,55 @@
type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit
open Types_
type suspension = unit Exn_bt.result -> unit
type task = unit -> unit
type suspension_handler = {
handle: name:string -> run:(name:string -> task -> unit) -> suspension -> unit;
handle:
ls:task_ls ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
}
[@@unboxed]
[@@@ifge 5.0]
[@@@ocaml.alert "-unstable"]
type _ Effect.t += Suspend : suspension_handler -> unit Effect.t
type _ Effect.t +=
| Suspend : suspension_handler -> unit Effect.t
| Yield : unit Effect.t
let[@inline] yield () = Effect.perform Yield
let[@inline] suspend h = Effect.perform (Suspend h)
let with_suspend ~name ~on_suspend ~(run : name:string -> task -> unit)
let with_suspend ~on_suspend ~(run : name:string -> task -> unit)
~(resume : ls:task_ls -> suspension -> unit Exn_bt.result -> unit)
(f : unit -> unit) : unit =
let module E = Effect.Deep in
(* effect handler *)
let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option =
function
| Suspend h ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some
(fun k ->
on_suspend ();
let ls = on_suspend () in
let k' : suspension = function
| Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in
h.handle ~name ~run k')
h.handle ~ls ~run ~resume k')
| Yield ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some
(fun k ->
let ls = on_suspend () in
let k' : suspension = function
| Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in
resume ~ls k' (Ok ()))
| _ -> None
in

View file

@ -3,13 +3,20 @@
This module is an implementation detail of Moonpool and should
not be used outside of it, except by experts to implement {!Runner}. *)
type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit
open Types_
type suspension = unit Exn_bt.result -> unit
(** A suspended computation *)
type task = unit -> unit
type suspension_handler = {
handle: name:string -> run:(name:string -> task -> unit) -> suspension -> unit;
handle:
ls:task_ls ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
}
[@@unboxed]
(** The handler that knows what to do with the suspended computation.
@ -40,9 +47,16 @@ type _ Effect.t +=
(** The effect used to suspend the current thread and pass it, suspended,
to the handler. The handler will ensure that the suspension is resumed later
once some computation has been done. *)
| Yield : unit Effect.t
(** The effect used to interrupt the current computation and immediately re-schedule
it on the same runner. *)
[@@@ocaml.alert "+unstable"]
val yield : unit -> unit
(** Interrupt current computation, and re-schedule it at the end of the
runner's job queue. *)
val suspend : suspension_handler -> unit
(** [suspend h] jumps back to the nearest {!with_suspend}
and calls [h.handle] with the current continuation [k]
@ -52,17 +66,24 @@ val suspend : suspension_handler -> unit
[@@@endif]
val with_suspend :
name:string ->
on_suspend:(unit -> unit) ->
on_suspend:(unit -> task_ls) ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
(unit -> unit) ->
unit
(** [with_suspend ~run f] runs [f()] in an environment where [suspend]
will work. If [f()] suspends with suspension handler [h],
this calls [h ~run k] where [k] is the suspension.
The suspension should always run in a new task, via [run].
(** [with_suspend ~name ~on_suspend ~run ~resume f]
runs [f()] in an environment where [suspend]
will work (on OCaml 5) or do nothing (on OCaml 4.xx).
If [f()] suspends with suspension handler [h],
this calls [h ~run ~resume k] where [k] is the suspension.
The suspension should always be passed exactly once to
[resume]. [run] should be used to start other tasks.
@param on_suspend called when [f()] suspends itself.
@param name used for tracing, if not [""].
@param run used to schedule new tasks
@param resume run the suspension. Must be called exactly once.
This will not do anything on OCaml 4.x.
*)

View file

@ -0,0 +1,53 @@
open Types_
module A = Atomic
type 'a key = 'a ls_key
let key_count_ = A.make 0
let new_key (type t) ~init () : t key =
let offset = A.fetch_and_add key_count_ 1 in
(module struct
type nonrec t = t
type ls_value += V of t
let offset = offset
let init = init
end : LS_KEY
with type t = t)
type ls_value += Dummy
(** Resize array of TLS values *)
let[@inline never] resize_ (cur : ls_value array ref) n =
let len = Array.length !cur in
let new_ls = Array.make (max n (len * 2)) Dummy in
Array.blit !cur 0 new_ls 0 len;
cur := new_ls
let[@inline] get_cur_ () : ls_value array ref =
match TLS.get k_ls_values with
| Some r -> r
| None -> failwith "Task local storage must be accessed from within a runner."
let get (type a) ((module K) : a key) : a =
let cur = get_cur_ () in
if K.offset >= Array.length !cur then resize_ cur K.offset;
match !cur.(K.offset) with
| K.V x -> (* common case first *) x
| Dummy ->
(* first time we access this *)
let v = K.init () in
!cur.(K.offset) <- K.V v;
v
| _ -> assert false
let set (type a) ((module K) : a key) (v : a) : unit =
let cur = get_cur_ () in
if K.offset >= Array.length !cur then resize_ cur K.offset;
!cur.(K.offset) <- K.V v
let with_value key x f =
let old = get key in
set key x;
Fun.protect ~finally:(fun () -> set key old) f

View file

@ -0,0 +1,43 @@
(** Task-local storage.
This storage is associated to the current task,
just like thread-local storage is associated with
the current thread. The storage is carried along in case
the current task is suspended.
@since NEXT_RELEASE
*)
type 'a key
(** A key used to access a particular (typed) storage slot on every task. *)
val new_key : init:(unit -> 'a) -> unit -> 'a key
(** [new_key ~init ()] makes a new key. Keys are expensive and
should never be allocated dynamically or in a loop.
The correct pattern is, at toplevel:
{[
let k_foo : foo Task_ocal_storage.key =
Task_local_storage.new_key ~init:(fun () -> make_foo ()) ()
(**)
(* use it: *)
let = Task_local_storage.get k_foo
]}
*)
val get : 'a key -> 'a
(** [get k] gets the value for the current task for key [k].
Must be run from inside a task running on a runner.
@raise Failure otherwise *)
val set : 'a key -> 'a -> unit
(** [set k v] sets the storage for [k] to [v].
Must be run from inside a task running on a runner.
@raise Failure otherwise *)
val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b
(** [with_value k v f] sets [k] to [v] for the duration of the call
to [f()]. When [f()] returns (or fails), [k] is restored
to its old value. *)

26
src/core/types_.ml Normal file
View file

@ -0,0 +1,26 @@
module TLS = Thread_local_storage_
type ls_value = ..
(** Key for task local storage *)
module type LS_KEY = sig
type t
type ls_value += V of t
val offset : int
(** Unique offset *)
val init : unit -> t
end
type 'a ls_key = (module LS_KEY with type t = 'a)
(** A LS key (task local storage) *)
type task_ls = ls_value array
(** Store the current LS values for the current thread.
A worker thread is going to cycle through many tasks, each of which
has its own storage. This key allows tasks running on the worker
to access their own storage *)
let k_ls_values : task_ls ref option TLS.key = TLS.new_key (fun () -> None)

View file

@ -1,3 +1,4 @@
open Types_
module WSQ = Ws_deque_
module A = Atomic_
module TLS = Thread_local_storage_
@ -13,16 +14,18 @@ module Id = struct
let equal : t -> t -> bool = ( == )
end
type task_with_name = {
type task_full = {
f: task;
name: string;
ls: task_ls;
}
type worker_state = {
pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t;
q: task_with_name WSQ.t; (** Work stealing queue *)
q: task_full WSQ.t; (** Work stealing queue *)
mutable cur_span: int64;
cur_ls: task_ls ref; (** Task storage *)
rng: Random.State.t;
}
(** State for a given worker. Only this worker is
@ -35,7 +38,7 @@ type state = {
id_: Id.t;
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
workers: worker_state array; (** Fixed set of workers. *)
main_q: task_with_name Queue.t;
main_q: task_full Queue.t;
(** Main queue for tasks coming from the outside *)
mutable n_waiting: int; (* protected by mutex *)
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
@ -72,10 +75,10 @@ let[@inline] try_wake_someone_ (self : state) : unit =
)
(** Run [task] as is, on the pool. *)
let schedule_task_ (self : state) ~name (w : worker_state option) (f : task) :
unit =
let schedule_task_ (self : state) ~name ~ls (w : worker_state option) (f : task)
: unit =
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let task = { f; name } in
let task = { f; name; ls } in
match w with
| Some w when Id.equal self.id_ w.pool_id_ ->
(* we're on this same pool, schedule in the worker's state. Otherwise
@ -104,9 +107,11 @@ let schedule_task_ (self : state) ~name (w : worker_state option) (f : task) :
raise Shutdown
(** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner (w : worker_state) ~name task : unit =
let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
unit =
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let (AT_pair (before_task, after_task)) = self.around_task in
w.cur_ls := ls;
let _ctx = before_task runner in
w.cur_span <- Tracing_.enter_span name;
@ -115,25 +120,32 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name task : unit =
w.cur_span <- Tracing_.dummy_span
in
let on_suspend () =
exit_span_ ();
!(w.cur_ls)
in
let run_another_task ~name task' =
let w = find_current_worker_ () in
schedule_task_ self w ~name task'
schedule_task_ self w ~name ~ls:[||] task'
in
let resume ~ls k r = schedule_task_ self (Some w) ~name ~ls (fun () -> k r) in
(* run the task now, catching errors *)
(try
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend task ~name ~run:run_another_task
~on_suspend:exit_span_
Suspend_.with_suspend task ~run:run_another_task ~resume ~on_suspend
with e ->
let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt);
exit_span_ ();
after_task runner _ctx
after_task runner _ctx;
w.cur_ls := [||]
let[@inline] run_async_ (self : state) ~name (f : task) : unit =
let w = find_current_worker_ () in
schedule_task_ self w ~name f
schedule_task_ self w ~name ~ls:[||] f
(* TODO: function to schedule many tasks from the outside.
- build a queue
@ -150,11 +162,11 @@ let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
exception Got_task of task_with_name
exception Got_task of task_full
(** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) :
task_with_name option =
let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option
=
let init = Random.State.int w.rng (Array.length self.workers) in
try
@ -179,7 +191,7 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
match WSQ.pop w.q with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner w ~name:task.name task.f
run_task_now_ self ~runner w ~name:task.name ~ls:task.ls task.f
| None -> continue := false
done
@ -192,7 +204,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
worker_run_self_tasks_ self ~runner w;
try_steal ()
and run_task task : unit =
run_task_now_ self ~runner w ~name:task.name task.f;
run_task_now_ self ~runner w ~name:task.name ~ls:task.ls task.f;
main ()
and try_steal () =
match try_to_steal_work_once_ self w with
@ -249,7 +261,7 @@ type ('a, 'b) create_args =
'a
(** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ = { f = ignore; name = "DUMMY_TASK" }
let dummy_task_ = { f = ignore; ls = [||]; name = "DUMMY_TASK" }
let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
@ -277,6 +289,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
cur_span = Tracing_.dummy_span;
q = WSQ.create ~dummy:dummy_task_ ();
rng = Random.State.make [| i |];
cur_ls = ref [||];
})
in
@ -318,6 +331,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let thread = Thread.self () in
let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.set k_ls_values (Some w.cur_ls);
(* set thread name *)
Option.iter

View file

@ -48,7 +48,7 @@ module State_ = struct
Suspend_.suspend
{
Suspend_.handle =
(fun ~name:_ ~run:_ suspension ->
(fun ~ls ~run:_ ~resume suspension ->
while
let old_st = A.get self in
match old_st with
@ -59,7 +59,7 @@ module State_ = struct
| Left_solved left ->
(* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right));
suspension (Ok ());
resume ~ls suspension (Ok ());
false
| Right_solved _ | Both_solved _ -> assert false
do
@ -113,19 +113,19 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
max 1 (1 + (n / Moonpool.Private.num_domains ()))
in
let start_tasks ~name ~run (suspension : Suspend_.suspension) =
let start_tasks ~ls ~run ~resume (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with
| () ->
if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *)
run ~name (fun () -> suspension (Ok ()))
resume ~ls suspension (Ok ())
| exception exn ->
let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then
(* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *)
run ~name (fun () -> suspension (Error (exn, bt)))
resume ~ls suspension (Error (exn, bt))
in
let i = ref 0 in
@ -135,7 +135,7 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n);
run ~name (fun () -> task_for ~offset ~len_range);
run ~name:"" (fun () -> task_for ~offset ~len_range);
i := !i + len_range
done
in
@ -143,9 +143,9 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
Suspend_.suspend
{
Suspend_.handle =
(fun ~name ~run suspension ->
(fun ~ls ~run ~resume suspension ->
(* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~name suspension);
start_tasks ~run ~ls ~resume suspension);
}
)