diff --git a/src/data/CCWBTree.ml b/src/data/CCWBTree.ml index 8b395062..4237df32 100644 --- a/src/data/CCWBTree.ml +++ b/src/data/CCWBTree.ml @@ -221,14 +221,44 @@ module MakeFull(K : KEY) : S with type key = K.t = struct M.cardinal m = List.length l) *) + (* extract max binding of the tree *) + let rec extract_max_ m = match m with + | E -> assert false + | N (k, v, l, E, _) -> k, v, l + | N (k, v, l, r, _) -> + let k', v', r' = extract_max_ r in + k', v', balance_r k v l r' + let rec remove k m = match m with | E -> E | N (k', v', l, r, _) -> match K.compare k k' with - | 0 -> assert false (* TODO fix using a paper *) + | 0 -> + begin match l, r with + | E, E -> E + | E, o + | o, E -> o + | _, _ -> + (* remove max element of [l] and put it at the root, + then rebalance towards the left if needed *) + let k', v', l' = extract_max_ l in + balance_l k' v' l' r + end | n when n<0 -> balance_l k' v' (remove k l) r | _ -> balance_r k' v' l (remove k r) + (*$Q & ~small:List.length + Q.(list (pair small_int small_int)) (fun l -> \ + let module M = Make(CCInt) in \ + let m = M.of_list l in \ + List.for_all (fun (k,_) -> \ + M.mem k m && (let m' = M.remove k m in not (M.mem k m'))) l) + Q.(list (pair small_int small_int)) (fun l -> \ + let module M = Make(CCInt) in \ + let m = M.of_list l in \ + List.for_all (fun (k,_) -> let m' = M.remove k m in M.balanced m') l) + *) + (* TODO union, intersection *) let rec nth_exn i m = match m with