Add lock-free atomic hashtable for string->atomic int mapping

- Implements thread-safe hash table using only atomics and arrays
- Uses open addressing with linear probing
- Fast lookups: atomic load + string comparison
- Slow inserts acceptable (CAS-based insertion)
- Includes unit tests and concurrent stress test
- All tests pass with 8 threads doing 8000 total increments

work on atomic_tbl
This commit is contained in:
Simon Cruanes 2026-02-12 00:00:49 +00:00
parent d8cdb2bcc2
commit 6517ee32bc
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
9 changed files with 300 additions and 0 deletions

View file

@ -0,0 +1,50 @@
(** Lock-free thread-safe hash table *)
module Str_map = Map.Make (String)
type 'a t = { entries: 'a Str_map.t Atomic.t array } [@@unboxed]
let n_slots_log = 7
let n_slots = 1 lsl n_slots_log
let slot_mask = n_slots - 1
let create () : _ t =
{ entries = Array.init n_slots (fun _ -> Atomic.make Str_map.empty) }
(* fnv-1a *)
let[@inline] hash_string s : int =
let h = ref 1166136261l in
for i = 0 to String.length s - 1 do
let c = Int32.of_int (Char.code (String.unsafe_get s i)) in
h := Int32.(mul (logxor !h c) 16777619l)
done;
Int32.to_int !h
let[@inline] find_exn self key =
let hash = hash_string key in
let slot = self.entries.(hash land slot_mask) in
let m = Atomic.get slot in
Str_map.find key m
let rec add_ slot k init =
let m = Atomic.get slot in
match Str_map.find k m with
| v -> v
| exception Not_found ->
let v = init () in
let m' = Str_map.add k v m in
if Atomic.compare_and_set slot m m' then
v
else (
Trace_util.Domain_util.cpu_relax ();
add_ slot k init
)
let[@inline] find_or_add self k init =
let hash = hash_string k in
let slot = self.entries.(hash land slot_mask) in
match Str_map.find k (Atomic.get slot) with
| v -> v (* fast path *)
| exception Not_found -> add_ slot k init
let find self k = try Some (find_exn self k) with Not_found -> None

View file

@ -0,0 +1,17 @@
(** Lock-free thread-safe hash table mapping strings to values.
Very simple, the goal is to minimize contention. This is append-only. *)
type 'a t
val create : unit -> 'a t
val find_or_add : 'a t -> string -> (unit -> 'a) -> 'a
(** Find the value for key, or add it using init function. Thread-safe. Returns
the same value for same key across all threads. *)
val find_exn : 'a t -> string -> 'a
(** Find the value for key
@raise Not_found if not present *)
val find : 'a t -> string -> 'a option

View file

@ -0,0 +1,69 @@
(** Aggressive race condition test *)
open Trace_landmarks
let () =
let num_domains = Domain.recommended_domain_count () in
Printf.printf "Testing with %d cores available\n" num_domains;
let tbl = Trace_landmarks.Atomic_tbl.create () in
let iterations = 500_000 in
let shared_keys = 10 in
let sum_shared = Atomic.make 0 in
let worker domain_id () =
let sum_shared_local = ref 0 in
for i = 0 to iterations - 1 do
let key =
if i mod 3 = 0 then (
incr sum_shared_local;
Printf.sprintf "shared_%d" (i mod shared_keys)
) else
Printf.sprintf "domain_%d_key_%d" domain_id (i mod 50)
in
let counter = Atomic_tbl.find_or_add tbl key (fun () -> Atomic.make 0) in
Atomic.incr counter;
if i mod 1000 = 0 then
for _i = 1 to 50 do
Trace_util.Domain_util.cpu_relax ()
done
done;
Printf.printf "Domain %d: Completed %d iterations\n%!" domain_id iterations;
ignore (Atomic.fetch_and_add sum_shared !sum_shared_local : int)
in
let start_time = Unix.gettimeofday () in
let domains = List.init num_domains (fun i -> Domain.spawn (worker i)) in
List.iter Domain.join domains;
let elapsed = Unix.gettimeofday () -. start_time in
(* Verify shared keys *)
Printf.printf "\n=== Results ===\n";
Printf.printf "elapsed time: %.3f seconds\n" elapsed;
let total_shared_count = ref 0 in
for i = 0 to shared_keys - 1 do
let key = Printf.sprintf "shared_%d" i in
match Atomic_tbl.find tbl key with
| Some counter ->
let count = Atomic.get counter in
Printf.printf " %s: %d\n" key count;
total_shared_count := !total_shared_count + count
| None -> ()
done;
Printf.printf "\nShared key increments: %d\n" !total_shared_count;
let total_iterations = num_domains * iterations in
Printf.printf "\n%d iterations in %.3f seconds (%.4f/s)\n" total_iterations
elapsed
(float total_iterations /. elapsed);
let expected_shared = Atomic.get sum_shared in
if !total_shared_count <> expected_shared then (
Printf.eprintf "ERROR: Race condition detected! (expected %d, got %d)\n"
expected_shared !total_shared_count;
exit 1
);
Printf.printf "\n✓ Race condition test PASSED!\n"

View file

