diff --git a/src/data/CCIntMap.ml b/src/data/CCIntMap.ml index 26fba781..76464bf6 100644 --- a/src/data/CCIntMap.ml +++ b/src/data/CCIntMap.ml @@ -440,8 +440,132 @@ let rec inter f a b = equal ~eq:(=) (inter f (inter f m1 m2) m3) (inter f m1 (inter f m2 m3))) *) +let rec disjoint_union_ t1 t2 : _ t = match t1, t2 with + | E, o | o, E -> o + | L (k,v), o + | o, L(k,v) -> insert_ (fun ~old:_ _ -> assert false) k v o + | N (p1,m1,l1,r1), N(p2,m2,l2,r2) -> + if p1 = p2 && Bit.equal m1 m2 then ( + mk_node_ p1 m1 (disjoint_union_ l1 l2) (disjoint_union_ r1 r2) + ) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then ( + if Bit.is_0 p2 ~bit:m1 + then mk_node_ p1 m1 (disjoint_union_ l1 t2) r1 + else mk_node_ p1 m1 l1 (disjoint_union_ r1 t2) + ) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then ( + if Bit.is_0 p1 ~bit:m2 + then mk_node_ p2 m2 (disjoint_union_ t1 l2) r2 + else mk_node_ p2 m2 l2 (disjoint_union_ t1 r2) + ) else ( + join_ t1 p1 t2 p2 + ) -(** {2 Whole-collection operations} *) +let rec filter f m = match m with + | E -> E + | L (k,v) -> + if f k v then m else E + | N (_,_,l,r) -> + disjoint_union_ (filter f l) (filter f r) + +(*$QR + Q.(pair (fun2 Observable.int Observable.int bool) (small_list (pair int int))) (fun (f,l) -> + let QCheck.Fun(_,f) = f in + _list_uniq (List.filter (fun (x,y) -> f x y) l) = + (_list_uniq @@ to_list @@ filter f @@ of_list l) + ) +*) + +let rec filter_map f m = match m with + | E -> E + | L (k,v) -> + begin match f k v with + | None -> E + | Some v' -> L(k,v') + end + | N (_,_,l,r) -> + disjoint_union_ (filter_map f l) (filter_map f r) + +(*$QR + Q.(pair (fun2 Observable.int Observable.int @@ option bool) (small_list (pair int int))) (fun (f,l) -> + let QCheck.Fun(_,f) = f in + _list_uniq (CCList.filter_map (fun (x,y) -> CCOpt.map (CCPair.make x) @@ f x y) l) = + (_list_uniq @@ to_list @@ filter_map f @@ of_list l) + ) +*) + +let rec merge ~f t1 t2 : _ t = + let merge1 t = + filter_map (fun k v -> f k (`Left v)) t + and merge2 t = + filter_map (fun k v -> f k (`Right v)) t + and add_some k opt m = match opt with + | None -> m + | Some v -> insert_ (fun ~old:_ _ -> assert false) k v m + in + match t1, t2 with + | E, o -> merge2 o + | o, E -> merge1 o + | L (k, v), o -> + let others = merge2 (remove k o) in + add_some k + (try f k (`Both (v,find_exn k o)) + with Not_found -> f k (`Left v)) others + | o, L (k, v) -> + let others = merge1 (remove k o) in + add_some k + (try f k (`Both (find_exn k o,v)) + with Not_found -> f k (`Right v)) others + | N (p1, m1, l1, r1), N (p2, m2, l2, r2) -> + if p1 = p2 && Bit.equal m1 m2 then ( + mk_node_ p1 m1 (merge ~f l1 l2) (merge ~f r1 r2) + ) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then ( + if Bit.is_0 p2 ~bit:m1 + then mk_node_ p1 m1 (merge ~f l1 t2) (merge1 r1) + else mk_node_ p1 m1 (merge1 l1) (merge ~f r1 t2) + ) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then ( + if Bit.is_0 p1 ~bit:m2 + then mk_node_ p2 m2 (merge ~f t1 l2) (merge2 r2) + else mk_node_ p2 m2 (merge2 l2) (merge ~f t1 r2) + ) else ( + join_ (merge1 t1) p1 (merge2 t2) p2 + ) + +(*$inject + let merge_union _x o = match o with + | `Left v | `Right v | `Both (v,_) -> Some v + let merge_inter _x o = match o with + | `Left _ | `Right _ -> None + | `Both (v,_) -> Some v +*) + +(*$QR + Q.(let p = small_list (pair small_int small_int) in pair p p) (fun (l1,l2) -> + check_invariants + (merge ~f:merge_union (of_list l1) (of_list l2))) + *) + +(*$QR + Q.(let p = small_list (pair small_int small_int) in pair p p) (fun (l1,l2) -> + check_invariants + (merge ~f:merge_inter (of_list l1) (of_list l2))) + *) + +(*$QR + Q.(let p = small_list (pair small_int unit) in pair p p) (fun (l1,l2) -> + let l1 = _list_uniq l1 and l2 = _list_uniq l2 in + equal Pervasives.(=) + (union (fun _ v1 _ -> v1) (of_list l1) (of_list l2)) + (merge ~f:merge_union (of_list l1) (of_list l2))) + *) + +(*$QR + Q.(let p = small_list (pair small_int unit) in pair p p) (fun (l1,l2) -> + let l1 = _list_uniq l1 and l2 = _list_uniq l2 in + equal Pervasives.(=) + (inter (fun _ v1 _ -> v1) (of_list l1) (of_list l2)) + (merge ~f:merge_inter (of_list l1) (of_list l2))) + *) + +(** {2 Conversions} *) type 'a sequence = ('a -> unit) -> unit type 'a gen = unit -> 'a option diff --git a/src/data/CCIntMap.mli b/src/data/CCIntMap.mli index 04874954..4cd15d9a 100644 --- a/src/data/CCIntMap.mli +++ b/src/data/CCIntMap.mli @@ -37,6 +37,14 @@ val compare : cmp:('a -> 'a -> int) -> 'a t -> 'a t -> int val update : int -> ('a option -> 'a option) -> 'a t -> 'a t +val filter : (int -> 'a -> bool) -> 'a t -> 'a t +(** Filter values using the given predicate + @since NEXT_RELEASE *) + +val filter_map : (int -> 'a -> 'b option) -> 'a t -> 'b t +(** Filter-map values using the given function + @since NEXT_RELEASE *) + val cardinal : _ t -> int (** Number of bindings in the map. Linear time. *) @@ -59,6 +67,15 @@ val union : (int -> 'a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t val inter : (int -> 'a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t +val merge : + f:(int -> [`Left of 'a | `Right of 'b | `Both of 'a * 'b] -> 'c option) -> + 'a t -> 'b t -> 'c t +(** [merge ~f m1 m2] merges [m1] and [m2] together, calling [f] once on every + key that occurs in at least one of [m1] and [m2]. + if [f k binding = Some c] then [k -> c] is part of the result, + else [k] is not part of the result. + @since NEXT_RELEASE *) + (** {2 Whole-collection operations} *) type 'a sequence = ('a -> unit) -> unit