From 3e105434d95d9db44324e780283247aad5dff4aa Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 14 Sep 2022 18:49:46 -0400 Subject: [PATCH] add `CCLocal_storage` to containers-thread Basic local storage emulation using an atomic-protected map. The goal is for this to work both with 4.xx threads, 5.xx domains, and 5.xx threads running on a given domain. --- src/threads/CCLocal_storage.ml | 89 +++++++++++++++++++++++++++++++++ src/threads/CCLocal_storage.mli | 46 +++++++++++++++++ tests/thread/t.ml | 1 + tests/thread/t_local_storage.ml | 29 +++++++++++ 4 files changed, 165 insertions(+) create mode 100644 src/threads/CCLocal_storage.ml create mode 100644 src/threads/CCLocal_storage.mli create mode 100644 tests/thread/t_local_storage.ml diff --git a/src/threads/CCLocal_storage.ml b/src/threads/CCLocal_storage.ml new file mode 100644 index 00000000..27b6c8c8 --- /dev/null +++ b/src/threads/CCLocal_storage.ml @@ -0,0 +1,89 @@ + +module A = CCAtomic + +[@@@ifge 5.00] + +type key = int*int + +let get_key_ () : key = + Domain.id (Domain.self()), Thread.id (Thread.self()) + +module Key_map_ = CCMap.Make(struct + type t = key + let compare : t -> t -> int = compare +end) + +[@@@else_] + +type key = int + +let get_key_ () : key = + Thread.id (Thread.self()) + +module Key_map_ = CCMap.Make(struct + type t = key + let compare : t -> t -> int = CCInt.compare +end) + +[@@@endif] + +type 'a t = 'a Key_map_.t A.t + +let create () : _ t = A.make Key_map_.empty + +let[@inline] n_entries self = Key_map_.cardinal (A.get self) + +let get (self: _ t ) : _ option = + let m = A.get self in + let key = get_key_ () in + Key_map_.get key m + +let get_exn self = + let m = A.get self in + let key = get_key_ () in + Key_map_.find key m + +let get_or ~default self = + try get_exn self + with Not_found -> default + +let set (self: _ t ) v : unit = + let key = get_key_ () in + while + let m = A.get self in + let m' = Key_map_.add key v m in + not (A.compare_and_set self m m') + do () done + +let set_get (self: _ t ) v : _ option = + let key = get_key_ () in + let rec loop () = + let m = A.get self in + let m' = Key_map_.add key v m in + if A.compare_and_set self m m' then Key_map_.get key m + else loop() + in loop () + +let remove self = + let key = get_key_ () in + while + let m = A.get self in + let m' = Key_map_.remove key m in + not (A.compare_and_set self m m') + do () done + +let[@inline] set_opt_ self v = + match v with + | None -> remove self + | Some v' -> set self v' + +let with_ self x f = + let old = set_get self x in + try + let r = f() in + set_opt_ self old; + r + with e -> + let bt = Printexc.get_raw_backtrace () in + set_opt_ self old; + Printexc.raise_with_backtrace e bt diff --git a/src/threads/CCLocal_storage.mli b/src/threads/CCLocal_storage.mli new file mode 100644 index 00000000..1374e5ff --- /dev/null +++ b/src/threads/CCLocal_storage.mli @@ -0,0 +1,46 @@ +(** Thread/Domain local storage + + This allows the creation of global state that is per-domain or per-thread. + + {b status} experimental + + @since NEXT_RELEASE +*) + +type 'a t + +val create : unit -> 'a t +(** Create new storage *) + +val get : 'a t -> 'a option +(** Get the content for this thread, if any. *) + +val get_exn : 'a t -> 'a +(** Same as {!get}, but fails if no data was associated to this thread. + @raise Not_found if the data is not there. *) + +val get_or : default:'a -> 'a t -> 'a +(** Same as {!get} but returns [default] if no data is associated + to this thread. *) + +val set : 'a t -> 'a -> unit +(** Set content for this thread. *) + +val set_get : 'a t -> 'a -> 'a option +(** Set content for this thread, and return the old value. *) + +val remove : 'a t -> unit +(** Remove value *) + +val n_entries : _ t -> int +(** Number of entries in the map currently. + + Be aware that some threads might + have exited without cleaning up behind them. See {!with_} for + scope-protected modification of the variable that will cleanup + properly (like {!Fun.protect}). +*) + +val with_ : 'a t -> 'a -> (unit -> 'b) -> 'b +(** [with_ var x f] sets [var] to [x] for this thread, calls [f()], and + then restores the old value of [var] for this thread. *) diff --git a/tests/thread/t.ml b/tests/thread/t.ml index 21518fff..743881b5 100644 --- a/tests/thread/t.ml +++ b/tests/thread/t.ml @@ -6,4 +6,5 @@ Containers_testlib.run_all ~descr:"containers-thread" T_semaphore.Test.get (); T_thread.Test.get (); T_timer.Test.get (); + T_local_storage.Test.get (); ] diff --git a/tests/thread/t_local_storage.ml b/tests/thread/t_local_storage.ml new file mode 100644 index 00000000..e2f6cd3e --- /dev/null +++ b/tests/thread/t_local_storage.ml @@ -0,0 +1,29 @@ +module Test = (val Containers_testlib.make ~__FILE__ ()) +open Test +module L = CCLocal_storage;; + +t @@ fun () -> +let var = L.create () in + +let sum_of_res = CCAtomic.make 0 in +let n = 1_000 in + +let run1 () = + L.with_ var 0 @@ fun () -> + for _i = 1 to n do + let x = L.get_exn var in + Thread.yield (); + L.set var (x + 1) + done; + ignore (CCAtomic.fetch_and_add sum_of_res (L.get_exn var) : int) +in + +let threads = Array.init 16 (fun _ -> Thread.create run1 ()) in +Array.iter Thread.join threads; + +assert_equal ~printer:string_of_int (n * 16) (CCAtomic.get sum_of_res); + +(* cleanup *) +assert_equal ~printer:string_of_int 0 (L.n_entries var); + +true