feat(intmap): add CCIntMap.{filter,filter_map,merge}

This commit is contained in:
Simon Cruanes 2018-06-04 23:36:15 -05:00
parent ca0521512f
commit 5523ed428c
2 changed files with 142 additions and 1 deletions

View file

@ -440,8 +440,132 @@ let rec inter f a b =
equal ~eq:(=) (inter f (inter f m1 m2) m3) (inter f m1 (inter f m2 m3)))
*)
let rec disjoint_union_ t1 t2 : _ t = match t1, t2 with
| E, o | o, E -> o
| L (k,v), o
| o, L(k,v) -> insert_ (fun ~old:_ _ -> assert false) k v o
| N (p1,m1,l1,r1), N(p2,m2,l2,r2) ->
if p1 = p2 && Bit.equal m1 m2 then (
mk_node_ p1 m1 (disjoint_union_ l1 l2) (disjoint_union_ r1 r2)
) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
if Bit.is_0 p2 ~bit:m1
then mk_node_ p1 m1 (disjoint_union_ l1 t2) r1
else mk_node_ p1 m1 l1 (disjoint_union_ r1 t2)
) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
if Bit.is_0 p1 ~bit:m2
then mk_node_ p2 m2 (disjoint_union_ t1 l2) r2
else mk_node_ p2 m2 l2 (disjoint_union_ t1 r2)
) else (
join_ t1 p1 t2 p2
)
(** {2 Whole-collection operations} *)
let rec filter f m = match m with
| E -> E
| L (k,v) ->
if f k v then m else E
| N (_,_,l,r) ->
disjoint_union_ (filter f l) (filter f r)
(*$QR
Q.(pair (fun2 Observable.int Observable.int bool) (small_list (pair int int))) (fun (f,l) ->
let QCheck.Fun(_,f) = f in
_list_uniq (List.filter (fun (x,y) -> f x y) l) =
(_list_uniq @@ to_list @@ filter f @@ of_list l)
)
*)
let rec filter_map f m = match m with
| E -> E
| L (k,v) ->
begin match f k v with
| None -> E
| Some v' -> L(k,v')
end
| N (_,_,l,r) ->
disjoint_union_ (filter_map f l) (filter_map f r)
(*$QR
Q.(pair (fun2 Observable.int Observable.int @@ option bool) (small_list (pair int int))) (fun (f,l) ->
let QCheck.Fun(_,f) = f in
_list_uniq (CCList.filter_map (fun (x,y) -> CCOpt.map (CCPair.make x) @@ f x y) l) =
(_list_uniq @@ to_list @@ filter_map f @@ of_list l)
)
*)
let rec merge ~f t1 t2 : _ t =
let merge1 t =
filter_map (fun k v -> f k (`Left v)) t
and merge2 t =
filter_map (fun k v -> f k (`Right v)) t
and add_some k opt m = match opt with
| None -> m
| Some v -> insert_ (fun ~old:_ _ -> assert false) k v m
in
match t1, t2 with
| E, o -> merge2 o
| o, E -> merge1 o
| L (k, v), o ->
let others = merge2 (remove k o) in
add_some k
(try f k (`Both (v,find_exn k o))
with Not_found -> f k (`Left v)) others
| o, L (k, v) ->
let others = merge1 (remove k o) in
add_some k
(try f k (`Both (find_exn k o,v))
with Not_found -> f k (`Right v)) others
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
if p1 = p2 && Bit.equal m1 m2 then (
mk_node_ p1 m1 (merge ~f l1 l2) (merge ~f r1 r2)
) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
if Bit.is_0 p2 ~bit:m1
then mk_node_ p1 m1 (merge ~f l1 t2) (merge1 r1)
else mk_node_ p1 m1 (merge1 l1) (merge ~f r1 t2)
) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
if Bit.is_0 p1 ~bit:m2
then mk_node_ p2 m2 (merge ~f t1 l2) (merge2 r2)
else mk_node_ p2 m2 (merge2 l2) (merge ~f t1 r2)
) else (
join_ (merge1 t1) p1 (merge2 t2) p2
)
(*$inject
let merge_union _x o = match o with
| `Left v | `Right v | `Both (v,_) -> Some v
let merge_inter _x o = match o with
| `Left _ | `Right _ -> None
| `Both (v,_) -> Some v
*)
(*$QR
Q.(let p = small_list (pair small_int small_int) in pair p p) (fun (l1,l2) ->
check_invariants
(merge ~f:merge_union (of_list l1) (of_list l2)))
*)
(*$QR
Q.(let p = small_list (pair small_int small_int) in pair p p) (fun (l1,l2) ->
check_invariants
(merge ~f:merge_inter (of_list l1) (of_list l2)))
*)
(*$QR
Q.(let p = small_list (pair small_int unit) in pair p p) (fun (l1,l2) ->
let l1 = _list_uniq l1 and l2 = _list_uniq l2 in
equal Pervasives.(=)
(union (fun _ v1 _ -> v1) (of_list l1) (of_list l2))
(merge ~f:merge_union (of_list l1) (of_list l2)))
*)
(*$QR
Q.(let p = small_list (pair small_int unit) in pair p p) (fun (l1,l2) ->
let l1 = _list_uniq l1 and l2 = _list_uniq l2 in
equal Pervasives.(=)
(inter (fun _ v1 _ -> v1) (of_list l1) (of_list l2))
(merge ~f:merge_inter (of_list l1) (of_list l2)))
*)
(** {2 Conversions} *)
type 'a sequence = ('a -> unit) -> unit
type 'a gen = unit -> 'a option

View file

@ -37,6 +37,14 @@ val compare : cmp:('a -> 'a -> int) -> 'a t -> 'a t -> int
val update : int -> ('a option -> 'a option) -> 'a t -> 'a t
val filter : (int -> 'a -> bool) -> 'a t -> 'a t
(** Filter values using the given predicate
@since NEXT_RELEASE *)
val filter_map : (int -> 'a -> 'b option) -> 'a t -> 'b t
(** Filter-map values using the given function
@since NEXT_RELEASE *)
val cardinal : _ t -> int
(** Number of bindings in the map. Linear time. *)
@ -59,6 +67,15 @@ val union : (int -> 'a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
val inter : (int -> 'a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
val merge :
f:(int -> [`Left of 'a | `Right of 'b | `Both of 'a * 'b] -> 'c option) ->
'a t -> 'b t -> 'c t
(** [merge ~f m1 m2] merges [m1] and [m2] together, calling [f] once on every
key that occurs in at least one of [m1] and [m2].
if [f k binding = Some c] then [k -> c] is part of the result,
else [k] is not part of the result.
@since NEXT_RELEASE *)
(** {2 Whole-collection operations} *)
type 'a sequence = ('a -> unit) -> unit