From 132414ba9dfce8812784f5688ffbda36a68ac696 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 2 Sep 2015 11:59:33 +0200 Subject: [PATCH] add tests to `CCIntMap`, add type safety, and fix various bugs in `{union,inter}` --- src/data/CCIntMap.ml | 227 ++++++++++++++++++++++++++++++++---------- src/data/CCIntMap.mli | 17 +++- 2 files changed, 190 insertions(+), 54 deletions(-) diff --git a/src/data/CCIntMap.ml b/src/data/CCIntMap.ml index 39560a27..aba70b84 100644 --- a/src/data/CCIntMap.ml +++ b/src/data/CCIntMap.ml @@ -29,54 +29,113 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. (* "Fast Mergeable Integer Maps", Okasaki & Gill. We use big-endian trees. *) +(** Masks with exactly one bit active *) +module Bit : sig + type t = private int + val highest : int -> t + val min_int : t + val is_0 : bit:t -> int -> bool + val is_1 : bit:t -> int -> bool + val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *) + val lt : t -> t -> bool + val gt : t -> t -> bool +end = struct + type t = int + + let min_int = min_int + + let rec highest_bit_naive x m = + if x=m then m + else highest_bit_naive (x land (lnot m)) (2*m) + + let mask_20_ = 1 lsl 20 + let mask_40_ = 1 lsl 40 + + let highest x = + if x<0 then min_int + else if Sys.word_size > 40 && x > mask_40_ + then (* remove least significant 40 bits *) + let x' = x land (lnot (mask_40_ -1)) in + highest_bit_naive x' mask_40_ + else if x> mask_20_ + then (* small shortcut: remove least significant 20 bits *) + let x' = x land (lnot (mask_20_ -1)) in + highest_bit_naive x' mask_20_ + else highest_bit_naive x 1 + + let is_0 ~bit x = x land bit = 0 + let is_1 ~bit x = x land bit = bit + + let mask ~mask x = (x lor (mask -1)) land (lnot mask) + (* low endian: let mask_ x ~mask = x land (mask - 1) *) + + let gt a b = (b != min_int) && (a = min_int || a > b) + let lt a b = gt b a +end + type 'a t = | E (* empty *) | L of int * 'a (* leaf *) - | N of int (* common prefix *) * int (* bit switch *) * 'a t * 'a t + | N of int (* common prefix *) * Bit.t (* bit switch *) * 'a t * 'a t let empty = E -let bit_is_0_ x ~bit = x land bit = 0 - -let mask_ x ~mask = (x lor (mask -1)) land (lnot mask) -(* low endian: let mask_ x ~mask = x land (mask - 1) *) - -let is_prefix_ ~prefix y ~bit = prefix = mask_ y ~mask:bit - -(* loop down until x=lowest_bit_ x *) -let rec highest_bit_naive x m = - if m = 0 then 0 - else if x land m = 0 then highest_bit_naive x (m lsr 1) - else m - -let highest_bit = - (* the highest representable 2^n *) - let max_log = 1 lsl (Sys.word_size - 2) in - fun x -> - if x > 1 lsl 20 - then (* small shortcut: remove least significant 20 bits *) - let x' = x land (lnot ((1 lsl 20) -1)) in - highest_bit_naive x' max_log - else highest_bit_naive x max_log +let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit (*$Q Q.int (fun i -> \ - let b = highest_bit i in \ - i < 0 || (b <= i && (i-b) < b)) + let b = Bit.highest i in \ + ((b:>int) land i = (b:>int)) && (i < 0 || ((b:>int) <= i && (i-(b:>int)) < (b:>int)))) + Q.int (fun i -> (Bit.highest i = Bit.min_int) = (i < 0)) + Q.int (fun i -> ((Bit.highest i:>int) < 0) = (Bit.highest i = Bit.min_int)) + Q.int (fun i -> let j = (Bit.highest i :> int) in j land (j-1) = 0) *) +(*$T + (Bit.highest min_int :> int) = min_int + (Bit.highest 2 :> int) = 2 + (Bit.highest 17 :> int) = 16 + (Bit.highest 300 :> int) = 256 + *) + (* helper: let b_of_i i = let rec f acc i = - if i=0 then acc else let q, r = i/2, i mod 2 + if i=0 then acc else let q, r = i/2, abs (i mod 2) in f (r::acc) q in f [] i;; *) (* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *) -let branching_bit_ a b = - highest_bit (a lxor b) +let branching_bit_ a b = Bit.highest (a lxor b) + +(* TODO use hint in branching_bit_ *) + +let check_invariants t = + (* check that keys are prefixed by every node in their path *) + let rec check_keys path t = match t with + | E -> true + | L (k, _) -> + List.for_all + (fun (prefix, switch, side) -> + is_prefix_ ~prefix k ~bit:switch + && + match side with + | `Left -> Bit.is_0 k ~bit:switch + | `Right -> Bit.is_1 k ~bit:switch + ) path + | N (prefix, switch, l, r) -> + check_keys ((prefix, switch, `Left) :: path) l + && + check_keys ((prefix, switch, `Right) :: path) r + in + check_keys [] t + +(*$Q + Q.(list (pair int bool)) (fun l -> \ + check_invariants (of_list l)) +*) let rec find_exn k t = match t with | E -> raise Not_found @@ -84,11 +143,13 @@ let rec find_exn k t = match t with | L _ -> raise Not_found | N (prefix, m, l, r) -> if is_prefix_ ~prefix k ~bit:m - then if bit_is_0_ k ~bit:m + then if Bit.is_0 k ~bit:m then find_exn k l else find_exn k r else raise Not_found + (* TODO test with lt_unsigned_ *) + (* FIXME: valid if k < 0? if k <= prefix (* search tree *) then find_exn k l @@ -99,10 +160,23 @@ let find k t = try Some (find_exn k t) with Not_found -> None +(*$Q + Q.(list (pair int int)) (fun l -> \ + let l = CCList.Set.uniq ~eq:(CCFun.compose_binop fst (=)) l in \ + let m = of_list l in \ + List.for_all (fun (k,v) -> find k m = Some v) l) +*) + let mem k t = try ignore (find_exn k t); true with Not_found -> false +(*$Q + Q.(list (pair int int)) (fun l -> \ + let m = of_list l in \ + List.for_all (fun (k,_) -> mem k m) l) +*) + let mk_node_ prefix switch l r = match l, r with | E, o | o, E -> o | _ -> N (prefix, switch, l, r) @@ -111,10 +185,15 @@ let mk_node_ prefix switch l r = match l, r with (p1 and p2 do not overlap) *) let join_ t1 p1 t2 p2 = let switch = branching_bit_ p1 p2 in - let prefix = mask_ p1 ~mask:switch in - if bit_is_0_ p1 ~bit:switch - then mk_node_ prefix switch t1 t2 - else (assert (bit_is_0_ p2 ~bit:switch); mk_node_ prefix switch t2 t1) + let prefix = Bit.mask p1 ~mask:switch in + if Bit.is_0 p1 ~bit:switch + then ( + assert (Bit.is_1 p2 ~bit:switch); + mk_node_ prefix switch t1 t2 + ) else ( + assert (Bit.is_0 p2 ~bit:switch); + mk_node_ prefix switch t2 t1 + ) let singleton k v = L (k, v) @@ -127,7 +206,7 @@ let rec insert_ c k v t = match t with else join_ t k' (L (k, v)) k | N (prefix, switch, l, r) -> if is_prefix_ ~prefix k ~bit:switch - then if bit_is_0_ k ~bit:switch + then if Bit.is_0 k ~bit:switch then N(prefix, switch, insert_ c k v l, r) else N(prefix, switch, l, insert_ c k v r) else join_ (L(k,v)) k t prefix @@ -145,11 +224,17 @@ let rec remove k t = match t with | L (k', _) -> if k=k' then E else t | N (prefix, switch, l, r) -> if is_prefix_ ~prefix k ~bit:switch - then if bit_is_0_ k ~bit:switch + then if Bit.is_0 k ~bit:switch then mk_node_ prefix switch (remove k l) r else mk_node_ prefix switch l (remove k r) else t (* not present *) +(*$Q & ~count:20 + Q.(list (pair int int)) (fun l -> \ + let l = CCList.Set.uniq l in let m = of_list l in \ + List.for_all (fun (k,_) -> mem k m && not (mem k (remove k m))) l) +*) + let update k f t = try let v = find_exn k t in @@ -162,6 +247,8 @@ let update k f t = | None -> t | Some v -> add k v t +(* TODO test *) + let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2) let rec equal ~eq a b = match a, b with @@ -201,7 +288,8 @@ let choose t = try Some (choose_exn t) with Not_found -> None -let rec union f a b = match a, b with +(* TODO fix *) +let rec union f t1 t2 = match t1, t2 with | E, o | o, E -> o | L (k, v), o | o, L (k, v) -> @@ -210,16 +298,43 @@ let rec union f a b = match a, b with | N (p1, m1, l1, r1), N (p2, m2, l2, r2) -> if p1 = p2 && m1 = m2 then mk_node_ p1 m1 (union f l1 l2) (union f r1 r2) - else if m1 < m2 && is_prefix_ ~prefix:p2 p1 ~bit:m1 - then if bit_is_0_ p2 ~bit:m1 - then N (p1, m1, union f l1 b, r1) - else N (p1, m1, l1, union f r1 b) - else if m1 > m2 && is_prefix_ ~prefix:p1 p2 ~bit:m2 - then if bit_is_0_ p1 ~bit:m2 - then N (p2, m2, union f l2 a, r2) - else N (p2, m2, l2, union f r2 a) - else join_ a p1 b p2 + else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 + then if Bit.is_0 p2 ~bit:m1 + then N (p1, m1, union f l1 t2, r1) + else N (p1, m1, l1, union f r1 t2) + else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 + then if Bit.is_0 p1 ~bit:m2 + then N (p2, m2, union f t1 l2, r2) + else N (p2, m2, l2, union f t1 r2) + else join_ t1 p1 t2 p2 +(*$Q & ~small:(fun (a,b) -> List.length a + List.length b) + Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \ + check_invariants (union (fun _ _ x -> x) (of_list l1) (of_list l2))) + Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \ + check_invariants (inter (fun _ _ x -> x) (of_list l1) (of_list l2))) +*) + +(*$R + assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print)) + (of_list [1, "1"; 2, "2"; 3, "3"; 4, "4"]) + (union (fun _ a b -> a) + (of_list [1, "1"; 3, "3"]) (of_list [2, "2"; 4, "4"])); +*) + +(*$R + assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print)) + (of_list [1, "1"; 2, "2"; 3, "3"; 4, "4"]) + (union (fun _ a b -> a) + (of_list [1, "1"; 2, "2"; 3, "3"]) (of_list [2, "2"; 4, "4"])) +*) + +(*$Q + Q.(list (pair int bool)) (fun l -> \ + equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l))) +*) + +(* TODO fix *) let rec inter f a b = match a, b with | E, _ | _, E -> E | L (k, v), o @@ -232,16 +347,28 @@ let rec inter f a b = match a, b with | N (p1, m1, l1, r1), N (p2, m2, l2, r2) -> if p1 = p2 && m1 = m2 then mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2) - else if m1 < m2 && is_prefix_ ~prefix:p2 p1 ~bit:m1 - then if bit_is_0_ p2 ~bit:m1 + else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 + then if Bit.is_0 p2 ~bit:m1 then inter f l1 b else inter f r1 b - else if m1 > m2 && is_prefix_ ~prefix:p1 p2 ~bit:m2 - then if bit_is_0_ p1 ~bit:m2 + else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 + then if Bit.is_0 p1 ~bit:m2 then inter f l2 a else inter f r2 a else E +(*$R + assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print)) + (singleton 2 "2") + (inter (fun _ a b -> a) + (of_list [1, "1"; 2, "2"; 3, "3"]) (of_list [2, "2"; 4, "4"])) +*) + +(*$Q + Q.(list (pair int bool)) (fun l -> \ + equal ~eq:(=) (of_list l) (inter (fun _ a _ -> a) (of_list l)(of_list l))) +*) + (* TODO: write tests *) (** {2 Whole-collection operations} *) @@ -375,7 +502,7 @@ let rec as_tree t () = match t with | E -> `Nil | L (k, v) -> `Node (`Leaf (k, v), []) | N (prefix, switch, l, r) -> - `Node (`Node (prefix, switch), [as_tree l; as_tree r]) + `Node (`Node (prefix, (switch:>int)), [as_tree l; as_tree r]) (** {2 IO} *) diff --git a/src/data/CCIntMap.mli b/src/data/CCIntMap.mli index 0c010138..9d62dc21 100644 --- a/src/data/CCIntMap.mli +++ b/src/data/CCIntMap.mli @@ -61,6 +61,7 @@ val compare : cmp:('a -> 'a -> int) -> 'a t -> 'a t -> int val update : int -> ('a option -> 'a option) -> 'a t -> 'a t val cardinal : _ t -> int +(** Number of bindings in the map. Linear time *) val iter : (int -> 'a -> unit) -> 'a t -> unit @@ -114,10 +115,6 @@ val of_klist : (int * 'a) klist -> 'a t val to_klist : 'a t -> (int * 'a) klist (** @since NEXT_RELEASE *) -(** Helpers *) - -val highest_bit : int -> int - type 'a tree = unit -> [`Nil | `Node of 'a * 'a tree list] val as_tree : 'a t -> [`Node of int * int | `Leaf of int * 'a ] tree @@ -129,3 +126,15 @@ type 'a printer = Format.formatter -> 'a -> unit val print : 'a printer -> 'a t printer (** @since NEXT_RELEASE *) +(** Helpers *) + +(**/**) + +module Bit : sig + type t = private int + val min_int : t + val highest : int -> t +end +val check_invariants : _ t -> bool + +(**/**)