diff --git a/src/data/CCHashTrie.ml b/src/data/CCHashTrie.ml index ad1d6c8f..337f268c 100644 --- a/src/data/CCHashTrie.ml +++ b/src/data/CCHashTrie.ml @@ -9,7 +9,7 @@ type 'a ktree = unit -> [`Nil | `Node of 'a * 'a ktree list] (** {2 Fixed-Size Arrays} *) module type FIXED_ARRAY = sig - type 'a t + type +'a t val create : 'a -> 'a t val length_log : int val length : int (* 2 power length_log *) @@ -68,29 +68,46 @@ end (** {2 Arrays} *) module A32 : FIXED_ARRAY = struct - type 'a t = 'a array + type +'a t = { dummy1: 'a; dummy2 : 'a } (* used for variance only *) + + (* NOTE for safety: + + the array and the record are both boxed types, in the heap + (since it has two fields it should not change in the future). + + using an array as covariant is safe because we ALWAYS copy before writing, + so we cannot put a wrong value in [a] by upcasting it and writing. + *) + + external hide_array_ : 'a array -> 'a t = "%identity" + external get_array_ : 'a t -> 'a array = "%identity" let length_log = 5 let length = 1 lsl length_log (* 32 *) - let create x = Array.make length x + let create x = hide_array_ (Array.make length x) - let get a i = a.(i) + let get a i = Array.get (get_array_ a) i let set a i x = - let a' = Array.copy a in + let a' = Array.copy (get_array_ a) in a'.(i) <- x; - a' + hide_array_ a' let update a i f = - let x = a.(i) in - let y = f a.(i) in - if x==y then a else set a i y + let x = Array.get (get_array_ a) i in + let y = f x in + if x==y then a + else ( + let a' = Array.copy (get_array_ a) in + a'.(i) <- y; + hide_array_ a' + ) - let iter = Array.iter + let iter f a = Array.iter f (get_array_ a) - let fold = Array.fold_left + let fold f acc a = Array.fold_left f acc (get_array_ a) end (** {2 Functors} *) @@ -105,11 +122,15 @@ module Make(Key : KEY) module Hash : sig type t = private int val make : Key.t -> t + val zero : t (* special "hash" *) + val is_0 : t -> bool val rem : t -> int (* [A.length_log] last bits *) val quotient : t -> t (* remove [A.length_log] last bits *) end = struct type t = int let make = Key.hash + let zero = 0 + let is_0 h = h==0 let rem h = h land (A.length - 1) let quotient h = h lsr A.length_log end @@ -126,13 +147,20 @@ module Make(Key : KEY) type 'a t = | E | L of Hash.t * 'a leaf (* same hash for all elements *) - | N of 'a t A.t + | N of 'a leaf * 'a t A.t (* leaf for hash=0, subnodes *) (* invariants: L [] --> E N [E, E,...., E] -> E *) + (* NOTE for safety: + + only allocate one empty array. It will contain only [E] for every + different value type + *) + let empty_arr_ = A.create E + let empty = E let is_empty = function @@ -153,10 +181,12 @@ module Make(Key : KEY) let rec get_exn_ k ~h m = match m with | E -> raise Not_found | L (_, l) -> get_exn_list_ k l - | N a -> - let i = Hash.rem h in - let h' = Hash.quotient h in - get_exn_ k ~h:h' (A.get a i) + | N (leaf, a) -> + if Hash.is_0 h then get_exn_list_ k leaf + else + let i = Hash.rem h in + let h' = Hash.quotient h in + get_exn_ k ~h:h' (A.get a i) let get_exn k m = get_exn_ k ~h:(hash_ k) m @@ -173,15 +203,24 @@ module Make(Key : KEY) if h=h' then L (h, add_list_ k v ~h l) else (* split into N *) - let a = A.create E in - (* put leaf in the right bucket *) - let i = Hash.rem h' in - let h'' = Hash.quotient h' in - let a = A.set a i (L (h'', l)) in + let a = empty_arr_ in + let a, leaf = + if Hash.is_0 h' then a, l + else + (* put leaf in the right bucket *) + let i = Hash.rem h' in + let h'' = Hash.quotient h' in + A.set a i (L (h'', l)), Nil + in (* then add new node *) - let a = add_to_array_ k v ~h a in - N a - | N a -> N (add_to_array_ k v ~h a) + let a, leaf = + if Hash.is_0 h then a, add_list_ k v ~h leaf + else add_to_array_ k v ~h a, leaf + in + N (leaf, a) + | N (leaf, a) -> + if Hash.is_0 h then N (add_list_ k v ~h leaf, a) + else N (leaf, add_to_array_ k v ~h a) (* [left] list nodes already visited *) and add_list_ k v ~h l = match l with @@ -208,6 +247,10 @@ module Make(Key : KEY) true with LocalExit -> false + let is_empty_list_ = function + | Nil -> true + | Cons _ -> false + let rec remove_list_ k l = match l with | Nil -> Nil | Cons (k', v', tail) -> @@ -218,17 +261,20 @@ module Make(Key : KEY) let rec remove_rec_ k ~h m = match m with | E -> E | L (h, l) -> - begin match remove_list_ k l with - | Nil -> E - | Cons _ as res -> L (h, res) - end - | N a -> - let i = Hash.rem h in - let h' = Hash.quotient h in - let a' = A.set a i (remove_rec_ k ~h:h' (A.get a i)) in - if is_empty_arr_ a' + let l = remove_list_ k l in + if is_empty_list_ l then E else L (h, l) + | N (leaf, a) -> + let leaf, a = + if Hash.is_0 h + then remove_list_ k leaf, a + else + let i = Hash.rem h in + let h' = Hash.quotient h in + leaf, A.set a i (remove_rec_ k ~h:h' (A.get a i)) + in + if is_empty_list_ leaf && is_empty_arr_ a then E - else N a' + else N (leaf, a) let remove k m = remove_rec_ k ~h:(hash_ k) m @@ -236,7 +282,7 @@ module Make(Key : KEY) let rec aux = function | E -> () | L (_,l) -> aux_list l - | N a -> A.iter aux a + | N (l,a) -> aux_list l; A.iter aux a and aux_list = function | Nil -> () | Cons (k, v, tl) -> f k v; aux_list tl @@ -247,7 +293,7 @@ module Make(Key : KEY) let rec aux acc t = match t with | E -> acc | L (_,l) -> aux_list acc l - | N a -> A.fold aux acc a + | N (l,a) -> let acc = aux_list acc l in A.fold aux acc a and aux_list acc l = match l with | Nil -> acc | Cons (k, v, tl) -> let acc = f acc k v in aux_list acc tl @@ -275,7 +321,7 @@ module Make(Key : KEY) let rec as_tree m () = match m with | E -> `Nil | L (h,l) -> `Node (`L ((h:>int), list_as_tree_ l), []) - | N a -> `Node (`N, array_as_tree_ a) + | N (l,a) -> `Node (`N, as_tree (L (Hash.zero, l)) :: array_as_tree_ a) and list_as_tree_ l = match l with | Nil -> [] | Cons (k, v, tail) -> (k,v) :: list_as_tree_ tail diff --git a/src/data/CCHashTrie.mli b/src/data/CCHashTrie.mli index 0082cd07..eb621e72 100644 --- a/src/data/CCHashTrie.mli +++ b/src/data/CCHashTrie.mli @@ -19,7 +19,7 @@ type 'a ktree = unit -> [`Nil | `Node of 'a * 'a ktree list] (** {2 Fixed-Size Arrays} *) module type FIXED_ARRAY = sig - type 'a t + type +'a t val create : 'a -> 'a t val length_log : int val length : int (* 2 power length_log *)