cleanup the local storage APIs

This commit is contained in:
Simon Cruanes 2024-03-01 18:38:41 -05:00
parent 953947f694
commit 45b8aa9999
12 changed files with 120 additions and 120 deletions

View file

@ -2,15 +2,14 @@ open Types_
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
let k_storage = Task_local_storage.Private_.Storage.k_storage
type task_full = type task_full =
| T_start of { | T_start of {
ls: Task_local_storage.storage ref; ls: Task_local_storage.t;
f: task; f: task;
} }
| T_resume : { | T_resume : {
ls: Task_local_storage.storage ref; ls: Task_local_storage.t;
k: 'a -> unit; k: 'a -> unit;
x: 'a; x: 'a;
} }
@ -30,7 +29,7 @@ let schedule_ (self : state) (task : task_full) : unit =
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type worker_state = { mutable cur_ls: Task_local_storage.storage ref option } type worker_state = { mutable cur_ls: Task_local_storage.t option }
let k_worker_state : worker_state option ref TLS.key = let k_worker_state : worker_state option ref TLS.key =
TLS.new_key (fun () -> ref None) TLS.new_key (fun () -> ref None)
@ -56,7 +55,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
| T_start { ls; _ } | T_resume { ls; _ } -> ls | T_start { ls; _ } | T_resume { ls; _ } -> ls
in in
w.cur_ls <- Some ls; w.cur_ls <- Some ls;
TLS.set k_storage (Some ls); TLS.get k_cur_storage := Some ls;
let _ctx = before_task runner in let _ctx = before_task runner in
(* run the task now, catching errors, handling effects *) (* run the task now, catching errors, handling effects *)
@ -75,7 +74,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
on_exn e bt); on_exn e bt);
after_task runner _ctx; after_task runner _ctx;
w.cur_ls <- None; w.cur_ls <- None;
TLS.set k_storage None TLS.get k_cur_storage := None
in in
let main_loop () = let main_loop () =

View file

@ -1,18 +1,15 @@
open Types_ open Types_
include Runner include Runner
(* convenient alias *)
let k_ls = Task_local_storage.Private_.Storage.k_storage
let run_async_ ~ls:cur_ls f = let run_async_ ~ls:cur_ls f =
TLS.set k_ls (Some cur_ls); TLS.get k_cur_storage := Some cur_ls;
try try
let x = f () in let x = f () in
TLS.set k_ls None; TLS.get k_cur_storage := None;
x x
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
TLS.set k_ls None; TLS.get k_cur_storage := None;
Printexc.raise_with_backtrace e bt Printexc.raise_with_backtrace e bt
let runner : t = let runner : t =

View file

