move to thread-local-storage 0.2 with get/set API

This commit is contained in:
Simon Cruanes 2024-08-16 10:07:51 -04:00
parent 3388098fcc
commit 265d4f73dd
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
10 changed files with 112 additions and 89 deletions

View file

@ -30,6 +30,7 @@
(depopts (depopts
(trace (>= 0.6)) (trace (>= 0.6))
thread-local-storage) thread-local-storage)
(conflicts (thread-local-storage (< 0.2)))
(tags (tags
(thread pool domain futures fork-join))) (thread pool domain futures fork-join)))

View file

@ -22,6 +22,9 @@ depopts: [
"trace" {>= "0.6"} "trace" {>= "0.6"}
"thread-local-storage" "thread-local-storage"
] ]
conflicts: [
"thread-local-storage" {< "0.2"}
]
build: [ build: [
["dune" "subst"] {dev} ["dune" "subst"] {dev}
[ [

View file

@ -31,18 +31,17 @@ let schedule_ (self : state) (task : task_full) : unit =
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type worker_state = { mutable cur_ls: Task_local_storage.t option } type worker_state = { mutable cur_ls: Task_local_storage.t option }
let k_worker_state : worker_state option ref TLS.key = let k_worker_state : worker_state TLS.t = TLS.create ()
TLS.new_key (fun () -> ref None)
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let w = { cur_ls = None } in let w = { cur_ls = None } in
TLS.get k_worker_state := Some w; TLS.set k_worker_state w;
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; TLS.set Runner.For_runner_implementors.k_cur_runner runner;
let (AT_pair (before_task, after_task)) = around_task in let (AT_pair (before_task, after_task)) = around_task in
let on_suspend () = let on_suspend () =
match !(TLS.get k_worker_state) with match TLS.get_opt k_worker_state with
| Some { cur_ls = Some ls; _ } -> ls | Some { cur_ls = Some ls; _ } -> ls
| _ -> assert false | _ -> assert false
in in
@ -55,7 +54,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
| T_start { ls; _ } | T_resume { ls; _ } -> ls | T_start { ls; _ } | T_resume { ls; _ } -> ls
in in
w.cur_ls <- Some ls; w.cur_ls <- Some ls;
TLS.get k_cur_storage := Some ls; TLS.set k_cur_storage ls;
let _ctx = before_task runner in let _ctx = before_task runner in
(* run the task now, catching errors, handling effects *) (* run the task now, catching errors, handling effects *)
@ -74,7 +73,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
on_exn e bt); on_exn e bt);
after_task runner _ctx; after_task runner _ctx;
w.cur_ls <- None; w.cur_ls <- None;
TLS.get k_cur_storage := None TLS.set k_cur_storage _dummy_ls
in in
let main_loop () = let main_loop () =

View file

@ -40,7 +40,7 @@ module For_runner_implementors = struct
let create ~size ~num_tasks ~shutdown ~run_async () : t = let create ~size ~num_tasks ~shutdown ~run_async () : t =
{ size; num_tasks; shutdown; run_async } { size; num_tasks; shutdown; run_async }
let k_cur_runner : t option ref TLS.key = Types_.k_cur_runner let k_cur_runner : t TLS.t = Types_.k_cur_runner
end end
let dummy : t = let dummy : t =

View file

@ -73,7 +73,7 @@ module For_runner_implementors : sig
{b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, {b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x,
so that {!Fork_join} and other 5.x features work properly. *) so that {!Fork_join} and other 5.x features work properly. *)
val k_cur_runner : t option ref Thread_local_storage_.key val k_cur_runner : t Thread_local_storage_.t
(** Key that should be used by each runner to store itself in TLS (** Key that should be used by each runner to store itself in TLS
on every thread it controls, so that tasks running on these threads on every thread it controls, so that tasks running on these threads
can access the runner. This is necessary for {!get_current_runner} can access the runner. This is necessary for {!get_current_runner}

View file

@ -8,7 +8,7 @@ let key_count_ = A.make 0
type t = local_storage type t = local_storage
type ls_value += Dummy type ls_value += Dummy
let dummy : t = ref [||] let dummy : t = _dummy_ls
(** Resize array of TLS values *) (** Resize array of TLS values *)
let[@inline never] resize_ (cur : ls_value array ref) n = let[@inline never] resize_ (cur : ls_value array ref) n =
@ -57,7 +57,9 @@ let new_key (type t) ~init () : t key =
let[@inline] get_cur_ () : ls_value array ref = let[@inline] get_cur_ () : ls_value array ref =
match get_current_storage () with match get_current_storage () with
| Some r -> r | Some r ->
assert (r != dummy);
r
| None -> failwith "Task local storage must be accessed from within a runner." | None -> failwith "Task local storage must be accessed from within a runner."
let[@inline] get (key : 'a key) : 'a = let[@inline] get (key : 'a key) : 'a =

View file

@ -27,11 +27,9 @@ type runner = {
num_tasks: unit -> int; num_tasks: unit -> int;
} }
let k_cur_runner : runner option ref TLS.key = TLS.new_key (fun () -> ref None) let k_cur_runner : runner TLS.t = TLS.create ()
let k_cur_storage : local_storage TLS.t = TLS.create ()
let k_cur_storage : local_storage option ref TLS.key = let _dummy_ls : local_storage = ref [||]
TLS.new_key (fun () -> ref None) let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner
let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage
let[@inline] get_current_runner () : _ option = !(TLS.get k_cur_runner)
let[@inline] get_current_storage () : _ option = !(TLS.get k_cur_storage)
let[@inline] create_local_storage () = ref [||] let[@inline] create_local_storage () = ref [||]

View file

@ -65,11 +65,10 @@ let num_tasks_ (self : state) : int =
(** TLS, used by worker to store their specific state (** TLS, used by worker to store their specific state
and be able to retrieve it from tasks when we schedule new and be able to retrieve it from tasks when we schedule new
sub-tasks. *) sub-tasks. *)
let k_worker_state : worker_state option ref TLS.key = let k_worker_state : worker_state TLS.t = TLS.create ()
TLS.new_key (fun () -> ref None)
let[@inline] find_current_worker_ () : worker_state option = let[@inline] find_current_worker_ () : worker_state option =
!(TLS.get k_worker_state) TLS.get_opt k_worker_state
(** Try to wake up a waiter, if there's any. *) (** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit = let[@inline] try_wake_someone_ (self : state) : unit =
@ -121,7 +120,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
in in
w.cur_ls <- Some ls; w.cur_ls <- Some ls;
TLS.get k_cur_storage := Some ls; TLS.set k_cur_storage ls;
let _ctx = before_task runner in let _ctx = before_task runner in
let[@inline] on_suspend () : _ ref = let[@inline] on_suspend () : _ ref =
@ -166,7 +165,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
after_task runner _ctx; after_task runner _ctx;
w.cur_ls <- None; w.cur_ls <- None;
TLS.get k_cur_storage := None TLS.set k_cur_storage _dummy_ls
let run_async_ (self : state) ~ls (f : task) : unit = let run_async_ (self : state) ~ls (f : task) : unit =
let w = find_current_worker_ () in let w = find_current_worker_ () in
@ -222,8 +221,8 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
(** Main loop for a worker thread. *) (** Main loop for a worker thread. *)
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; TLS.set Runner.For_runner_implementors.k_cur_runner runner;
TLS.get k_worker_state := Some w; TLS.set k_worker_state w;
let rec main () : unit = let rec main () : unit =
worker_run_self_tasks_ self ~runner w; worker_run_self_tasks_ self ~runner w;
@ -358,7 +357,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let thread = Thread.self () in let thread = Thread.self () in
let t_id = Thread.id thread in let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id (); on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.get k_cur_storage := None; TLS.set k_cur_storage _dummy_ls;
(* set thread name *) (* set thread name *)
Option.iter Option.iter

View file

@ -1,21 +1,13 @@
(** Thread local storage *) (** Thread local storage *)
(* TODO: alias this to the library if present *) type 'a t
(** A TLS slot for values of type ['a]. This allows the storage of a
single value of type ['a] per thread. *)
type 'a key val create : unit -> 'a t
(** A TLS key for values of type ['a]. This allows the
storage of a single value of type ['a] per thread. *)
val new_key : (unit -> 'a) -> 'a key val get : 'a t -> 'a
(** Allocate a new, generative key. (** @raise Failure if not present *)
When the key is used for the first time on a thread,
the function is called to produce it.
This should only ever be called at toplevel to produce val get_opt : 'a t -> 'a option
constants, do not use it in a loop. *) val set : 'a t -> 'a -> unit
val get : 'a key -> 'a
(** Get the value for the current thread. *)
val set : 'a key -> 'a -> unit
(** Set the value for the current thread. *)

View file

@ -1,36 +1,13 @@
(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *) (* vendored from https://github.com/c-cube/thread-local-storage *)
module A = Atomic_
(* sanity check *) (* sanity check *)
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ()) let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())
type 'a key = { type 'a t = int
index: int; (** Unique index for this key. *) (** Unique index for this TLS slot. *)
compute: unit -> 'a;
(** Initializer for values for this key. Called at most
once per thread. *)
}
(** Counter used to allocate new keys *) let tls_length index =
let counter = A.make 0 let ceil_pow_2_minus_1 (n : int) : int =
(** Value used to detect a TLS slot that was not initialized yet *)
let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter
let new_key compute : _ key =
let index = A.fetch_and_add counter 1 in
{ index; compute }
type thread_internal_state = {
_id: int; (** Thread ID (here for padding reasons) *)
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
}
(** A partial representation of the internal type [Thread.t], allowing
us to access the second field (unused after the thread
has started) and stash TLS data in it. *)
let ceil_pow_2_minus_1 (n : int) : int =
let n = n lor (n lsr 1) in let n = n lor (n lsr 1) in
let n = n lor (n lsr 2) in let n = n lor (n lsr 2) in
let n = n lor (n lsr 4) in let n = n lor (n lsr 4) in
@ -40,43 +17,95 @@ let ceil_pow_2_minus_1 (n : int) : int =
n lor (n lsr 32) n lor (n lsr 32)
else else
n n
in
let size = ceil_pow_2_minus_1 (index + 1) in
assert (size > index);
size
(** Counter used to allocate new keys *)
let counter = Atomic.make 0
(** Value used to detect a TLS slot that was not initialized yet.
Because [counter] is private and lives forever, no other
object the user can see will have the same address. *)
let sentinel_value_for_uninit_tls : Obj.t = Obj.repr counter
external max_wosize : unit -> int = "caml_sys_const_max_wosize"
let max_word_size = max_wosize ()
let create () : _ t =
let index = Atomic.fetch_and_add counter 1 in
if tls_length index <= max_word_size then
index
else (
(* Some platforms have a small max word size. *)
ignore (Atomic.fetch_and_add counter (-1));
failwith "Thread_local_storage.create: out of TLS slots"
)
type thread_internal_state = {
_id: int; (** Thread ID (here for padding reasons) *)
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
_other: Obj.t;
(** Here to avoid lying to ocamlopt/flambda about the size of [Thread.t] *)
}
(** A partial representation of the internal type [Thread.t], allowing
us to access the second field (unused after the thread
has started) and stash TLS data in it. *)
let[@inline] get_raw index : Obj.t =
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
let tls = thread.tls in
if Obj.is_block tls && index < Array.length (Obj.obj tls : Obj.t array) then
Array.unsafe_get (Obj.obj tls : Obj.t array) index
else
sentinel_value_for_uninit_tls
let[@inline never] tls_error () =
failwith "Thread_local_storage.get: TLS entry not initialised"
let[@inline] get slot =
let v = get_raw slot in
if v != sentinel_value_for_uninit_tls then
Obj.obj v
else
tls_error ()
let[@inline] get_opt slot =
let v = get_raw slot in
if v != sentinel_value_for_uninit_tls then
Some (Obj.obj v)
else
None
(** Allocating and setting *)
(** Grow the array so that [index] is valid. *) (** Grow the array so that [index] is valid. *)
let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array = let grow (old : Obj.t array) (index : int) : Obj.t array =
let new_length = ceil_pow_2_minus_1 (index + 1) in let new_length = tls_length index in
let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in let new_ = Array.make new_length sentinel_value_for_uninit_tls in
Array.blit old 0 new_ 0 (Array.length old); Array.blit old 0 new_ 0 (Array.length old);
new_ new_
let[@inline] get_tls_ (index : int) : Obj.t array = let get_tls_with_capacity index : Obj.t array =
let thread : thread_internal_state = Obj.magic (Thread.self ()) in let thread : thread_internal_state = Obj.magic (Thread.self ()) in
let tls = thread.tls in let tls = thread.tls in
if Obj.is_int tls then ( if Obj.is_int tls then (
let new_tls = grow_tls [||] index in let new_tls = grow [||] index in
thread.tls <- Obj.magic new_tls; thread.tls <- Obj.repr new_tls;
new_tls new_tls
) else ( ) else (
let tls = (Obj.magic tls : Obj.t array) in let tls = (Obj.obj tls : Obj.t array) in
if index < Array.length tls then if index < Array.length tls then
tls tls
else ( else (
let new_tls = grow_tls tls index in let new_tls = grow tls index in
thread.tls <- Obj.magic new_tls; thread.tls <- Obj.repr new_tls;
new_tls new_tls
) )
) )
let get key = let[@inline] set slot value : unit =
let tls = get_tls_ key.index in let tls = get_tls_with_capacity slot in
let value = Array.unsafe_get tls key.index in Array.unsafe_set tls slot (Obj.repr (Sys.opaque_identity value))
if value != sentinel_value_for_uninit_tls_ () then
Obj.magic value
else (
let value = key.compute () in
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value));
value
)
let set key value =
let tls = get_tls_ key.index in
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value))