diff --git a/README.md b/README.md index c03a7644..e7bd23ee 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ Documentation [here](http://cedeela.fr/~simon/software/containers). - `CCHashconsedSet`, a set structure with sharing of sub-structures - `CCGraph`, a small collection of graph algorithms - `CCBitField`, a type-safe implementation of bitfields that fit in `int` +- `CCWBTree`, a weight-balanced tree, implementing a map interface ### Containers.io diff --git a/_oasis b/_oasis index c50b3f31..2847d801 100644 --- a/_oasis +++ b/_oasis @@ -85,7 +85,7 @@ Library "containers_data" CCPersistentHashtbl, CCDeque, CCFQueue, CCBV, CCMixtbl, CCMixmap, CCRingBuffer, CCIntMap, CCPersistentArray, CCMixset, CCHashconsedSet, CCGraph, CCHashSet, CCBitField, - CCHashTrie, CCBloom + CCHashTrie, CCBloom, CCWBTree BuildDepends: bytes FindlibParent: containers FindlibName: data diff --git a/benchs/run_benchs.ml b/benchs/run_benchs.ml index 52795f40..e10646d2 100644 --- a/benchs/run_benchs.ml +++ b/benchs/run_benchs.ml @@ -287,6 +287,17 @@ module Tbl = struct let module U = MUT_OF_IMMUT(T) in (module U : MUT with type key = a) + let wbt : type a. a key_type -> (module MUT with type key = a) + = fun k -> + let (module K), name = arg_make k in + let module T = struct + let name = sprintf "wbt(%s)" name + include CCWBTree.Make(K) + let find = get_exn + end in + let module U = MUT_OF_IMMUT(T) in + (module U : MUT with type key = a) + let flat_hashtbl = let module T = CCFlatHashtbl.Make(CCInt) in let module U = struct @@ -328,6 +339,7 @@ module Tbl = struct ; persistent_hashtbl ; poly_hashtbl ; map Int + ; wbt Int ; flat_hashtbl ; hashtrie Int ; hamt Int @@ -336,6 +348,7 @@ module Tbl = struct let modules_string = [ hashtbl_make Str ; map Str + ; wbt Str ; hashtrie Str ; hamt Str ] diff --git a/doc/intro.txt b/doc/intro.txt index 16b9db22..0cfd1dbf 100644 --- a/doc/intro.txt +++ b/doc/intro.txt @@ -81,6 +81,7 @@ CCPersistentArray CCPersistentHashtbl CCRingBuffer CCTrie +CCWBTree } {4 Containers.io} diff --git a/src/data/CCWBTree.ml b/src/data/CCWBTree.ml new file mode 100644 index 00000000..fe4cd5f5 --- /dev/null +++ b/src/data/CCWBTree.ml @@ -0,0 +1,332 @@ +(* This file is free software, part of containers. See file "license" for more details. *) + +(** {1 Weight-Balanced Tree} *) + +type 'a sequence = ('a -> unit) -> unit +type 'a gen = unit -> 'a option +type 'a printer = Format.formatter -> 'a -> unit + +module type ORD = sig + type t + val compare : t -> t -> int +end + +module type KEY = sig + include ORD + val weight : t -> int +end + +(** {2 Signature} *) + +module type S = sig + type key + + type 'a t + + val empty : 'a t + + val mem : key -> _ t -> bool + + val get : key -> 'a t -> 'a option + + val get_exn : key -> 'a t -> 'a + (** @raise Not_found if the key is not present *) + + val nth : int -> 'a t -> (key * 'a) option + (** [nth i m] returns the [i]-th [key, value] in the ascending + order. Complexity is [O(log (cardinal m))] *) + + val nth_exn : int -> 'a t -> key * 'a + (** @raise Not_found if the index is invalid *) + + val add : key -> 'a -> 'a t -> 'a t + + val remove : key -> 'a t -> 'a t + + val cardinal : _ t -> int + + val weight : _ t -> int + + val fold : ('b -> key -> 'a -> 'b) -> 'b -> 'a t -> 'b + + val iter : (key -> 'a -> unit) -> 'a t -> unit + + val choose : 'a t -> (key * 'a) option + + val choose_exn : 'a t -> key * 'a + (** @raise Not_found if the tree is empty *) + + val random_choose : Random.State.t -> 'a t -> key * 'a + (** Randomly choose a (key,value) pair within the tree, using weights + as probability weights + @raise Not_found if the tree is empty *) + + val add_list : 'a t -> (key * 'a) list -> 'a t + + val of_list : (key * 'a) list -> 'a t + + val to_list : 'a t -> (key * 'a) list + + val add_seq : 'a t -> (key * 'a) sequence -> 'a t + + val of_seq : (key * 'a) sequence -> 'a t + + val to_seq : 'a t -> (key * 'a) sequence + + val add_gen : 'a t -> (key * 'a) gen -> 'a t + + val of_gen : (key * 'a) gen -> 'a t + + val to_gen : 'a t -> (key * 'a) gen + + val print : key printer -> 'a printer -> 'a t printer + + (**/**) + val balanced : _ t -> bool + (**/**) +end + +module MakeFull(K : KEY) : S with type key = K.t = struct + type key = K.t + + type weight = int + + type 'a t = + | E + | N of key * 'a * 'a t * 'a t * weight + + let empty = E + + let rec get_exn k m = match m with + | E -> raise Not_found + | N (k', v, l, r, _) -> + match K.compare k k' with + | 0 -> v + | n when n<0 -> get_exn k l + | _ -> get_exn k r + + let get k m = + try Some (get_exn k m) + with Not_found -> None + + let mem k m = + try ignore (get_exn k m); true + with Not_found -> false + + let singleton k v = + N (k, v, E, E, K.weight k) + + let weight = function + | E -> 0 + | N (_, _, _, _, w) -> w + + (* balancing parameters *) + + (* delta=5/2 + delta × (weight l + 1) ≥ weight r + 1 + *) + let is_balanced l r = + 5 * (weight l + 1) >= (weight r + 1) * 2 + + (* gamma = 3/2 + weight l + 1 < gamma × (weight r + 1) *) + let is_single l r = + 2 * (weight l + 1) < 3 * (weight r + 1) + + (* debug function *) + let rec balanced = function + | E -> true + | N (_, _, l, r, _) -> + is_balanced l r && + is_balanced r l && + balanced l && + balanced r + + (* smart constructor *) + let mk_node_ k v l r = + N (k, v, l, r, weight l + weight r + K.weight k) + + let single_l k1 v1 t1 t2 = match t2 with + | E -> assert false + | N (k2, v2, t2, t3, _) -> + mk_node_ k2 v2 (mk_node_ k1 v1 t1 t2) t3 + + let double_l k1 v1 t1 t2 = match t2 with + | N (k2, v2, N (k3, v3, t2, t3, _), t4, _) -> + mk_node_ k3 v3 (mk_node_ k1 v1 t1 t2) (mk_node_ k2 v2 t3 t4) + | _ -> assert false + + let rotate_l k v l r = match r with + | E -> assert false + | N (_, _, rl, rr, _) -> + if is_single rl rr + then single_l k v l r + else double_l k v l r + + (* balance towards left *) + let balance_l k v l r = + if is_balanced l r then mk_node_ k v l r + else rotate_l k v l r + + let single_r k1 v1 t1 t2 = match t1 with + | E -> assert false + | N (k2, v2, t11, t12, _) -> + mk_node_ k2 v2 t11 (mk_node_ k1 v1 t12 t2) + + let double_r k1 v1 t1 t2 = match t1 with + | N (k2, v2, t11, N (k3, v3, t121, t122, _), _) -> + mk_node_ k3 v3 (mk_node_ k2 v2 t11 t121) (mk_node_ k1 v1 t122 t2) + | _ -> assert false + + let rotate_r k v l r = match l with + | E -> assert false + | N (_, _, ll, lr, _) -> + if is_single lr ll + then single_r k v l r + else double_r k v l r + + (* balance toward right *) + let balance_r k v l r = + if is_balanced r l then mk_node_ k v l r + else rotate_r k v l r + + let rec add k v m = match m with + | E -> singleton k v + | N (k', v', l, r, _) -> + match K.compare k k' with + | 0 -> mk_node_ k v l r + | n when n<0 -> balance_r k' v' (add k v l) r + | _ -> balance_l k' v' l (add k v r) + + (*$Q & ~small:List.length + Q.(list (pair small_int bool)) (fun l -> \ + let module M = Make(CCInt) in \ + let m = M.of_list l in \ + M.balanced m) + Q.(list (pair small_int small_int)) (fun l -> \ + let l = CCList.Set.uniq ~eq:(CCFun.compose_binop fst (=)) l in \ + let module M = Make(CCInt) in \ + let m = M.of_list l in \ + List.for_all (fun (k,v) -> M.get_exn k m = v) l) + Q.(list (pair small_int small_int)) (fun l -> \ + let l = CCList.Set.uniq ~eq:(CCFun.compose_binop fst (=)) l in \ + let module M = Make(CCInt) in \ + let m = M.of_list l in \ + M.cardinal m = List.length l) + *) + + let rec remove k m = match m with + | E -> E + | N (k', v', l, r, _) -> + match K.compare k k' with + | 0 -> assert false (* TODO fix using a paper *) + | n when n<0 -> balance_l k' v' (remove k l) r + | _ -> balance_r k' v' l (remove k r) + + (* TODO union, intersection *) + + let rec nth_exn i m = match m with + | E -> raise Not_found + | N (k, v, l, r, w) -> + let c = i - weight l in + match c with + | 0 -> k, v + | n when n<0 -> nth_exn i l (* search left *) + | _ -> + (* means c< K.weight k *) + if i None + + (*$T + let module M = Make(CCInt) in \ + let m = CCList.(0 -- 1000 |> map (fun i->i,i) |> M.of_list) in \ + List.for_all (fun i -> M.nth_exn i m = (i,i)) CCList.(0--1000) + *) + + let rec fold f acc m = match m with + | E -> acc + | N (k, v, l, r, _) -> + let acc = fold f acc l in + let acc = f acc k v in + fold f acc r + + let rec iter f m = match m with + | E -> () + | N (k, v, l, r, _) -> + iter f l; + f k v; + iter f r + + let choose_exn = function + | E -> raise Not_found + | N (k, v, _, _, _) -> k, v + + let choose = function + | E -> None + | N (k, v, _, _, _) -> Some (k,v) + + (* pick an index within [0.. weight m-1] and get the element with + this index *) + let random_choose st m = + let w = weight m in + if w=0 then raise Not_found; + nth_exn (Random.State.int st w) m + + let cardinal m = fold (fun acc _ _ -> acc+1) 0 m + + let add_list m l = List.fold_left (fun acc (k,v) -> add k v acc) m l + + let of_list l = add_list empty l + + let to_list m = fold (fun acc k v -> (k,v) :: acc) [] m + + let add_seq m seq = + let m = ref m in + seq (fun (k,v) -> m := add k v !m); + !m + + let of_seq s = add_seq empty s + + let to_seq m yield = iter (fun k v -> yield (k,v)) m + + let rec add_gen m g = match g() with + | None -> m + | Some (k,v) -> add_gen (add k v m) g + + let of_gen g = add_gen empty g + + let to_gen m = + let st = Stack.create () in + Stack.push m st; + let rec next() = + if Stack.is_empty st then None + else match Stack.pop st with + | E -> next () + | N (k, v, l, r, _) -> + Stack.push r st; + Stack.push l st; + Some (k,v) + in next + + let print pp_k pp_v fmt m = + let start = "[" and stop = "]" and arrow = "->" and sep = ","in + Format.pp_print_string fmt start; + let first = ref true in + iter + (fun k v -> + if !first then first := false else Format.pp_print_string fmt sep; + pp_k fmt k; + Format.pp_print_string fmt arrow; + pp_v fmt v; + Format.pp_print_cut fmt () + ) m; + Format.pp_print_string fmt stop +end + +module Make(X : ORD) = MakeFull(struct + include X + let weight _ = 1 +end) diff --git a/src/data/CCWBTree.mli b/src/data/CCWBTree.mli new file mode 100644 index 00000000..87eb975e --- /dev/null +++ b/src/data/CCWBTree.mli @@ -0,0 +1,99 @@ + +(* This file is free software, part of containers. See file "license" for more details. *) + +(** {1 Weight-Balanced Tree} + + {b status: experimental} + + @since NEXT_RELEASE *) + +type 'a sequence = ('a -> unit) -> unit +type 'a gen = unit -> 'a option +type 'a printer = Format.formatter -> 'a -> unit + +module type ORD = sig + type t + val compare : t -> t -> int +end + +module type KEY = sig + include ORD + val weight : t -> int +end + +(** {2 Signature} *) + +module type S = sig + type key + + type 'a t + + val empty : 'a t + + val mem : key -> _ t -> bool + + val get : key -> 'a t -> 'a option + + val get_exn : key -> 'a t -> 'a + (** @raise Not_found if the key is not present *) + + val nth : int -> 'a t -> (key * 'a) option + (** [nth i m] returns the [i]-th [key, value] in the ascending + order. Complexity is [O(log (cardinal m))] *) + + val nth_exn : int -> 'a t -> key * 'a + (** @raise Not_found if the index is invalid *) + + val add : key -> 'a -> 'a t -> 'a t + + val remove : key -> 'a t -> 'a t + + val cardinal : _ t -> int + + val weight : _ t -> int + + val fold : ('b -> key -> 'a -> 'b) -> 'b -> 'a t -> 'b + + val iter : (key -> 'a -> unit) -> 'a t -> unit + + val choose : 'a t -> (key * 'a) option + + val choose_exn : 'a t -> key * 'a + (** @raise Not_found if the tree is empty *) + + val random_choose : Random.State.t -> 'a t -> key * 'a + (** Randomly choose a (key,value) pair within the tree, using weights + as probability weights + @raise Not_found if the tree is empty *) + + val add_list : 'a t -> (key * 'a) list -> 'a t + + val of_list : (key * 'a) list -> 'a t + + val to_list : 'a t -> (key * 'a) list + + val add_seq : 'a t -> (key * 'a) sequence -> 'a t + + val of_seq : (key * 'a) sequence -> 'a t + + val to_seq : 'a t -> (key * 'a) sequence + + val add_gen : 'a t -> (key * 'a) gen -> 'a t + + val of_gen : (key * 'a) gen -> 'a t + + val to_gen : 'a t -> (key * 'a) gen + + val print : key printer -> 'a printer -> 'a t printer + + (**/**) + val balanced : _ t -> bool + (**/**) +end + +(** {2 Functor} *) + +module Make(X : ORD) : S with type key = X.t + +module MakeFull(X : KEY) : S with type key = X.t +(** Use the custom [X.weight] function *)