@ -26,15 +26,13 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t
to run the thread. This ensures that we don't always pick the same domain to run the thread. This ensures that we don't always pick the same domain
to run all the various threads needed in an application (timers, event loops, etc.) *) to run all the various threads needed in an application (timers, event loops, etc.) *)
val run_async : val run_async : ?ls:Task_local_storage.t -> Runner.t -> (unit -> unit) -> unit
?ls:Task_local_storage.storage ref -> Runner.t -> (unit -> unit) -> unit
(** [run_async runner task] schedules the task to run (** [run_async runner task] schedules the task to run
on the given runner. This means [task()] will be executed on the given runner. This means [task()] will be executed
at some point in the future, possibly in another thread. at some point in the future, possibly in another thread.
@since 0.5 *) @since 0.5 *)
val run_wait_block : val run_wait_block : ?ls:Task_local_storage.t -> Runner.t -> (unit -> 'a) -> 'a
?ls:Task_local_storage.storage ref -> Runner.t -> (unit -> 'a) -> 'a
(** [run_wait_block runner f] schedules [f] for later execution (** [run_wait_block runner f] schedules [f] for later execution
on the runner, like {!run_async}. on the runner, like {!run_async}.
It then blocks the current thread until [f()] is done executing, It then blocks the current thread until [f()] is done executing,

View file

@ -1,9 +1,10 @@
open Types_
module TLS = Thread_local_storage_ module TLS = Thread_local_storage_
type task = unit -> unit type task = unit -> unit
type t = { type t = runner = {
run_async: ls:Task_local_storage.storage ref -> task -> unit; run_async: ls:local_storage -> task -> unit;
shutdown: wait:bool -> unit -> unit; shutdown: wait:bool -> unit -> unit;
size: unit -> int; size: unit -> int;
num_tasks: unit -> int; num_tasks: unit -> int;
@ -11,9 +12,7 @@ type t = {
exception Shutdown exception Shutdown
let[@inline] run_async let[@inline] run_async ?(ls = create_local_storage ()) (self : t) f : unit =
?(ls = ref @@ Task_local_storage.Private_.Storage.create ()) (self : t) f :
unit =
self.run_async ~ls f self.run_async ~ls f
let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true () let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true ()
@ -41,8 +40,8 @@ module For_runner_implementors = struct
let create ~size ~num_tasks ~shutdown ~run_async () : t = let create ~size ~num_tasks ~shutdown ~run_async () : t =
{ size; num_tasks; shutdown; run_async } { size; num_tasks; shutdown; run_async }
let k_cur_runner : t option ref TLS.key = TLS.new_key (fun () -> ref None) let k_cur_runner : t option ref TLS.key = Types_.k_cur_runner
end end
let[@inline] get_current_runner () : _ option = let get_current_runner = get_current_runner
!(TLS.get For_runner_implementors.k_cur_runner) let get_current_storage = get_current_storage

View file

@ -33,15 +33,14 @@ val shutdown_without_waiting : t -> unit
exception Shutdown exception Shutdown
val run_async : ?ls:Task_local_storage.storage ref -> t -> task -> unit val run_async : ?ls:Task_local_storage.t -> t -> task -> unit
(** [run_async pool f] schedules [f] for later execution on the runner (** [run_async pool f] schedules [f] for later execution on the runner
in one of the threads. [f()] will run on one of the runner's in one of the threads. [f()] will run on one of the runner's
worker threads/domains. worker threads/domains.
@param ls if provided, run the task with this initial local storage @param ls if provided, run the task with this initial local storage
@raise Shutdown if the runner was shut down before [run_async] was called. *) @raise Shutdown if the runner was shut down before [run_async] was called. *)
val run_wait_block : val run_wait_block : ?ls:Task_local_storage.t -> t -> (unit -> 'a) -> 'a
?ls:Task_local_storage.storage ref -> t -> (unit -> 'a) -> 'a
(** [run_wait_block pool f] schedules [f] for later execution (** [run_wait_block pool f] schedules [f] for later execution
on the pool, like {!run_async}. on the pool, like {!run_async}.
It then blocks the current thread until [f()] is done executing, It then blocks the current thread until [f()] is done executing,
@ -61,7 +60,7 @@ module For_runner_implementors : sig
size:(unit -> int) -> size:(unit -> int) ->
num_tasks:(unit -> int) -> num_tasks:(unit -> int) ->
shutdown:(wait:bool -> unit -> unit) -> shutdown:(wait:bool -> unit -> unit) ->
run_async:(ls:Task_local_storage.storage ref -> task -> unit) -> run_async:(ls:Task_local_storage.t -> task -> unit) ->
unit -> unit ->
t t
(** Create a new runner. (** Create a new runner.
@ -80,3 +79,7 @@ val get_current_runner : unit -> t option
(** Access the current runner. This returns [Some r] if the call (** Access the current runner. This returns [Some r] if the call
happens on a thread that belongs in a runner. happens on a thread that belongs in a runner.
@since 0.5 *) @since 0.5 *)
val get_current_storage : unit -> Task_local_storage.t option
(** [get_current_storage runner] gets the local storage
for the currently running task. *)

View file

@ -5,7 +5,44 @@ type 'a key = 'a ls_key
let key_count_ = A.make 0 let key_count_ = A.make 0
type storage = task_ls type t = local_storage
type ls_value += Dummy
let dummy : t = ref [||]
(** Resize array of TLS values *)
let[@inline never] resize_ (cur : ls_value array ref) n =
if n > Sys.max_array_length then failwith "too many task local storage keys";
let len = Array.length !cur in
let new_ls =
Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy
in
Array.blit !cur 0 new_ls 0 len;
cur := new_ls
module Direct = struct
type nonrec t = t
let create = create_local_storage
let[@inline] copy (self : t) = ref (Array.copy !self)
let get (type a) (self : t) ((module K) : a key) : a =
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
match !self.(K.offset) with
| K.V x -> (* common case first *) x
| Dummy ->
(* first time we access this *)
let v = K.init () in
!self.(K.offset) <- K.V v;
v
| _ -> assert false
let set (type a) (self : t) ((module K) : a key) (v : a) : unit =
assert (self != dummy);
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
!self.(K.offset) <- K.V v;
()
end
let new_key (type t) ~init () : t key = let new_key (type t) ~init () : t key =
let offset = A.fetch_and_add key_count_ 1 in let offset = A.fetch_and_add key_count_ 1 in
@ -18,68 +55,25 @@ let new_key (type t) ~init () : t key =
end : LS_KEY end : LS_KEY
with type t = t) with type t = t)
type ls_value += Dummy
(** Resize array of TLS values *)
let[@inline never] resize_ (cur : ls_value array ref) n =
if n > Sys.max_array_length then failwith "too many task local storage keys";
let len = Array.length !cur in
let new_ls =
Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy
in
Array.blit !cur 0 new_ls 0 len;
cur := new_ls
let[@inline] get_cur_ () : ls_value array ref = let[@inline] get_cur_ () : ls_value array ref =
match TLS.get k_ls_values with match get_current_storage () with
| Some r -> r | Some r -> r
| None -> failwith "Task local storage must be accessed from within a runner." | None -> failwith "Task local storage must be accessed from within a runner."
let get_from_ (type a) cur ((module K) : a key) : a =
if K.offset >= Array.length !cur then resize_ cur (K.offset + 1);
match !cur.(K.offset) with
| K.V x -> (* common case first *) x
| Dummy ->
(* first time we access this *)
let v = K.init () in
!cur.(K.offset) <- K.V v;
v
| _ -> assert false
let[@inline] get (key : 'a key) : 'a = let[@inline] get (key : 'a key) : 'a =
let cur = get_cur_ () in let cur = get_cur_ () in
get_from_ cur key Direct.get cur key
let[@inline] get_opt key = let[@inline] get_opt key =
match TLS.get k_ls_values with match get_current_storage () with
| None -> None | None -> None
| Some cur -> Some (get_from_ cur key) | Some cur -> Some (Direct.get cur key)
let set_into_ (type a) cur ((module K) : a key) (v : a) : unit =
if K.offset >= Array.length !cur then resize_ cur (K.offset + 1);
!cur.(K.offset) <- K.V v;
()
let[@inline] set key v : unit = let[@inline] set key v : unit =
let cur = get_cur_ () in let cur = get_cur_ () in
set_into_ cur key v Direct.set cur key v
let with_value key x f = let with_value key x f =
let old = get key in let old = get key in
set key x; set key x;
Fun.protect ~finally:(fun () -> set key old) f Fun.protect ~finally:(fun () -> set key old) f
module Private_ = struct
module Storage = struct
type t = storage
let k_storage = k_ls_values
let[@inline] create () = [||]
let[@inline] get_cur_opt () = TLS.get k_storage
let copy = Array.copy
let get = get_from_
let set = set_into_
let[@inline] copy_of_current () = copy @@ !(get_cur_ ())
let dummy = [||]
end
end

View file

@ -8,8 +8,11 @@
@since NEXT_RELEASE @since NEXT_RELEASE
*) *)
type storage type t = Types_.local_storage
(** Underlying storage for a task *) (** Underlying storage for a task. This is mutable and
not thread-safe. *)
val dummy : t
type 'a key type 'a key
(** A key used to access a particular (typed) storage slot on every task. *) (** A key used to access a particular (typed) storage slot on every task. *)
@ -49,22 +52,12 @@ val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b
to [f()]. When [f()] returns (or fails), [k] is restored to [f()]. When [f()] returns (or fails), [k] is restored
to its old value. *) to its old value. *)
(**/**) (** Direct access to values from a storage handle *)
module Direct : sig
val get : t -> 'a key -> 'a
(** Access a key *)
(** Private API *) val set : t -> 'a key -> 'a -> unit
module Private_ : sig
module Storage : sig
type t = storage
val get : t ref -> 'a key -> 'a
val set : t ref -> 'a key -> 'a -> unit
val k_storage : t ref option Thread_local_storage_.key
val get_cur_opt : unit -> t ref option
val create : unit -> t val create : unit -> t
val copy : t -> t val copy : t -> t
val copy_of_current : unit -> t
val dummy : t
end
end end
(**/**)

View file

@ -16,11 +16,21 @@ end
type 'a ls_key = (module LS_KEY with type t = 'a) type 'a ls_key = (module LS_KEY with type t = 'a)
(** A LS key (task local storage) *) (** A LS key (task local storage) *)
type task_ls = ls_value array type task = unit -> unit
type local_storage = ls_value array ref
(** Store the current LS values for the current thread. type runner = {
run_async: ls:local_storage -> task -> unit;
shutdown: wait:bool -> unit -> unit;
size: unit -> int;
num_tasks: unit -> int;
}
A worker thread is going to cycle through many tasks, each of which let k_cur_runner : runner option ref TLS.key = TLS.new_key (fun () -> ref None)
has its own storage. This key allows tasks running on the worker
to access their own storage *) let k_cur_storage : local_storage option ref TLS.key =
let k_ls_values : task_ls ref option TLS.key = TLS.new_key (fun () -> None) TLS.new_key (fun () -> ref None)
let[@inline] get_current_runner () : _ option = !(TLS.get k_cur_runner)
let[@inline] get_current_storage () : _ option = !(TLS.get k_cur_storage)
let[@inline] create_local_storage () = ref [||]

View file

@ -1,10 +1,10 @@
open Types_
module WSQ = Ws_deque_ module WSQ = Ws_deque_
module A = Atomic_ module A = Atomic_
module TLS = Thread_local_storage_ module TLS = Thread_local_storage_
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
let k_storage = Task_local_storage.Private_.Storage.k_storage
module Id = struct module Id = struct
type t = unit ref type t = unit ref
@ -18,11 +18,11 @@ type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type task_full = type task_full =
| T_start of { | T_start of {
ls: Task_local_storage.storage ref; ls: Task_local_storage.t;
f: task; f: task;
} }
| T_resume : { | T_resume : {
ls: Task_local_storage.storage ref; ls: Task_local_storage.t;
k: 'a -> unit; k: 'a -> unit;
x: 'a; x: 'a;
} }
@ -32,7 +32,7 @@ type worker_state = {
pool_id_: Id.t; (** Unique per pool *) pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t; mutable thread: Thread.t;
q: task_full WSQ.t; (** Work stealing queue *) q: task_full WSQ.t; (** Work stealing queue *)
mutable cur_ls: Task_local_storage.storage ref option; (** Task storage *) mutable cur_ls: Task_local_storage.t option; (** Task storage *)
rng: Random.State.t; rng: Random.State.t;
} }
(** State for a given worker. Only this worker is (** State for a given worker. Only this worker is
@ -121,7 +121,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
in in
w.cur_ls <- Some ls; w.cur_ls <- Some ls;
TLS.set k_storage (Some ls); TLS.get k_cur_storage := Some ls;
let _ctx = before_task runner in let _ctx = before_task runner in
let[@inline] on_suspend () : _ ref = let[@inline] on_suspend () : _ ref =
@ -136,7 +136,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
| Some w when Id.equal w.pool_id_ self.id_ -> Some w | Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None | _ -> None
in in
let ls' = ref @@ Task_local_storage.Private_.Storage.copy !ls in let ls' = Task_local_storage.Direct.copy ls in
schedule_task_ self ~w @@ T_start { ls = ls'; f = task' } schedule_task_ self ~w @@ T_start { ls = ls'; f = task' }
in in
@ -166,7 +166,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
after_task runner _ctx; after_task runner _ctx;
w.cur_ls <- None; w.cur_ls <- None;
TLS.set k_storage None TLS.get k_cur_storage := None
let run_async_ (self : state) ~ls (f : task) : unit = let run_async_ (self : state) ~ls (f : task) : unit =
let w = find_current_worker_ () in let w = find_current_worker_ () in
@ -289,7 +289,7 @@ type ('a, 'b) create_args =
(** Arguments used in {!create}. See {!create} for explanations. *) (** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ : task_full = let dummy_task_ : task_full =
T_start { f = ignore; ls = ref Task_local_storage.Private_.Storage.dummy } T_start { f = ignore; ls = Task_local_storage.dummy }
let create ?(on_init_thread = default_thread_init_exit_) let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
@ -358,7 +358,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let thread = Thread.self () in let thread = Thread.self () in
let t_id = Thread.id thread in let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id (); on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.set k_storage None; TLS.get k_cur_storage := None;
(* set thread name *) (* set thread name *)
Option.iter Option.iter

View file

@ -16,7 +16,7 @@ module Private_ = struct
state: 'a state A.t; (** Current state in the lifetime of the fiber *) state: 'a state A.t; (** Current state in the lifetime of the fiber *)
res: 'a Fut.t; res: 'a Fut.t;
runner: Runner.t; runner: Runner.t;
ls: Task_local_storage.storage ref; ls: Task_local_storage.t;
} }
and 'a state = and 'a state =
@ -248,7 +248,7 @@ let spawn_ ~ls (Nursery n) (f : nursery -> 'a) : 'a t =
let spawn (Nursery n) ?(protect = true) f : _ t = let spawn (Nursery n) ?(protect = true) f : _ t =
(* spawn [f()] with a copy of our local storage *) (* spawn [f()] with a copy of our local storage *)
let ls = ref (Task_local_storage.Private_.Storage.copy !(n.ls)) in let ls = Task_local_storage.Direct.copy n.ls in
let child = spawn_ ~ls (Nursery n) f in let child = spawn_ ~ls (Nursery n) f in
add_child_ ~protect n child; add_child_ ~protect n child;
child child
@ -259,6 +259,8 @@ let[@inline] spawn_ignore n ?protect f : unit =
module Nursery = struct module Nursery = struct
type t = nursery type t = nursery
let[@inline] runner (Nursery n) = n.runner
let[@inline] await (Nursery n) : unit = let[@inline] await (Nursery n) : unit =
ignore (await n); ignore (await n);
() ()
@ -266,17 +268,13 @@ module Nursery = struct
let cancel_with (Nursery n) ebt : unit = resolve_as_failed_ n ebt let cancel_with (Nursery n) ebt : unit = resolve_as_failed_ n ebt
let with_create_top ~on () f = let with_create_top ~on () f =
let n = let n = create_ ~ls:(Task_local_storage.Direct.create ()) ~runner:on () in
create_
~ls:(ref @@ Task_local_storage.Private_.Storage.create ())
~runner:on ()
in
Fun.protect ~finally:(fun () -> resolve_ok_ n ()) (fun () -> f (Nursery n)) Fun.protect ~finally:(fun () -> resolve_ok_ n ()) (fun () -> f (Nursery n))
let with_create_sub ~protect (Nursery parent : t) f = let with_create_sub ~protect (Nursery parent : t) f =
let n = let n =
create_ create_
~ls:(ref @@ Task_local_storage.Private_.Storage.copy !(parent.ls)) ~ls:(Task_local_storage.Direct.copy parent.ls)
~runner:parent.runner () ~runner:parent.runner ()
in in
add_child_ ~protect parent n; add_child_ ~protect parent n;

View file

@ -27,6 +27,9 @@ type cancel_callback = Exn_bt.t -> unit
module Nursery : sig module Nursery : sig
type t type t
val runner : t -> Runner.t
(** Recover the runner this nursery uses to spawn fibers *)
val await : t -> unit val await : t -> unit
(** Await for the nursery to exit. *) (** Await for the nursery to exit. *)
@ -59,7 +62,7 @@ module Private_ : sig
state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *) state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *)
res: 'a Fut.t; res: 'a Fut.t;
runner: Runner.t; runner: Runner.t;
ls: Task_local_storage.storage ref; ls: Task_local_storage.t;
} }
(** Type definition, exposed so that {!any} can be unboxed. (** Type definition, exposed so that {!any} can be unboxed.
Please do not rely on that. *) Please do not rely on that. *)

6
src/fib/moonpool_fib.ml Normal file
View file

@ -0,0 +1,6 @@
(** Fiber for moonpool *)
module Fiber = Fiber
module Fls = Fls
module Handle = Handle
include Fiber