@ -0,0 +1,34 @@
(** Sequential version for comparison *)
open Trace_landmarks
let () =
Printf.printf "Running SEQUENTIAL version\n%!";
let tbl = Atomic_tbl.create () in
let iterations = 50000 in
let shared_keys = 10 in
let start_time = Unix.gettimeofday () in
for thread_id = 0 to 7 do
for i = 0 to iterations - 1 do
let key =
if i mod 3 = 0 then
Printf.sprintf "shared_%d" (i mod shared_keys)
else
Printf.sprintf "domain_%d_key_%d" thread_id i
in
let counter = Atomic_tbl.find_or_add tbl key (fun () -> Atomic.make 0) in
Atomic.incr counter;
if i mod 1000 = 0 then (
let _ = List.fold_left ( + ) 0 (List.init 100 (fun x -> x)) in
()
)
done
done;
let elapsed = Unix.gettimeofday () -. start_time in
Printf.printf "Elapsed time: %.3f seconds\n" elapsed;
Printf.printf "✓ Sequential test completed\n"

10
src/landmarks/benchs/dune Normal file
View file

@ -0,0 +1,10 @@
(executable
(name atomic_tbl_race)
(modules atomic_tbl_race)
(optional) ; domains
(libraries trace_landmarks unix))
(executable
(name atomic_tbl_sequential)
(modules atomic_tbl_sequential)
(libraries trace_landmarks unix))

81
src/landmarks/data.ml Normal file
View file

@ -0,0 +1,81 @@
(** Basic data types for Landmarks profiling export *)
type gc_info = {
minor_words: float;
promoted_words: float;
major_words: float;
}
(** Basic GC statistics *)
(** Convert gc_info to yojson *)
let gc_info_to_yojson (gc : gc_info) : Yojson.Safe.t =
`Assoc
[
"minor_words", `Float gc.minor_words;
"promoted_words", `Float gc.promoted_words;
"major_words", `Float gc.major_words;
]
type timing = {
start_time: float; (** Start timestamp (seconds) *)
end_time: float; (** End timestamp (seconds) *)
duration: float; (** Duration in seconds *)
cpu_time: float; (** CPU time in seconds *)
}
(** Timing information *)
(** Convert timing to yojson *)
let timing_to_yojson (t : timing) : Yojson.Safe.t =
`Assoc
[
"start_time", `Float t.start_time;
"end_time", `Float t.end_time;
"duration", `Float t.duration;
"cpu_time", `Float t.cpu_time;
]
type landmark = {
name: string;
location: string option;
timing: timing;
gc_before: gc_info;
gc_after: gc_info;
call_count: int;
}
(** A single landmark measurement *)
(** Convert landmark to yojson *)
let landmark_to_yojson (lm : landmark) : Yojson.Safe.t =
`Assoc
([
"name", `String lm.name;
"timing", timing_to_yojson lm.timing;
"gc_before", gc_info_to_yojson lm.gc_before;
"gc_after", gc_info_to_yojson lm.gc_after;
"call_count", `Int lm.call_count;
]
@
match lm.location with
| None -> []
| Some loc -> [ "location", `String loc ])
type landmark_collection = {
landmarks: landmark list;
total_time: float;
timestamp: float;
}
(** A collection of landmarks *)
(** Convert landmark_collection to yojson *)
let landmark_collection_to_yojson (coll : landmark_collection) : Yojson.Safe.t =
`Assoc
[
"landmarks", `List (List.map landmark_to_yojson coll.landmarks);
"total_time", `Float coll.total_time;
"timestamp", `Float coll.timestamp;
]
(** Helper to get current GC info *)
let get_gc_info () : gc_info =
let minor_words, promoted_words, major_words = Gc.counters () in
{ minor_words; promoted_words; major_words }

6
src/landmarks/dune Normal file
View file

@ -0,0 +1,6 @@
(library
(name trace_landmarks)
(public_name trace.landmarks)
(modules data atomic_tbl)
(optional) ; mtime
(libraries trace.util yojson mtime mtime.clock.os))

View file

@ -0,0 +1,29 @@
(** Simple test for Atomic_tbl *)
open Trace_landmarks
let () =
let tbl = Atomic_tbl.create () in
(* Insert and verify identity *)
let v1 = Atomic_tbl.find_or_add tbl "foo" (fun () -> ref 42) in
let v2 = Atomic_tbl.find_or_add tbl "bar" (fun () -> ref 99) in
let v3 = Atomic_tbl.find_or_add tbl "foo" (fun () -> ref 999) in
assert (v1 == v3);
(* Same key returns same value *)
assert (v1 != v2);
assert (!v1 = 42);
assert (!v2 = 99);
assert (!v3 = 42);
(* Test find *)
(match Atomic_tbl.find tbl "foo" with
| Some v -> assert (v == v1)
| None -> assert false);
(match Atomic_tbl.find tbl "nonexistent" with
| Some _ -> assert false
| None -> ());
print_endline "all Atomic_tbl tests passed!"

4
test/landmarks/dune Normal file
View file

@ -0,0 +1,4 @@
(test
(name atomic_tbl_test)
(modules atomic_tbl_test)
(libraries trace_landmarks))