add more functions to CCHashTrie

This commit is contained in:
Simon Cruanes 2015-09-05 12:35:13 +02:00
parent c9a4bbd75a
commit 6f388b5d3c
3 changed files with 97 additions and 18 deletions

View file

@ -4,6 +4,7 @@
(** {1 Hash Tries} *) (** {1 Hash Tries} *)
type 'a sequence = ('a -> unit) -> unit type 'a sequence = ('a -> unit) -> unit
type 'a gen = unit -> 'a option
type 'a printer = Format.formatter -> 'a -> unit type 'a printer = Format.formatter -> 'a -> unit
type 'a ktree = unit -> [`Nil | `Node of 'a * 'a ktree list] type 'a ktree = unit -> [`Nil | `Node of 'a * 'a ktree list]
@ -15,14 +16,13 @@ module type FIXED_ARRAY = sig
val length : int (* 2 power length_log *) val length : int (* 2 power length_log *)
val get : 'a t -> int -> 'a val get : 'a t -> int -> 'a
val set : 'a t -> int -> 'a -> 'a t val set : 'a t -> int -> 'a -> 'a t
val update : 'a t -> int -> ('a -> 'a) -> 'a t val set : mut:bool -> 'a t -> int -> 'a -> 'a t
val update : mut:bool -> 'a t -> int -> ('a -> 'a) -> 'a t
val remove : empty:'a -> 'a t -> int -> 'a t (* put back [empty] there *) val remove : empty:'a -> 'a t -> int -> 'a t (* put back [empty] there *)
val iter : ('a -> unit) -> 'a t -> unit val iter : ('a -> unit) -> 'a t -> unit
val fold : ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b val fold : ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b
end end
(* TODO: add update again, to call popcount only once *)
module type S = sig module type S = sig
module A : FIXED_ARRAY module A : FIXED_ARRAY
@ -47,8 +47,18 @@ module type S = sig
val remove : key -> 'a t -> 'a t val remove : key -> 'a t -> 'a t
val update : key -> ('a option -> 'a option) -> 'a t -> 'a t
(** [update k f m] calls [f (Some v)] if [get k m = Some v], [f None]
otherwise. Then, if [f] returns [Some v'] it binds [k] to [v'],
if [f] returns [None] it removes [k] *)
val cardinal : _ t -> int val cardinal : _ t -> int
val choose : 'a t -> (key * 'a) option
val choose_exn : 'a t -> key * 'a
(** @raise Not_found if not pair was found *)
val iter : (key -> 'a -> unit) -> 'a t -> unit val iter : (key -> 'a -> unit) -> 'a t -> unit
val fold : ('b -> key -> 'a -> 'b) -> 'b -> 'a t -> 'b val fold : ('b -> key -> 'a -> 'b) -> 'b -> 'a t -> 'b
@ -67,6 +77,12 @@ module type S = sig
val to_seq : 'a t -> (key * 'a) sequence val to_seq : 'a t -> (key * 'a) sequence
val add_gen : 'a t -> (key * 'a) gen -> 'a t
val of_gen : (key * 'a) gen -> 'a t
val to_gen : 'a t -> (key * 'a) gen
(** {6 IO} *) (** {6 IO} *)
val print : key printer -> 'a printer -> 'a t printer val print : key printer -> 'a printer -> 'a t printer
@ -97,12 +113,12 @@ module A32 : FIXED_ARRAY = struct
let get a i = Array.get a i let get a i = Array.get a i
let set a i x = let set ~mut a i x =
let a' = Array.copy a in let a' = if mut then a else Array.copy a in
a'.(i) <- x; a'.(i) <- x;
a' a'
let update a i f = set a i (f (get a i)) let update ~mut a i f = set ~mut a i (f (get a i))
let remove ~empty a i = let remove ~empty a i =
let a' = Array.copy a in let a' = Array.copy a in
@ -176,7 +192,7 @@ module A_SPARSE : FIXED_ARRAY = struct
let real_idx = popcount (a.bits land (idx- 1)) in let real_idx = popcount (a.bits land (idx- 1)) in
a.arr.(real_idx) a.arr.(real_idx)
let set a i x = let set ~mut a i x =
let idx = 1 lsl i in let idx = 1 lsl i in
let real_idx = popcount (a.bits land (idx -1)) in let real_idx = popcount (a.bits land (idx -1)) in
if a.bits land idx = 0 if a.bits land idx = 0
@ -193,12 +209,12 @@ module A_SPARSE : FIXED_ARRAY = struct
{a with bits; arr} {a with bits; arr}
) else ( ) else (
(* replace element at [real_idx] *) (* replace element at [real_idx] *)
let arr = Array.copy a.arr in let arr = if mut then a.arr else Array.copy a.arr in
arr.(real_idx) <- x; arr.(real_idx) <- x;
{a with arr} {a with arr}
) )
let update a i f = let update ~mut a i f =
let idx = 1 lsl i in let idx = 1 lsl i in
let real_idx = popcount (a.bits land (idx -1)) in let real_idx = popcount (a.bits land (idx -1)) in
if a.bits land idx = 0 if a.bits land idx = 0
@ -218,7 +234,7 @@ module A_SPARSE : FIXED_ARRAY = struct
) else ( ) else (
let x = f a.arr.(real_idx) in let x = f a.arr.(real_idx) in
(* replace element at [real_idx] *) (* replace element at [real_idx] *)
let arr = Array.copy a.arr in let arr = if mut then a.arr else Array.copy a.arr in
arr.(real_idx) <- x; arr.(real_idx) <- x;
{a with arr} {a with arr}
) )
@ -357,7 +373,7 @@ module Make(Key : KEY)
| N (leaf, a) -> | N (leaf, a) ->
if Hash.is_0 h if Hash.is_0 h
then N (add_list_ k v leaf, a) then N (add_list_ k v leaf, a)
else N (leaf, add_to_array_ k v ~h a) else N (leaf, add_to_array_ ~mut:false k v ~h a)
(* make an array containing a leaf, and insert (k,v) in it *) (* make an array containing a leaf, and insert (k,v) in it *)
and make_array_ ~leaf ~h_leaf:h' k v ~h = and make_array_ ~leaf ~h_leaf:h' k v ~h =
@ -368,21 +384,21 @@ module Make(Key : KEY)
(* put leaf in the right bucket *) (* put leaf in the right bucket *)
let i = Hash.rem h' in let i = Hash.rem h' in
let h'' = Hash.quotient h' in let h'' = Hash.quotient h' in
A.set a i (L (h'', leaf)), Nil A.set ~mut:true a i (L (h'', leaf)), Nil
in in
(* then add new node *) (* then add new node *)
let a, leaf = let a, leaf =
if Hash.is_0 h then a, add_list_ k v leaf if Hash.is_0 h then a, add_list_ k v leaf
else add_to_array_ k v ~h a, leaf else add_to_array_ ~mut:true k v ~h a, leaf
in in
N (leaf, a) N (leaf, a)
(* add k->v to [a] *) (* add k->v to [a] *)
and add_to_array_ k v ~h a = and add_to_array_ ~mut k v ~h a =
(* insert in a bucket *) (* insert in a bucket *)
let i = Hash.rem h in let i = Hash.rem h in
let h' = Hash.quotient h in let h' = Hash.quotient h in
A.update a i (fun x -> add_ k v ~h:h' x) A.update ~mut a i (fun x -> add_ k v ~h:h' x)
let add k v m = add_ k v ~h:(hash_ k) m let add k v m = add_ k v ~h:(hash_ k) m
@ -422,7 +438,7 @@ module Make(Key : KEY)
let new_t = remove_rec_ k ~h:h' (A.get a i) in let new_t = remove_rec_ k ~h:h' (A.get a i) in
if is_empty new_t if is_empty new_t
then leaf, A.remove ~empty:E a i (* remove sub-tree *) then leaf, A.remove ~empty:E a i (* remove sub-tree *)
else leaf, A.set a i new_t else leaf, A.set ~mut:false a i new_t
in in
if is_empty_list_ leaf && is_empty_arr_ a if is_empty_list_ leaf && is_empty_arr_ a
then E then E
@ -430,6 +446,15 @@ module Make(Key : KEY)
let remove k m = remove_rec_ k ~h:(hash_ k) m let remove k m = remove_rec_ k ~h:(hash_ k) m
let update k f m =
let h = hash_ k in
let opt_v = try Some (get_exn_ k ~h m) with Not_found -> None in
match opt_v, f opt_v with
| None, None -> m
| Some _, Some v
| None, Some v -> add_ k v ~h m
| Some _, None -> remove_rec_ k ~h m
let iter f t = let iter f t =
let rec aux = function let rec aux = function
| E -> () | E -> ()
@ -471,6 +496,42 @@ module Make(Key : KEY)
let to_seq m yield = iter (fun k v -> yield (k,v)) m let to_seq m yield = iter (fun k v -> yield (k,v)) m
let rec add_gen m g = match g() with
| None -> m
| Some (k,v) -> add_gen (add k v m) g
let of_gen g = add_gen empty g
(* traverse the tree by increasing hash order, where the order compares
hashes lexicographically by A.length_log-wide chunks of bits,
least-significant chunks first *)
let to_gen m =
let st = Stack.create() in
Stack.push m st;
let rec next() =
if Stack.is_empty st then None
else match Stack.pop st with
| E -> next ()
| S (_,k,v) -> Some (k,v)
| L (_, Nil) -> next()
| L (h, Cons(k,v,tl)) ->
Stack.push (L (h, tl)) st; (* tail *)
Some (k,v)
| N (l, a) ->
A.iter
(fun sub -> Stack.push sub st)
a;
Stack.push (L (Hash.zero, l)) st; (* leaf *)
next()
in
next
let choose m = to_gen m ()
let choose_exn m = match choose m with
| None -> raise Not_found
| Some (k,v) -> k, v
let print ppk ppv out m = let print ppk ppv out m =
let first = ref true in let first = ref true in
iter iter

