From 3d035e05cdc9dbf8fa95676ee0a2164e84c804e9 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 9 Sep 2015 23:13:56 +0200 Subject: [PATCH] wip: fix `CCWBTree.{split,merge}`; add tests --- src/data/CCWBTree.ml | 95 ++++++++++++++++++++++++++++++++++--------- src/data/CCWBTree.mli | 11 ++++- 2 files changed, 86 insertions(+), 20 deletions(-) diff --git a/src/data/CCWBTree.ml b/src/data/CCWBTree.ml index 99aee13b..b4fbda61 100644 --- a/src/data/CCWBTree.ml +++ b/src/data/CCWBTree.ml @@ -68,7 +68,15 @@ module type S = sig val merge : (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t (** Similar to {!Map.S.merge} *) - (* TODO: compare, equal *) + val extract_min : 'a t -> key * 'a * 'a t + (** [extract_min m] returns [k, v, m'] where [k,v] is the pair with the + smaller key in [m], and [m'] does not contain [k]. + @raise Not_found if the map is empty *) + + val extract_max : 'a t -> key * 'a * 'a t + (** [extract_max m] returns [k, v, m'] where [k,v] is the pair with the + highest key in [m], and [m'] does not contain [k]. + @raise Not_found if the map is empty *) val choose : 'a t -> (key * 'a) option @@ -101,6 +109,7 @@ module type S = sig val print : key printer -> 'a printer -> 'a t printer (**/**) + val node_ : key -> 'a -> 'a t -> 'a t -> 'a t val balanced : _ t -> bool (**/**) end @@ -152,7 +161,7 @@ module MakeFull(K : KEY) : S with type key = K.t = struct delta × (weight l + 1) ≥ weight r + 1 *) let is_balanced l r = - 5 * (weight l + 1) >= (weight r + 1) * 2 + 5 * (weight l + 1) >= 2 * (weight r + 1) (* gamma = 3/2 weight l + 1 < gamma × (weight r + 1) *) @@ -242,19 +251,19 @@ module MakeFull(K : KEY) : S with type key = K.t = struct *) (* extract min binding of the tree *) - let rec extract_min_ m = match m with + let rec extract_min m = match m with | E -> assert false | N (k, v, E, r, _) -> k, v, r | N (k, v, l, r, _) -> - let k', v', l' = extract_min_ l in + let k', v', l' = extract_min l in k', v', balance_l k v l' r (* extract max binding of the tree *) - let rec extract_max_ m = match m with + let rec extract_max m = match m with | E -> assert false | N (k, v, l, E, _) -> k, v, l | N (k, v, l, r, _) -> - let k', v', r' = extract_max_ r in + let k', v', r' = extract_max r in k', v', balance_r k v l r' let rec remove k m = match m with @@ -271,11 +280,11 @@ module MakeFull(K : KEY) : S with type key = K.t = struct then (* remove max element of [l] and put it at the root, then rebalance towards the left if needed *) - let k', v', l' = extract_max_ l in + let k', v', l' = extract_max l in balance_l k' v' l' r else (* remove min element of [r] and rebalance *) - let k', v', r' = extract_min_ r in + let k', v', r' = extract_min r in balance_r k' v' l r' end | n when n<0 -> balance_l k' v' (remove k l) r @@ -300,8 +309,6 @@ module MakeFull(K : KEY) : S with type key = K.t = struct | Some _, None -> remove k m | _, Some v -> add k v m - (* TODO union, intersection *) - let rec nth_exn i m = match m with | E -> raise Not_found | N (k, v, l, r, w) -> @@ -352,6 +359,14 @@ module MakeFull(K : KEY) : S with type key = K.t = struct if w=0 then raise Not_found; nth_exn (Random.State.int st w) m + (* make a node (k,v,l,r) but balances on whichever side requires it *) + let node_shallow_ k v l r = + if is_balanced l r + then if is_balanced r l + then mk_node_ k v l r + else balance_r k v l r + else balance_l k v l r + (* assume keys of [l] are smaller than [k] and [k] smaller than keys of [r], but do not assume anything about weights. returns a tree with l, r, and (k,v) *) @@ -360,24 +375,25 @@ module MakeFull(K : KEY) : S with type key = K.t = struct | E, o | o, E -> add k v o | N (kl, vl, ll, lr, wl), N (kr, vr, rl, rr, wr) -> - if is_balanced l r && is_balanced r l + let left = is_balanced l r in + if left && is_balanced r l then mk_node_ k v l r - else if wl <= wr - then balance_l kr vr (node_ k v l rl) rr - else balance_r kl vl ll (node_ k v lr r) + else if not left + then node_shallow_ kr vr (node_ k v l rl) rr + else node_shallow_ kl vl ll (node_ k v lr r) (* join two trees, assuming all keys of [l] are smaller than keys of [r] *) let join_ l r = match l, r with | E, E -> E - | E, _ -> r - | _, E -> l + | E, o -> o + | o, E -> o | N _, N _ -> if weight l <= weight r then - let k, v, r' = extract_min_ r in + let k, v, r' = extract_min r in node_ k v l r' else - let k, v, l' = extract_max_ l in + let k, v, l' = extract_max l in node_ k v l' r (* if [o_v = Some v], behave like [mk_node k v l r] @@ -398,6 +414,21 @@ module MakeFull(K : KEY) : S with type key = K.t = struct let rl, o, rr = split k r in join_ l rl, o, rr + (*$Q & ~small:List.length + Q.(list (pair small_int small_int)) ( fun lst -> \ + let module M = Make(CCInt) in \ + let lst = CCList.Set.uniq ~eq:(CCFun.compose_binop fst (=)) lst in \ + let m = M.of_list lst in \ + List.for_all (fun (k,v) -> \ + let l, v', r = M.split k m in \ + v' = Some v && \ + (M.to_seq l |> Sequence.for_all (fun (k',_) -> k' < k)) &&\ + (M.to_seq r |> Sequence.for_all (fun (k',_) -> k' > k)) &&\ + M.balanced m && \ + M.cardinal l + M.cardinal r + 1 = List.length lst) \ + lst) + *) + let rec merge f a b = match a, b with | E, E -> E | E, N (k, v, l, r, _) -> @@ -410,7 +441,8 @@ module MakeFull(K : KEY) : S with type key = K.t = struct if K.compare k1 k2 = 0 then (* easy case *) - mk_node_or_join_ k1 (f k1 (Some v1) (Some v2)) (merge f l1 l2) (merge f r1 r2) + mk_node_or_join_ k1 (f k1 (Some v1) (Some v2)) + (merge f l1 l2) (merge f r1 r2) else if w1 <= w2 then let l1', v1', r1' = split k2 a in @@ -421,6 +453,31 @@ module MakeFull(K : KEY) : S with type key = K.t = struct mk_node_or_join_ k1 (f k1 (Some v1) v2') (merge f l1 l2') (merge f r1 r2') + (*$R + let module M = Make(CCInt) in + let m1 = M.of_list [1, 1; 2, 2; 4, 4] in + let m2 = M.of_list [1, 1; 3, 3; 4, 4; 7, 7] in + let m = M.merge (fun k -> CCOpt.map2 (+)) m1 m2 in + assert_bool "balanced" (M.balanced m); + assert_equal + ~cmp:(CCList.equal (CCPair.equal CCInt.equal CCInt.equal)) + ~printer:CCFormat.(to_string (list (pair int int))) + [1, 2; 4, 8] + (M.to_list m |> List.sort Pervasives.compare) + *) + + (*$Q & ~small:(fun (l1,l2) -> List.length l1 + List.length l2) + Q.(let p = list (pair small_int small_int) in pair p p) (fun (l1, l2) -> \ + let module M = Make(CCInt) in \ + let eq x y = fst x = fst y in \ + let l1 = CCList.Set.uniq ~eq l1 and l2 = CCList.Set.uniq ~eq l2 in \ + let m1 = M.of_list l1 and m2 = M.of_list l2 in \ + let m = M.merge (fun _ v1 v2 -> match v1 with \ + | None -> v2 | Some _ as r -> r) m1 m2 in \ + List.for_all (fun (k,v) -> M.get_exn k m = v) l1 && \ + List.for_all (fun (k,v) -> M.mem k m1 || M.get_exn k m = v) l2) + *) + let cardinal m = fold (fun acc _ _ -> acc+1) 0 m let add_list m l = List.fold_left (fun acc (k,v) -> add k v acc) m l diff --git a/src/data/CCWBTree.mli b/src/data/CCWBTree.mli index 912d902f..7181a28d 100644 --- a/src/data/CCWBTree.mli +++ b/src/data/CCWBTree.mli @@ -73,7 +73,15 @@ module type S = sig val merge : (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t (** Similar to {!Map.S.merge} *) - (* TODO: compare, equal *) + val extract_min : 'a t -> key * 'a * 'a t + (** [extract_min m] returns [k, v, m'] where [k,v] is the pair with the + smaller key in [m], and [m'] does not contain [k]. + @raise Not_found if the map is empty *) + + val extract_max : 'a t -> key * 'a * 'a t + (** [extract_max m] returns [k, v, m'] where [k,v] is the pair with the + highest key in [m], and [m'] does not contain [k]. + @raise Not_found if the map is empty *) val choose : 'a t -> (key * 'a) option @@ -106,6 +114,7 @@ module type S = sig val print : key printer -> 'a printer -> 'a t printer (**/**) + val node_ : key -> 'a -> 'a t -> 'a t -> 'a t val balanced : _ t -> bool (**/**) end