diff --git a/src/data/CCWBTree.ml b/src/data/CCWBTree.ml index 0292126f..99aee13b 100644 --- a/src/data/CCWBTree.ml +++ b/src/data/CCWBTree.ml @@ -60,6 +60,16 @@ module type S = sig val iter : (key -> 'a -> unit) -> 'a t -> unit + val split : key -> 'a t -> 'a t * 'a option * 'a t + (** [split k t] returns [l, o, r] where [l] is the part of the map + with keys smaller than [k], [r] has keys bigger than [k], + and [o = Some v] if [k, v] belonged to the map *) + + val merge : (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t + (** Similar to {!Map.S.merge} *) + + (* TODO: compare, equal *) + val choose : 'a t -> (key * 'a) option val choose_exn : 'a t -> key * 'a @@ -342,6 +352,75 @@ module MakeFull(K : KEY) : S with type key = K.t = struct if w=0 then raise Not_found; nth_exn (Random.State.int st w) m + (* assume keys of [l] are smaller than [k] and [k] smaller than keys of [r], + but do not assume anything about weights. + returns a tree with l, r, and (k,v) *) + let rec node_ k v l r = match l, r with + | E, E -> mk_node_ k v E E + | E, o + | o, E -> add k v o + | N (kl, vl, ll, lr, wl), N (kr, vr, rl, rr, wr) -> + if is_balanced l r && is_balanced r l + then mk_node_ k v l r + else if wl <= wr + then balance_l kr vr (node_ k v l rl) rr + else balance_r kl vl ll (node_ k v lr r) + + (* join two trees, assuming all keys of [l] are smaller than keys of [r] *) + let join_ l r = match l, r with + | E, E -> E + | E, _ -> r + | _, E -> l + | N _, N _ -> + if weight l <= weight r + then + let k, v, r' = extract_min_ r in + node_ k v l r' + else + let k, v, l' = extract_max_ l in + node_ k v l' r + + (* if [o_v = Some v], behave like [mk_node k v l r] + else behave like [join_ l r] *) + let mk_node_or_join_ k o_v l r = match o_v with + | None -> join_ l r + | Some v -> node_ k v l r + + let rec split k m = match m with + | E -> E, None, E + | N (k', v', l, r, _) -> + match K.compare k k' with + | 0 -> l, Some v', r + | n when n<0 -> + let ll, o, lr = split k l in + ll, o, join_ lr r + | _ -> + let rl, o, rr = split k r in + join_ l rl, o, rr + + let rec merge f a b = match a, b with + | E, E -> E + | E, N (k, v, l, r, _) -> + let v' = f k None (Some v) in + mk_node_or_join_ k v' (merge f E l) (merge f E r) + | N (k, v, l, r, _), E -> + let v' = f k (Some v) None in + mk_node_or_join_ k v' (merge f l E) (merge f r E) + | N (k1, v1, l1, r1, w1), N (k2, v2, l2, r2, w2) -> + if K.compare k1 k2 = 0 + then + (* easy case *) + mk_node_or_join_ k1 (f k1 (Some v1) (Some v2)) (merge f l1 l2) (merge f r1 r2) + else if w1 <= w2 + then + let l1', v1', r1' = split k2 a in + mk_node_or_join_ k2 (f k2 v1' (Some v2)) + (merge f l1' l2) (merge f r1' r2) + else + let l2', v2', r2' = split k1 b in + mk_node_or_join_ k1 (f k1 (Some v1) v2') + (merge f l1 l2') (merge f r1 r2') + let cardinal m = fold (fun acc _ _ -> acc+1) 0 m let add_list m l = List.fold_left (fun acc (k,v) -> add k v acc) m l diff --git a/src/data/CCWBTree.mli b/src/data/CCWBTree.mli index 6219c46f..912d902f 100644 --- a/src/data/CCWBTree.mli +++ b/src/data/CCWBTree.mli @@ -65,6 +65,16 @@ module type S = sig val iter : (key -> 'a -> unit) -> 'a t -> unit + val split : key -> 'a t -> 'a t * 'a option * 'a t + (** [split k t] returns [l, o, r] where [l] is the part of the map + with keys smaller than [k], [r] has keys bigger than [k], + and [o = Some v] if [k, v] belonged to the map *) + + val merge : (key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t + (** Similar to {!Map.S.merge} *) + + (* TODO: compare, equal *) + val choose : 'a t -> (key * 'a) option val choose_exn : 'a t -> key * 'a