add CCLocal_storage to containers-thread

Basic local storage emulation using an atomic-protected map. The goal
is for this to work both with 4.xx threads, 5.xx domains, and 5.xx
threads running on a given domain.
This commit is contained in:
Simon Cruanes 2022-09-14 18:49:46 -04:00
parent 00d344e09e
commit 3e105434d9
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
4 changed files with 165 additions and 0 deletions

View file

@ -0,0 +1,89 @@
module A = CCAtomic
[@@@ifge 5.00]
type key = int*int
let get_key_ () : key =
Domain.id (Domain.self()), Thread.id (Thread.self())
module Key_map_ = CCMap.Make(struct
type t = key
let compare : t -> t -> int = compare
end)
[@@@else_]
type key = int
let get_key_ () : key =
Thread.id (Thread.self())
module Key_map_ = CCMap.Make(struct
type t = key
let compare : t -> t -> int = CCInt.compare
end)
[@@@endif]
type 'a t = 'a Key_map_.t A.t
let create () : _ t = A.make Key_map_.empty
let[@inline] n_entries self = Key_map_.cardinal (A.get self)
let get (self: _ t ) : _ option =
let m = A.get self in
let key = get_key_ () in
Key_map_.get key m
let get_exn self =
let m = A.get self in
let key = get_key_ () in
Key_map_.find key m
let get_or ~default self =
try get_exn self
with Not_found -> default
let set (self: _ t ) v : unit =
let key = get_key_ () in
while
let m = A.get self in
let m' = Key_map_.add key v m in
not (A.compare_and_set self m m')
do () done
let set_get (self: _ t ) v : _ option =
let key = get_key_ () in
let rec loop () =
let m = A.get self in
let m' = Key_map_.add key v m in
if A.compare_and_set self m m' then Key_map_.get key m
else loop()
in loop ()
let remove self =
let key = get_key_ () in
while
let m = A.get self in
let m' = Key_map_.remove key m in
not (A.compare_and_set self m m')
do () done
let[@inline] set_opt_ self v =
match v with
| None -> remove self
| Some v' -> set self v'
let with_ self x f =
let old = set_get self x in
try
let r = f() in
set_opt_ self old;
r
with e ->
let bt = Printexc.get_raw_backtrace () in
set_opt_ self old;
Printexc.raise_with_backtrace e bt

View file

@ -0,0 +1,46 @@
(** Thread/Domain local storage
This allows the creation of global state that is per-domain or per-thread.
{b status} experimental
@since NEXT_RELEASE
*)
type 'a t
val create : unit -> 'a t
(** Create new storage *)
val get : 'a t -> 'a option
(** Get the content for this thread, if any. *)
val get_exn : 'a t -> 'a
(** Same as {!get}, but fails if no data was associated to this thread.
@raise Not_found if the data is not there. *)
val get_or : default:'a -> 'a t -> 'a
(** Same as {!get} but returns [default] if no data is associated
to this thread. *)
val set : 'a t -> 'a -> unit
(** Set content for this thread. *)
val set_get : 'a t -> 'a -> 'a option
(** Set content for this thread, and return the old value. *)
val remove : 'a t -> unit
(** Remove value *)
val n_entries : _ t -> int
(** Number of entries in the map currently.
Be aware that some threads might
have exited without cleaning up behind them. See {!with_} for
scope-protected modification of the variable that will cleanup
properly (like {!Fun.protect}).
*)
val with_ : 'a t -> 'a -> (unit -> 'b) -> 'b
(** [with_ var x f] sets [var] to [x] for this thread, calls [f()], and
then restores the old value of [var] for this thread. *)

View file

@ -6,4 +6,5 @@ Containers_testlib.run_all ~descr:"containers-thread"
T_semaphore.Test.get (); T_semaphore.Test.get ();
T_thread.Test.get (); T_thread.Test.get ();
T_timer.Test.get (); T_timer.Test.get ();
T_local_storage.Test.get ();
] ]

View file

@ -0,0 +1,29 @@
module Test = (val Containers_testlib.make ~__FILE__ ())
open Test
module L = CCLocal_storage;;
t @@ fun () ->
let var = L.create () in
let sum_of_res = CCAtomic.make 0 in
let n = 1_000 in
let run1 () =
L.with_ var 0 @@ fun () ->
for _i = 1 to n do
let x = L.get_exn var in
Thread.yield ();
L.set var (x + 1)
done;
ignore (CCAtomic.fetch_and_add sum_of_res (L.get_exn var) : int)
in
let threads = Array.init 16 (fun _ -> Thread.create run1 ()) in
Array.iter Thread.join threads;
assert_equal ~printer:string_of_int (n * 16) (CCAtomic.get sum_of_res);
(* cleanup *)
assert_equal ~printer:string_of_int 0 (L.n_entries var);
true