feat: pass task local storage in run_async

the idea is that we could use this to pass storage
around in `Fut` combinators, but I'm not sure that's actually
a good idea.
This commit is contained in:
Simon Cruanes 2024-02-12 12:02:42 -05:00
parent e8e61f6b30
commit 2a42f15e37
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
11 changed files with 121 additions and 34 deletions

View file

@ -2,11 +2,12 @@ open Types_
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
let k_storage = Task_local_storage.Private_.Storage.k_storage
type task_full = { type task_full = {
f: unit -> unit; f: unit -> unit;
name: string; name: string;
ls: task_ls; ls: Task_local_storage.storage;
} }
type state = { type state = {
@ -25,8 +26,8 @@ let schedule_ (self : state) (task : task_full) : unit =
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let cur_ls : task_ls ref = ref [||] in let cur_ls : Task_local_storage.storage ref = ref Task_local_storage.Private_.Storage.dummy in
TLS.set Types_.k_ls_values (Some cur_ls); TLS.set k_storage (Some cur_ls);
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
let (AT_pair (before_task, after_task)) = around_task in let (AT_pair (before_task, after_task)) = around_task in
@ -44,7 +45,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
in in
let run_another_task ls ~name task' = let run_another_task ls ~name task' =
let ls' = Array.copy ls in let ls' = Task_local_storage.Private_.Storage.copy ls in
schedule_ self { f = task'; name; ls = ls' } schedule_ self { f = task'; name; ls = ls' }
in in
@ -73,7 +74,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
on_exn e bt); on_exn e bt);
exit_span_ (); exit_span_ ();
after_task runner _ctx; after_task runner _ctx;
cur_ls := [||] cur_ls := Task_local_storage.Private_.Storage.dummy
in in
let main_loop () = let main_loop () =
@ -130,7 +131,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
{ threads = Array.make num_threads dummy; q = Bb_queue.create () } { threads = Array.make num_threads dummy; q = Bb_queue.create () }
in in
let run_async ~name f = schedule_ pool { f; name; ls = [||] } in let run_async ~name ~ls f = schedule_ pool { f; name; ls } in
let runner = let runner =
Runner.For_runner_implementors.create Runner.For_runner_implementors.create

View file

