diff --git a/src/thread_local_storage.real.ml b/src/thread_local_storage.real.ml index 70d7a558..768e2c35 100644 --- a/src/thread_local_storage.real.ml +++ b/src/thread_local_storage.real.ml @@ -5,8 +5,13 @@ module A = Atomic_ (* sanity check *) let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ()) +type view = .. +type view += Sentinel + type 'a key = { index: int; (** Unique index for this key. *) + unwrap: view -> 'a option; + wrap: 'a -> view; compute: unit -> 'a; (** Initializer for values for this key. Called at most once per thread. *) @@ -15,12 +20,17 @@ type 'a key = { (** Counter used to allocate new keys *) let counter = A.make 0 -(** Value used to detect a TLS slot that was not initialized yet *) -let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter - -let new_key compute : _ key = +let new_key (type a) compute : a key = + let module M = struct + type view += V of a + end in let index = A.fetch_and_add counter 1 in - { index; compute } + let wrap x = M.V x in + let unwrap = function + | M.V x -> Some x + | _ -> None + in + { index; compute; wrap; unwrap } type thread_internal_state = { _id: int; (** Thread ID (here for padding reasons) *) @@ -42,13 +52,13 @@ let ceil_pow_2_minus_1 (n : int) : int = n (** Grow the array so that [index] is valid. *) -let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array = +let[@inline never] grow_tls (old : view array) (index : int) : view array = let new_length = ceil_pow_2_minus_1 (index + 1) in - let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in + let new_ = Array.make new_length Sentinel in Array.blit old 0 new_ 0 (Array.length old); new_ -let[@inline] get_tls_ (index : int) : Obj.t array = +let[@inline] get_tls_ (index : int) : view array = let thread : thread_internal_state = Obj.magic (Thread.self ()) in let tls = thread.tls in if Obj.is_int tls then ( @@ -56,7 +66,7 @@ let[@inline] get_tls_ (index : int) : Obj.t array = thread.tls <- Obj.magic new_tls; new_tls ) else ( - let tls = (Obj.magic tls : Obj.t array) in + let tls = (Obj.magic tls : view array) in if index < Array.length tls then tls else ( @@ -69,14 +79,17 @@ let[@inline] get_tls_ (index : int) : Obj.t array = let get key = let tls = get_tls_ key.index in let value = Array.unsafe_get tls key.index in - if value != sentinel_value_for_uninit_tls_ () then - Obj.magic value - else ( - let value = key.compute () in - Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)); - value - ) + + match key.unwrap value with + | Some v -> v + | None -> + (match value with + | Sentinel -> + let value = key.compute () in + Array.unsafe_set tls key.index (key.wrap value); + value + | _ -> assert false) let set key value = let tls = get_tls_ key.index in - Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)) + Array.unsafe_set tls key.index (key.wrap value)