diff --git a/core/CCMultiSet.ml b/core/CCMultiSet.ml index 8e840586..3fa4c8c1 100644 --- a/core/CCMultiSet.ml +++ b/core/CCMultiSet.ml @@ -45,6 +45,12 @@ module type S = sig val remove : t -> elt -> t + val add_mult : t -> elt -> int -> t + + val remove_mult : t -> elt -> int -> t + + val update : t -> elt -> (int -> int) -> t + val min : t -> elt val max : t -> elt @@ -102,12 +108,31 @@ module Make(O : Set.OrderedType) = struct let n = count ms x in M.add x (n+1) ms - let remove ms x = + let add_mult ms x n = + if n < 0 then invalid_arg "CCMultiSet.add_mult"; + if n=0 + then ms + else M.add x (count ms x + n) ms + + let remove_mult ms x n = + if n < 0 then invalid_arg "CCMultiSet.remove_mult"; + let cur_n = count ms x in + let new_n = cur_n - n in + if new_n <= 0 + then M.remove x ms + else M.add x new_n ms + + let remove ms x = remove_mult ms x 1 + + let update ms x f = let n = count ms x in - match n with - | 0 -> ms - | 1 -> M.remove x ms - | _ -> M.add x (n-1) ms + match f n with + | 0 -> + if n=0 then ms else M.remove x ms + | n' -> + if n' < 0 + then invalid_arg "CCMultiSet.udpate" + else M.add x n' ms let min ms = fst (M.min_binding ms) @@ -197,3 +222,12 @@ module Make(O : Set.OrderedType) = struct seq (fun x -> m := add !m x); !m end + +(*$T + let module S = CCMultiSet.Make(String) in \ + S.count (S.add_mult S.empty "a" 5) "a" = 5 + let module S = CCMultiSet.Make(String) in \ + S.count (S.remove_mult (S.add_mult S.empty "a" 5) "a" 3) "a" = 2 + let module S = CCMultiSet.Make(String) in \ + S.count (S.remove_mult (S.add_mult S.empty "a" 4) "a" 6) "a" = 0 +*) diff --git a/core/CCMultiSet.mli b/core/CCMultiSet.mli index 89d32f83..826f394b 100644 --- a/core/CCMultiSet.mli +++ b/core/CCMultiSet.mli @@ -45,6 +45,23 @@ module type S = sig val remove : t -> elt -> t + val add_mult : t -> elt -> int -> t + (** [add_mult set x n] adds [n] occurrences of [x] to [set] + @raise Invalid_argument if [n < 0] + @since NEXT_RELEASE *) + + val remove_mult : t -> elt -> int -> t + (** [remove_mult set x n] removes at most [n] occurrences of [x] from [set] + @raise Invalid_argument if [n < 0] + @since NEXT_RELEASE *) + + val update : t -> elt -> (int -> int) -> t + (** [update set x f] calls [f n] where [n] is the current multiplicity + of [x] in [set] ([0] to indicate its absence); the result of [f n] + is the new multiplicity of [x]. + @raise Invalid_argument if [f n < 0] + @since NEXT_RELEASE *) + val min : t -> elt (** Minimal element w.r.t the total ordering on elements *)