@ -105,7 +105,7 @@ let[@inline] fulfill_idempotent self r =
(* ### combinators ### *) (* ### combinators ### *)
let spawn ?name ~on f : _ t = let spawn ?name ?ls ~on f : _ t =
let fut, promise = make () in let fut, promise = make () in
let task () = let task () =
@ -118,13 +118,13 @@ let spawn ?name ~on f : _ t =
fulfill promise res fulfill promise res
in in
Runner.run_async ?name on task; Runner.run_async ?name ?ls on task;
fut fut
let spawn_on_current_runner ?name f : _ t = let spawn_on_current_runner ?name ?ls f : _ t =
match Runner.get_current_runner () with match Runner.get_current_runner () with
| None -> failwith "Fut.spawn_on_current_runner: not running on a runner" | None -> failwith "Fut.spawn_on_current_runner: not running on a runner"
| Some on -> spawn ?name ~on f | Some on -> spawn ?name ?ls ~on f
let reify_error (f : 'a t) : 'a or_error t = let reify_error (f : 'a t) : 'a or_error t =
match peek f with match peek f with

View file

@ -94,11 +94,17 @@ val is_failed : _ t -> bool
(** {2 Combinators} *) (** {2 Combinators} *)
val spawn : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a t val spawn :
?name:string ->
?ls:Task_local_storage.storage ->
on:Runner.t ->
(unit -> 'a) ->
'a t
(** [spaw ~on f] runs [f()] on the given runner [on], and return a future that will (** [spaw ~on f] runs [f()] on the given runner [on], and return a future that will
hold its result. *) hold its result. *)
val spawn_on_current_runner : ?name:string -> (unit -> 'a) -> 'a t val spawn_on_current_runner :
?name:string -> ?ls:Task_local_storage.storage -> (unit -> 'a) -> 'a t
(** This must be run from inside a runner, and schedules (** This must be run from inside a runner, and schedules
the new task on it as well. the new task on it as well.

View file

@ -1,14 +1,21 @@
open Types_
include Runner include Runner
let run_async_ ~name f = let k_ls = Task_local_storage.Private_.Storage.k_storage
let run_async_ ~name ~ls f =
let cur_ls = ref ls in
TLS.set k_ls (Some cur_ls);
let sp = Tracing_.enter_span name in let sp = Tracing_.enter_span name in
try try
let x = f () in let x = f () in
Tracing_.exit_span sp; Tracing_.exit_span sp;
TLS.set k_ls None;
x x
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
Tracing_.exit_span sp; Tracing_.exit_span sp;
TLS.set k_ls None;
Printexc.raise_with_backtrace e bt Printexc.raise_with_backtrace e bt
let runner : t = let runner : t =

View file

@ -1,8 +1,11 @@
exception Shutdown = Runner.Shutdown
let start_thread_on_some_domain f x = let start_thread_on_some_domain f x =
let did = Random.int (Domain_pool_.n_domains ()) in let did = Random.int (Domain_pool_.n_domains ()) in
Domain_pool_.run_on_and_wait did (fun () -> Thread.create f x) Domain_pool_.run_on_and_wait did (fun () -> Thread.create f x)
let run_async = Runner.run_async let run_async = Runner.run_async
let run_wait_block = Runner.run_wait_block
let recommended_thread_count () = Domain_.recommended_number () let recommended_thread_count () = Domain_.recommended_number ()
let spawn = Fut.spawn let spawn = Fut.spawn
let spawn_on_current_runner = Fut.spawn_on_current_runner let spawn_on_current_runner = Fut.spawn_on_current_runner

View file

@ -15,12 +15,22 @@ module Runner = Runner
module Immediate_runner = Immediate_runner module Immediate_runner = Immediate_runner
module Exn_bt = Exn_bt module Exn_bt = Exn_bt
exception Shutdown
(** Exception raised when trying to run tasks on
runners that have been shut down.
@since NEXT_RELEASE *)
val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t
(** Similar to {!Thread.create}, but it picks a background domain at random (** Similar to {!Thread.create}, but it picks a background domain at random
to run the thread. This ensures that we don't always pick the same domain to run the thread. This ensures that we don't always pick the same domain
to run all the various threads needed in an application (timers, event loops, etc.) *) to run all the various threads needed in an application (timers, event loops, etc.) *)
val run_async : ?name:string -> Runner.t -> (unit -> unit) -> unit val run_async :
?name:string ->
?ls:Task_local_storage.storage ->
Runner.t ->
(unit -> unit) ->
unit
(** [run_async runner task] schedules the task to run (** [run_async runner task] schedules the task to run
on the given runner. This means [task()] will be executed on the given runner. This means [task()] will be executed
at some point in the future, possibly in another thread. at some point in the future, possibly in another thread.
@ -29,20 +39,43 @@ val run_async : ?name:string -> Runner.t -> (unit -> unit) -> unit
(since NEXT_RELEASE) (since NEXT_RELEASE)
@since 0.5 *) @since 0.5 *)
val run_wait_block :
?name:string ->
?ls:Task_local_storage.storage ->
Runner.t ->
(unit -> 'a) ->
'a
(** [run_wait_block runner f] schedules [f] for later execution
on the runner, like {!run_async}.
It then blocks the current thread until [f()] is done executing,
and returns its result. If [f()] raises an exception, then [run_wait_block pool f]
will raise it as well.
{b NOTE} be careful with deadlocks (see notes in {!Fut.wait_block}
about the required discipline to avoid deadlocks).
@raise Shutdown if the runner was already shut down
@since NEXT_RELEASE *)
val recommended_thread_count : unit -> int val recommended_thread_count : unit -> int
(** Number of threads recommended to saturate the CPU. (** Number of threads recommended to saturate the CPU.
For IO pools this makes little sense (you might want more threads than For IO pools this makes little sense (you might want more threads than
this because many of them will be blocked most of the time). this because many of them will be blocked most of the time).
@since 0.5 *) @since 0.5 *)
val spawn : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a Fut.t val spawn :
?name:string ->
?ls:Task_local_storage.storage ->
on:Runner.t ->
(unit -> 'a) ->
'a Fut.t
(** [spawn ~on f] runs [f()] on the runner (a thread pool typically) (** [spawn ~on f] runs [f()] on the runner (a thread pool typically)
and returns a future result for it. See {!Fut.spawn}. and returns a future result for it. See {!Fut.spawn}.
@param name if provided and [Trace] is present in dependencies, @param name if provided and [Trace] is present in dependencies,
a span will be created for the future. (since 0.6) a span will be created for the future. (since 0.6)
@since 0.5 *) @since 0.5 *)
val spawn_on_current_runner : ?name:string -> (unit -> 'a) -> 'a Fut.t val spawn_on_current_runner :
?name:string -> ?ls:Task_local_storage.storage -> (unit -> 'a) -> 'a Fut.t
(** See {!Fut.spawn_on_current_runner}. (** See {!Fut.spawn_on_current_runner}.
@param name see {!spawn}. since 0.6. @param name see {!spawn}. since 0.6.
@since 0.5 *) @since 0.5 *)

View file

@ -3,7 +3,7 @@ module TLS = Thread_local_storage_
type task = unit -> unit type task = unit -> unit
type t = { type t = {
run_async: name:string -> task -> unit; run_async: name:string -> ls:Task_local_storage.storage -> task -> unit;
shutdown: wait:bool -> unit -> unit; shutdown: wait:bool -> unit -> unit;
size: unit -> int; size: unit -> int;
num_tasks: unit -> int; num_tasks: unit -> int;
@ -11,7 +11,10 @@ type t = {
exception Shutdown exception Shutdown
let[@inline] run_async ?(name = "") (self : t) f : unit = self.run_async ~name f let[@inline] run_async ?(name = "")
?(ls = Task_local_storage.Private_.Storage.create ()) (self : t) f : unit =
self.run_async ~name ~ls f
let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true () let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true ()
let[@inline] shutdown_without_waiting (self : t) : unit = let[@inline] shutdown_without_waiting (self : t) : unit =
@ -20,9 +23,9 @@ let[@inline] shutdown_without_waiting (self : t) : unit =
let[@inline] num_tasks (self : t) : int = self.num_tasks () let[@inline] num_tasks (self : t) : int = self.num_tasks ()
let[@inline] size (self : t) : int = self.size () let[@inline] size (self : t) : int = self.size ()
let run_wait_block ?name self (f : unit -> 'a) : 'a = let run_wait_block ?name ?ls self (f : unit -> 'a) : 'a =
let q = Bb_queue.create () in let q = Bb_queue.create () in
run_async ?name self (fun () -> run_async ?name ?ls self (fun () ->
try try
let x = f () in let x = f () in
Bb_queue.push q (Ok x) Bb_queue.push q (Ok x)

View file

@ -33,16 +33,19 @@ val shutdown_without_waiting : t -> unit
exception Shutdown exception Shutdown
val run_async : ?name:string -> t -> task -> unit val run_async :
?name:string -> ?ls:Task_local_storage.storage -> t -> task -> unit
(** [run_async pool f] schedules [f] for later execution on the runner (** [run_async pool f] schedules [f] for later execution on the runner
in one of the threads. [f()] will run on one of the runner's in one of the threads. [f()] will run on one of the runner's
worker threads/domains. worker threads/domains.
@param name if provided and [Trace] is present in dependencies, a span @param name if provided and [Trace] is present in dependencies, a span
will be created when the task starts, and will stop when the task is over. will be created when the task starts, and will stop when the task is over.
(since NEXT_RELEASE) (since NEXT_RELEASE)
@param ls if provided, run the task with this initial local storage
@raise Shutdown if the runner was shut down before [run_async] was called. *) @raise Shutdown if the runner was shut down before [run_async] was called. *)
val run_wait_block : ?name:string -> t -> (unit -> 'a) -> 'a val run_wait_block :
?name:string -> ?ls:Task_local_storage.storage -> t -> (unit -> 'a) -> 'a
(** [run_wait_block pool f] schedules [f] for later execution (** [run_wait_block pool f] schedules [f] for later execution
on the pool, like {!run_async}. on the pool, like {!run_async}.
It then blocks the current thread until [f()] is done executing, It then blocks the current thread until [f()] is done executing,
@ -62,7 +65,7 @@ module For_runner_implementors : sig
size:(unit -> int) -> size:(unit -> int) ->
num_tasks:(unit -> int) -> num_tasks:(unit -> int) ->
shutdown:(wait:bool -> unit -> unit) -> shutdown:(wait:bool -> unit -> unit) ->
run_async:(name:string -> task -> unit) -> run_async:(name:string -> ls:Task_local_storage.storage -> task -> unit) ->
unit -> unit ->
t t
(** Create a new runner. (** Create a new runner.

View file

@ -5,6 +5,8 @@ type 'a key = 'a ls_key
let key_count_ = A.make 0 let key_count_ = A.make 0
type storage = task_ls
let new_key (type t) ~init () : t key = let new_key (type t) ~init () : t key =
let offset = A.fetch_and_add key_count_ 1 in let offset = A.fetch_and_add key_count_ 1 in
(module struct (module struct
@ -55,3 +57,14 @@ let with_value key x f =
let old = get key in let old = get key in
set key x; set key x;
Fun.protect ~finally:(fun () -> set key old) f Fun.protect ~finally:(fun () -> set key old) f
module Private_ = struct
module Storage = struct
type t = storage
let k_storage = k_ls_values
let[@inline] create () = [||]
let copy = Array.copy
let dummy = [||]
end
end

View file

@ -8,6 +8,9 @@
@since NEXT_RELEASE @since NEXT_RELEASE
*) *)
type storage
(** Underlying storage for a task *)
type 'a key type 'a key
(** A key used to access a particular (typed) storage slot on every task. *) (** A key used to access a particular (typed) storage slot on every task. *)
@ -41,3 +44,18 @@ val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b
(** [with_value k v f] sets [k] to [v] for the duration of the call (** [with_value k v f] sets [k] to [v] for the duration of the call
to [f()]. When [f()] returns (or fails), [k] is restored to [f()]. When [f()] returns (or fails), [k] is restored
to its old value. *) to its old value. *)
(**/**)
module Private_ : sig
module Storage : sig
type t = storage
val k_storage : t ref option Thread_local_storage_.key
val create : unit -> t
val copy : t -> t
val dummy : t
end
end
(**/**)

View file

@ -1,10 +1,10 @@
open Types_
module WSQ = Ws_deque_ module WSQ = Ws_deque_
module A = Atomic_ module A = Atomic_
module TLS = Thread_local_storage_ module TLS = Thread_local_storage_
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
let k_storage = Task_local_storage.Private_.Storage.k_storage
module Id = struct module Id = struct
type t = unit ref type t = unit ref
@ -17,7 +17,7 @@ end
type task_full = { type task_full = {
f: task; f: task;
name: string; name: string;
ls: task_ls; ls: Task_local_storage.storage;
} }
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
@ -27,7 +27,7 @@ type worker_state = {
mutable thread: Thread.t; mutable thread: Thread.t;
q: task_full WSQ.t; (** Work stealing queue *) q: task_full WSQ.t; (** Work stealing queue *)
mutable cur_span: int64; mutable cur_span: int64;
cur_ls: task_ls ref; (** Task storage *) cur_ls: Task_local_storage.storage ref; (** Task storage *)
rng: Random.State.t; rng: Random.State.t;
} }
(** State for a given worker. Only this worker is (** State for a given worker. Only this worker is
@ -127,7 +127,7 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
let run_another_task ls ~name task' = let run_another_task ls ~name task' =
let w = find_current_worker_ () in let w = find_current_worker_ () in
let ls' = Array.copy ls in let ls' = Task_local_storage.Private_.Storage.copy ls in
schedule_task_ self w ~name ~ls:ls' task' schedule_task_ self w ~name ~ls:ls' task'
in in
@ -154,11 +154,11 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
exit_span_ (); exit_span_ ();
after_task runner _ctx; after_task runner _ctx;
w.cur_ls := [||] w.cur_ls := Task_local_storage.Private_.Storage.dummy
let[@inline] run_async_ (self : state) ~name (f : task) : unit = let[@inline] run_async_ (self : state) ~name ~ls (f : task) : unit =
let w = find_current_worker_ () in let w = find_current_worker_ () in
schedule_task_ self w ~name ~ls:[||] f schedule_task_ self w ~name ~ls f
(* TODO: function to schedule many tasks from the outside. (* TODO: function to schedule many tasks from the outside.
- build a queue - build a queue
@ -276,7 +276,7 @@ type ('a, 'b) create_args =
'a 'a
(** Arguments used in {!create}. See {!create} for explanations. *) (** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ = { f = ignore; ls = [||]; name = "DUMMY_TASK" } let dummy_task_ = { f = ignore; ls = Task_local_storage.Private_.Storage.dummy ; name = "DUMMY_TASK" }
let create ?(on_init_thread = default_thread_init_exit_) let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
@ -304,7 +304,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
cur_span = Tracing_.dummy_span; cur_span = Tracing_.dummy_span;
q = WSQ.create ~dummy:dummy_task_ (); q = WSQ.create ~dummy:dummy_task_ ();
rng = Random.State.make [| i |]; rng = Random.State.make [| i |];
cur_ls = ref [||]; cur_ls = ref Task_local_storage.Private_.Storage.dummy;
}) })
in in
@ -326,7 +326,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let runner = let runner =
Runner.For_runner_implementors.create Runner.For_runner_implementors.create
~shutdown:(fun ~wait () -> shutdown_ pool ~wait) ~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
~run_async:(fun ~name f -> run_async_ pool ~name f) ~run_async:(fun ~name ~ls f -> run_async_ pool ~name ~ls f)
~size:(fun () -> size_ pool) ~size:(fun () -> size_ pool)
~num_tasks:(fun () -> num_tasks_ pool) ~num_tasks:(fun () -> num_tasks_ pool)
() ()
@ -346,7 +346,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let thread = Thread.self () in let thread = Thread.self () in
let t_id = Thread.id thread in let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id (); on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.set k_ls_values (Some w.cur_ls); TLS.set k_storage (Some w.cur_ls);
(* set thread name *) (* set thread name *)
Option.iter Option.iter