View file

@ -17,6 +17,7 @@
*) *)
type 'a sequence = ('a -> unit) -> unit type 'a sequence = ('a -> unit) -> unit
type 'a gen = unit -> 'a option
type 'a printer = Format.formatter -> 'a -> unit type 'a printer = Format.formatter -> 'a -> unit
type 'a ktree = unit -> [`Nil | `Node of 'a * 'a ktree list] type 'a ktree = unit -> [`Nil | `Node of 'a * 'a ktree list]
@ -27,8 +28,8 @@ module type FIXED_ARRAY = sig
val length_log : int val length_log : int
val length : int (* 2 power length_log *) val length : int (* 2 power length_log *)
val get : 'a t -> int -> 'a val get : 'a t -> int -> 'a
val set : 'a t -> int -> 'a -> 'a t val set : mut:bool -> 'a t -> int -> 'a -> 'a t
val update : 'a t -> int -> ('a -> 'a) -> 'a t val update : mut:bool -> 'a t -> int -> ('a -> 'a) -> 'a t
val remove : empty:'a -> 'a t -> int -> 'a t (* put back [empty] there *) val remove : empty:'a -> 'a t -> int -> 'a t (* put back [empty] there *)
val iter : ('a -> unit) -> 'a t -> unit val iter : ('a -> unit) -> 'a t -> unit
val fold : ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b val fold : ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b
@ -59,8 +60,18 @@ module type S = sig
val remove : key -> 'a t -> 'a t val remove : key -> 'a t -> 'a t
val update : key -> ('a option -> 'a option) -> 'a t -> 'a t
(** [update k f m] calls [f (Some v)] if [get k m = Some v], [f None]
otherwise. Then, if [f] returns [Some v'] it binds [k] to [v'],
if [f] returns [None] it removes [k] *)
val cardinal : _ t -> int val cardinal : _ t -> int
val choose : 'a t -> (key * 'a) option
val choose_exn : 'a t -> key * 'a
(** @raise Not_found if not pair was found *)
val iter : (key -> 'a -> unit) -> 'a t -> unit val iter : (key -> 'a -> unit) -> 'a t -> unit
val fold : ('b -> key -> 'a -> 'b) -> 'b -> 'a t -> 'b val fold : ('b -> key -> 'a -> 'b) -> 'b -> 'a t -> 'b
@ -79,6 +90,12 @@ module type S = sig
val to_seq : 'a t -> (key * 'a) sequence val to_seq : 'a t -> (key * 'a) sequence
val add_gen : 'a t -> (key * 'a) gen -> 'a t
val of_gen : (key * 'a) gen -> 'a t
val to_gen : 'a t -> (key * 'a) gen
(** {6 IO} *) (** {6 IO} *)
val print : key printer -> 'a printer -> 'a t printer val print : key printer -> 'a printer -> 'a t printer

View file

@ -70,6 +70,7 @@ val fold : (int -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
val choose : 'a t -> (int * 'a) option val choose : 'a t -> (int * 'a) option
val choose_exn : 'a t -> int * 'a val choose_exn : 'a t -> int * 'a
(** @raise Not_found if not pair was found *)
val union : (int -> 'a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t val union : (int -> 'a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t