From c9d22309d0ee01a0e6ab39b9d52fcd69c1149a37 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 9 Jun 2023 17:56:35 -0400 Subject: [PATCH] thread-local: add get_or_create --- src/thread_local.ml | 31 +++++++++++++++++++++++-------- src/thread_local.mli | 2 ++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/thread_local.ml b/src/thread_local.ml index 9578ffdf..a96ac414 100644 --- a/src/thread_local.ml +++ b/src/thread_local.ml @@ -31,7 +31,7 @@ let[@inline] get self = try Some (get_exn self) with Not_found -> None let[@inline] get_or ~default self = try get_exn self with Not_found -> default (* remove reference for the key *) -let[@inline] remove_ref_ self key : unit = +let remove_ref_ self key : unit = while let m = A.get self in let m' = Key_map_.remove key m in @@ -40,6 +40,15 @@ let[@inline] remove_ref_ self key : unit = Thread.yield () done +let set_ self key (r : _ ref) : unit = + while + let m = A.get self in + let m' = Key_map_.add key r m in + not (A.compare_and_set self m m') + do + Thread.yield () + done + (* get or associate a reference to [key], and return it. Also return a function to remove the reference if we just created it. *) let get_or_create_ref_ (self : _ t) key ~v : _ ref * _ option = @@ -50,15 +59,21 @@ let get_or_create_ref_ (self : _ t) key ~v : _ ref * _ option = r, Some old with Not_found -> let r = ref v in - while - let m = A.get self in - let m' = Key_map_.add key r m in - not (A.compare_and_set self m m') - do - Thread.yield () - done; + set_ self key r; r, None +let get_or_create ~create (self : 'a t) : 'a = + let key = get_key_ () in + try + let r = Key_map_.find key (A.get self) in + !r + with Not_found -> + Gc.finalise (fun _ -> remove_ref_ self key) (Thread.self ()); + let v = create () in + let r = ref v in + set_ self key r; + v + let with_ self v f = let key = get_key_ () in let r, old = get_or_create_ref_ self key ~v in diff --git a/src/thread_local.mli b/src/thread_local.mli index ecd162f2..1af43af9 100644 --- a/src/thread_local.mli +++ b/src/thread_local.mli @@ -15,6 +15,8 @@ val get_exn : 'a t -> 'a (** Like {!get} but fails with an exception @raise Not_found if no value was found *) +val get_or_create : create:(unit -> 'a) -> 'a t -> 'a + val with_ : 'a t -> 'a -> ('a option -> 'b) -> 'b (** [with_ var x f] sets [var] to [x] for this thread, calls [f prev] where [prev] is the value currently in [var] (if any), and