safer, but slower, TLS

This commit is contained in:
Simon Cruanes 2023-10-27 17:07:30 -04:00
parent def384b4f8
commit 5ab96aabbc
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -5,8 +5,13 @@ module A = Atomic_
(* sanity check *) (* sanity check *)
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ()) let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())
type view = ..
type view += Sentinel
type 'a key = { type 'a key = {
index: int; (** Unique index for this key. *) index: int; (** Unique index for this key. *)
unwrap: view -> 'a option;
wrap: 'a -> view;
compute: unit -> 'a; compute: unit -> 'a;
(** Initializer for values for this key. Called at most (** Initializer for values for this key. Called at most
once per thread. *) once per thread. *)
@ -15,12 +20,17 @@ type 'a key = {
(** Counter used to allocate new keys *) (** Counter used to allocate new keys *)
let counter = A.make 0 let counter = A.make 0
(** Value used to detect a TLS slot that was not initialized yet *) let new_key (type a) compute : a key =
let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter let module M = struct
type view += V of a
let new_key compute : _ key = end in
let index = A.fetch_and_add counter 1 in let index = A.fetch_and_add counter 1 in
{ index; compute } let wrap x = M.V x in
let unwrap = function
| M.V x -> Some x
| _ -> None
in
{ index; compute; wrap; unwrap }
type thread_internal_state = { type thread_internal_state = {
_id: int; (** Thread ID (here for padding reasons) *) _id: int; (** Thread ID (here for padding reasons) *)
@ -42,13 +52,13 @@ let ceil_pow_2_minus_1 (n : int) : int =
n n
(** Grow the array so that [index] is valid. *) (** Grow the array so that [index] is valid. *)
let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array = let[@inline never] grow_tls (old : view array) (index : int) : view array =
let new_length = ceil_pow_2_minus_1 (index + 1) in let new_length = ceil_pow_2_minus_1 (index + 1) in
let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in let new_ = Array.make new_length Sentinel in
Array.blit old 0 new_ 0 (Array.length old); Array.blit old 0 new_ 0 (Array.length old);
new_ new_
let[@inline] get_tls_ (index : int) : Obj.t array = let[@inline] get_tls_ (index : int) : view array =
let thread : thread_internal_state = Obj.magic (Thread.self ()) in let thread : thread_internal_state = Obj.magic (Thread.self ()) in
let tls = thread.tls in let tls = thread.tls in
if Obj.is_int tls then ( if Obj.is_int tls then (
@ -56,7 +66,7 @@ let[@inline] get_tls_ (index : int) : Obj.t array =
thread.tls <- Obj.magic new_tls; thread.tls <- Obj.magic new_tls;
new_tls new_tls
) else ( ) else (
let tls = (Obj.magic tls : Obj.t array) in let tls = (Obj.magic tls : view array) in
if index < Array.length tls then if index < Array.length tls then
tls tls
else ( else (
@ -69,14 +79,17 @@ let[@inline] get_tls_ (index : int) : Obj.t array =
let get key = let get key =
let tls = get_tls_ key.index in let tls = get_tls_ key.index in
let value = Array.unsafe_get tls key.index in let value = Array.unsafe_get tls key.index in
if value != sentinel_value_for_uninit_tls_ () then
Obj.magic value match key.unwrap value with
else ( | Some v -> v
let value = key.compute () in | None ->
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)); (match value with
value | Sentinel ->
) let value = key.compute () in
Array.unsafe_set tls key.index (key.wrap value);
value
| _ -> assert false)
let set key value = let set key value =
let tls = get_tls_ key.index in let tls = get_tls_ key.index in
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)) Array.unsafe_set tls key.index (key.wrap value)