mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-07 19:55:31 -05:00
wip: fix CCWBTree.{split,merge}; add tests
This commit is contained in:
parent
5e5d192448
commit
3d035e05cd
2 changed files with 86 additions and 20 deletions
|
|
@ -68,7 +68,15 @@ module type S = sig
|
|||
val merge : (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t
|
||||
(** Similar to {!Map.S.merge} *)
|
||||
|
||||
(* TODO: compare, equal *)
|
||||
val extract_min : 'a t -> key * 'a * 'a t
|
||||
(** [extract_min m] returns [k, v, m'] where [k,v] is the pair with the
|
||||
smaller key in [m], and [m'] does not contain [k].
|
||||
@raise Not_found if the map is empty *)
|
||||
|
||||
val extract_max : 'a t -> key * 'a * 'a t
|
||||
(** [extract_max m] returns [k, v, m'] where [k,v] is the pair with the
|
||||
highest key in [m], and [m'] does not contain [k].
|
||||
@raise Not_found if the map is empty *)
|
||||
|
||||
val choose : 'a t -> (key * 'a) option
|
||||
|
||||
|
|
@ -101,6 +109,7 @@ module type S = sig
|
|||
val print : key printer -> 'a printer -> 'a t printer
|
||||
|
||||
(**/**)
|
||||
val node_ : key -> 'a -> 'a t -> 'a t -> 'a t
|
||||
val balanced : _ t -> bool
|
||||
(**/**)
|
||||
end
|
||||
|
|
@ -152,7 +161,7 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
delta × (weight l + 1) ≥ weight r + 1
|
||||
*)
|
||||
let is_balanced l r =
|
||||
5 * (weight l + 1) >= (weight r + 1) * 2
|
||||
5 * (weight l + 1) >= 2 * (weight r + 1)
|
||||
|
||||
(* gamma = 3/2
|
||||
weight l + 1 < gamma × (weight r + 1) *)
|
||||
|
|
@ -242,19 +251,19 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
*)
|
||||
|
||||
(* extract min binding of the tree *)
|
||||
let rec extract_min_ m = match m with
|
||||
let rec extract_min m = match m with
|
||||
| E -> assert false
|
||||
| N (k, v, E, r, _) -> k, v, r
|
||||
| N (k, v, l, r, _) ->
|
||||
let k', v', l' = extract_min_ l in
|
||||
let k', v', l' = extract_min l in
|
||||
k', v', balance_l k v l' r
|
||||
|
||||
(* extract max binding of the tree *)
|
||||
let rec extract_max_ m = match m with
|
||||
let rec extract_max m = match m with
|
||||
| E -> assert false
|
||||
| N (k, v, l, E, _) -> k, v, l
|
||||
| N (k, v, l, r, _) ->
|
||||
let k', v', r' = extract_max_ r in
|
||||
let k', v', r' = extract_max r in
|
||||
k', v', balance_r k v l r'
|
||||
|
||||
let rec remove k m = match m with
|
||||
|
|
@ -271,11 +280,11 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
then
|
||||
(* remove max element of [l] and put it at the root,
|
||||
then rebalance towards the left if needed *)
|
||||
let k', v', l' = extract_max_ l in
|
||||
let k', v', l' = extract_max l in
|
||||
balance_l k' v' l' r
|
||||
else
|
||||
(* remove min element of [r] and rebalance *)
|
||||
let k', v', r' = extract_min_ r in
|
||||
let k', v', r' = extract_min r in
|
||||
balance_r k' v' l r'
|
||||
end
|
||||
| n when n<0 -> balance_l k' v' (remove k l) r
|
||||
|
|
@ -300,8 +309,6 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
| Some _, None -> remove k m
|
||||
| _, Some v -> add k v m
|
||||
|
||||
(* TODO union, intersection *)
|
||||
|
||||
let rec nth_exn i m = match m with
|
||||
| E -> raise Not_found
|
||||
| N (k, v, l, r, w) ->
|
||||
|
|
@ -352,6 +359,14 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
if w=0 then raise Not_found;
|
||||
nth_exn (Random.State.int st w) m
|
||||
|
||||
(* make a node (k,v,l,r) but balances on whichever side requires it *)
|
||||
let node_shallow_ k v l r =
|
||||
if is_balanced l r
|
||||
then if is_balanced r l
|
||||
then mk_node_ k v l r
|
||||
else balance_r k v l r
|
||||
else balance_l k v l r
|
||||
|
||||
(* assume keys of [l] are smaller than [k] and [k] smaller than keys of [r],
|
||||
but do not assume anything about weights.
|
||||
returns a tree with l, r, and (k,v) *)
|
||||
|
|
@ -360,24 +375,25 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
| E, o
|
||||
| o, E -> add k v o
|
||||
| N (kl, vl, ll, lr, wl), N (kr, vr, rl, rr, wr) ->
|
||||
if is_balanced l r && is_balanced r l
|
||||
let left = is_balanced l r in
|
||||
if left && is_balanced r l
|
||||
then mk_node_ k v l r
|
||||
else if wl <= wr
|
||||
then balance_l kr vr (node_ k v l rl) rr
|
||||
else balance_r kl vl ll (node_ k v lr r)
|
||||
else if not left
|
||||
then node_shallow_ kr vr (node_ k v l rl) rr
|
||||
else node_shallow_ kl vl ll (node_ k v lr r)
|
||||
|
||||
(* join two trees, assuming all keys of [l] are smaller than keys of [r] *)
|
||||
let join_ l r = match l, r with
|
||||
| E, E -> E
|
||||
| E, _ -> r
|
||||
| _, E -> l
|
||||
| E, o -> o
|
||||
| o, E -> o
|
||||
| N _, N _ ->
|
||||
if weight l <= weight r
|
||||
then
|
||||
let k, v, r' = extract_min_ r in
|
||||
let k, v, r' = extract_min r in
|
||||
node_ k v l r'
|
||||
else
|
||||
let k, v, l' = extract_max_ l in
|
||||
let k, v, l' = extract_max l in
|
||||
node_ k v l' r
|
||||
|
||||
(* if [o_v = Some v], behave like [mk_node k v l r]
|
||||
|
|
@ -398,6 +414,21 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
let rl, o, rr = split k r in
|
||||
join_ l rl, o, rr
|
||||
|
||||
(*$Q & ~small:List.length
|
||||
Q.(list (pair small_int small_int)) ( fun lst -> \
|
||||
let module M = Make(CCInt) in \
|
||||
let lst = CCList.Set.uniq ~eq:(CCFun.compose_binop fst (=)) lst in \
|
||||
let m = M.of_list lst in \
|
||||
List.for_all (fun (k,v) -> \
|
||||
let l, v', r = M.split k m in \
|
||||
v' = Some v && \
|
||||
(M.to_seq l |> Sequence.for_all (fun (k',_) -> k' < k)) &&\
|
||||
(M.to_seq r |> Sequence.for_all (fun (k',_) -> k' > k)) &&\
|
||||
M.balanced m && \
|
||||
M.cardinal l + M.cardinal r + 1 = List.length lst) \
|
||||
lst)
|
||||
*)
|
||||
|
||||
let rec merge f a b = match a, b with
|
||||
| E, E -> E
|
||||
| E, N (k, v, l, r, _) ->
|
||||
|
|
@ -410,7 +441,8 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
if K.compare k1 k2 = 0
|
||||
then
|
||||
(* easy case *)
|
||||
mk_node_or_join_ k1 (f k1 (Some v1) (Some v2)) (merge f l1 l2) (merge f r1 r2)
|
||||
mk_node_or_join_ k1 (f k1 (Some v1) (Some v2))
|
||||
(merge f l1 l2) (merge f r1 r2)
|
||||
else if w1 <= w2
|
||||
then
|
||||
let l1', v1', r1' = split k2 a in
|
||||
|
|
@ -421,6 +453,31 @@ module MakeFull(K : KEY) : S with type key = K.t = struct
|
|||
mk_node_or_join_ k1 (f k1 (Some v1) v2')
|
||||
(merge f l1 l2') (merge f r1 r2')
|
||||
|
||||
(*$R
|
||||
let module M = Make(CCInt) in
|
||||
let m1 = M.of_list [1, 1; 2, 2; 4, 4] in
|
||||
let m2 = M.of_list [1, 1; 3, 3; 4, 4; 7, 7] in
|
||||
let m = M.merge (fun k -> CCOpt.map2 (+)) m1 m2 in
|
||||
assert_bool "balanced" (M.balanced m);
|
||||
assert_equal
|
||||
~cmp:(CCList.equal (CCPair.equal CCInt.equal CCInt.equal))
|
||||
~printer:CCFormat.(to_string (list (pair int int)))
|
||||
[1, 2; 4, 8]
|
||||
(M.to_list m |> List.sort Pervasives.compare)
|
||||
*)
|
||||
|
||||
(*$Q & ~small:(fun (l1,l2) -> List.length l1 + List.length l2)
|
||||
Q.(let p = list (pair small_int small_int) in pair p p) (fun (l1, l2) -> \
|
||||
let module M = Make(CCInt) in \
|
||||
let eq x y = fst x = fst y in \
|
||||
let l1 = CCList.Set.uniq ~eq l1 and l2 = CCList.Set.uniq ~eq l2 in \
|
||||
let m1 = M.of_list l1 and m2 = M.of_list l2 in \
|
||||
let m = M.merge (fun _ v1 v2 -> match v1 with \
|
||||
| None -> v2 | Some _ as r -> r) m1 m2 in \
|
||||
List.for_all (fun (k,v) -> M.get_exn k m = v) l1 && \
|
||||
List.for_all (fun (k,v) -> M.mem k m1 || M.get_exn k m = v) l2)
|
||||
*)
|
||||
|
||||
let cardinal m = fold (fun acc _ _ -> acc+1) 0 m
|
||||
|
||||
let add_list m l = List.fold_left (fun acc (k,v) -> add k v acc) m l
|
||||
|
|
|
|||
|
|
@ -73,7 +73,15 @@ module type S = sig
|
|||
val merge : (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t
|
||||
(** Similar to {!Map.S.merge} *)
|
||||
|
||||
(* TODO: compare, equal *)
|
||||
val extract_min : 'a t -> key * 'a * 'a t
|
||||
(** [extract_min m] returns [k, v, m'] where [k,v] is the pair with the
|
||||
smaller key in [m], and [m'] does not contain [k].
|
||||
@raise Not_found if the map is empty *)
|
||||
|
||||
val extract_max : 'a t -> key * 'a * 'a t
|
||||
(** [extract_max m] returns [k, v, m'] where [k,v] is the pair with the
|
||||
highest key in [m], and [m'] does not contain [k].
|
||||
@raise Not_found if the map is empty *)
|
||||
|
||||
val choose : 'a t -> (key * 'a) option
|
||||
|
||||
|
|
@ -106,6 +114,7 @@ module type S = sig
|
|||
val print : key printer -> 'a printer -> 'a t printer
|
||||
|
||||
(**/**)
|
||||
val node_ : key -> 'a -> 'a t -> 'a t -> 'a t
|
||||
val balanced : _ t -> bool
|
||||
(**/**)
|
||||
end
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue