test(intmap): add some tests for CCIntMap, also improve style

This commit is contained in:
Simon Cruanes 2018-06-04 23:32:08 -05:00
parent 0c48cff2a1
commit ca0521512f
2 changed files with 109 additions and 55 deletions

View file

@ -17,12 +17,13 @@ module Bit : sig
val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *) val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *)
val lt : t -> t -> bool val lt : t -> t -> bool
val gt : t -> t -> bool val gt : t -> t -> bool
val equal_int : int -> t -> bool
end = struct end = struct
type t = int type t = int
let min_int = min_int let min_int = min_int
let equal = (=) let equal : t -> t -> bool = Pervasives.(=)
let rec highest_bit_naive x m = let rec highest_bit_naive x m =
if x=m then m if x=m then m
@ -33,26 +34,45 @@ end = struct
let highest x = let highest x =
if x<0 then min_int if x<0 then min_int
else if Sys.word_size > 40 && x > mask_40_ else if Sys.word_size > 40 && x > mask_40_ then (
then (* remove least significant 40 bits *) (* remove least significant 40 bits *)
let x' = x land (lnot (mask_40_ -1)) in let x' = x land (lnot (mask_40_ -1)) in
highest_bit_naive x' mask_40_ highest_bit_naive x' mask_40_
else if x> mask_20_ ) else if x> mask_20_ then (
then (* small shortcut: remove least significant 20 bits *) (* small shortcut: remove least significant 20 bits *)
let x' = x land (lnot (mask_20_ -1)) in let x' = x land (lnot (mask_20_ -1)) in
highest_bit_naive x' mask_20_ highest_bit_naive x' mask_20_
else highest_bit_naive x 1 ) else (
highest_bit_naive x 1
)
let is_0 ~bit x = x land bit = 0 let[@inline] is_0 ~bit x = x land bit = 0
let is_1 ~bit x = x land bit = bit let[@inline] is_1 ~bit x = x land bit = bit
let mask ~mask x = (x lor (mask -1)) land (lnot mask) let[@inline] mask ~mask x = (x lor (mask -1)) land (lnot mask)
(* low endian: let mask_ x ~mask = x land (mask - 1) *) (* low endian: let mask_ x ~mask = x land (mask - 1) *)
let gt a b = (b != min_int) && (a = min_int || a > b) let[@inline] gt a b = (b != min_int) && (a = min_int || a > b)
let lt a b = gt b a let[@inline] lt a b = gt b a
let equal_int = Pervasives.(=)
end end
(*$inject
let highest2 x : int =
let rec aux i =
if i=0 then i
else if 1 = (x lsr i) then 1 lsl i else aux (i-1)
in
if x<0 then min_int else aux (Sys.word_size-2)
*)
(*$QR & ~count:1_000
Q.int (fun x ->
if Bit.equal_int (highest2 x) (Bit.highest x) then true
else QCheck.Test.fail_reportf "x=%d, highest=%d, highest2=%d@." x
(Bit.highest x :> int) (highest2 x))
*)
type 'a t = type 'a t =
| E (* empty *) | E (* empty *)
| L of int * 'a (* leaf *) | L of int * 'a (* leaf *)
@ -60,10 +80,11 @@ type 'a t =
let empty = E let empty = E
let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit let[@inline] is_prefix_ ~prefix y ~bit =
prefix = Bit.mask y ~mask:bit
(*$inject (*$inject
let _list_uniq = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b)) let _list_uniq l = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b)) l
*) *)
(*$Q (*$Q
@ -92,7 +113,7 @@ let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
*) *)
(* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *) (* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *)
let branching_bit_ a b = Bit.highest (a lxor b) let[@inline] branching_bit_ a b = Bit.highest (a lxor b)
(* TODO use hint in branching_bit_ *) (* TODO use hint in branching_bit_ *)
@ -103,15 +124,14 @@ let check_invariants t =
| L (k, _) -> | L (k, _) ->
List.for_all List.for_all
(fun (prefix, switch, side) -> (fun (prefix, switch, side) ->
is_prefix_ ~prefix k ~bit:switch is_prefix_ ~prefix k ~bit:switch &&
&& begin match side with
match side with
| `Left -> Bit.is_0 k ~bit:switch | `Left -> Bit.is_0 k ~bit:switch
| `Right -> Bit.is_1 k ~bit:switch | `Right -> Bit.is_1 k ~bit:switch
) path end)
path
| N (prefix, switch, l, r) -> | N (prefix, switch, l, r) ->
check_keys ((prefix, switch, `Left) :: path) l check_keys ((prefix, switch, `Left) :: path) l &&
&&
check_keys ((prefix, switch, `Right) :: path) r check_keys ((prefix, switch, `Right) :: path) r
in in
check_keys [] t check_keys [] t
@ -126,11 +146,13 @@ let rec find_exn k t = match t with
| L (k', v) when k = k' -> v | L (k', v) when k = k' -> v
| L _ -> raise Not_found | L _ -> raise Not_found
| N (prefix, m, l, r) -> | N (prefix, m, l, r) ->
if is_prefix_ ~prefix k ~bit:m if is_prefix_ ~prefix k ~bit:m then (
then if Bit.is_0 k ~bit:m if Bit.is_0 k ~bit:m
then find_exn k l then find_exn k l
else find_exn k r else find_exn k r
else raise Not_found ) else (
raise Not_found
)
(* XXX could test with lt_unsigned_? *) (* XXX could test with lt_unsigned_? *)
@ -161,7 +183,7 @@ let mem k t =
List.for_all (fun (k,_) -> mem k m) l) List.for_all (fun (k,_) -> mem k m) l)
*) *)
let mk_node_ prefix switch l r = match l, r with let[@inline] mk_node_ prefix switch l r = match l, r with
| E, o | o, E -> o | E, o | o, E -> o
| _ -> N (prefix, switch, l, r) | _ -> N (prefix, switch, l, r)
@ -170,8 +192,7 @@ let mk_node_ prefix switch l r = match l, r with
let join_ t1 p1 t2 p2 = let join_ t1 p1 t2 p2 =
let switch = branching_bit_ p1 p2 in let switch = branching_bit_ p1 p2 in
let prefix = Bit.mask p1 ~mask:switch in let prefix = Bit.mask p1 ~mask:switch in
if Bit.is_0 p1 ~bit:switch if Bit.is_0 p1 ~bit:switch then (
then (
assert (Bit.is_1 p2 ~bit:switch); assert (Bit.is_1 p2 ~bit:switch);
mk_node_ prefix switch t1 t2 mk_node_ prefix switch t1 t2
) else ( ) else (
@ -179,7 +200,7 @@ let join_ t1 p1 t2 p2 =
mk_node_ prefix switch t2 t1 mk_node_ prefix switch t2 t1
) )
let singleton k v = L (k, v) let[@inline] singleton k v = L (k, v)
(* c: conflict function *) (* c: conflict function *)
let rec insert_ c k v t = match t with let rec insert_ c k v t = match t with
@ -189,11 +210,13 @@ let rec insert_ c k v t = match t with
then L (k, c ~old:v' v) then L (k, c ~old:v' v)
else join_ t k' (L (k, v)) k else join_ t k' (L (k, v)) k
| N (prefix, switch, l, r) -> | N (prefix, switch, l, r) ->
if is_prefix_ ~prefix k ~bit:switch if is_prefix_ ~prefix k ~bit:switch then (
then if Bit.is_0 k ~bit:switch if Bit.is_0 k ~bit:switch
then N(prefix, switch, insert_ c k v l, r) then N(prefix, switch, insert_ c k v l, r)
else N(prefix, switch, l, insert_ c k v r) else N(prefix, switch, l, insert_ c k v r)
else join_ (L(k,v)) k t prefix ) else (
join_ (L(k,v)) k t prefix
)
let add k v t = insert_ (fun ~old:_ v -> v) k v t let add k v t = insert_ (fun ~old:_ v -> v) k v t
@ -207,11 +230,13 @@ let rec remove k t = match t with
| E -> E | E -> E
| L (k', _) -> if k=k' then E else t | L (k', _) -> if k=k' then E else t
| N (prefix, switch, l, r) -> | N (prefix, switch, l, r) ->
if is_prefix_ ~prefix k ~bit:switch if is_prefix_ ~prefix k ~bit:switch then (
then if Bit.is_0 k ~bit:switch if Bit.is_0 k ~bit:switch
then mk_node_ prefix switch (remove k l) r then mk_node_ prefix switch (remove k l) r
else mk_node_ prefix switch l (remove k r) else mk_node_ prefix switch l (remove k r)
else t (* not present *) ) else (
t (* not present *)
)
(*$Q & ~count:20 (*$Q & ~count:20
Q.(list (pair int int)) (fun l -> \ Q.(list (pair int int)) (fun l -> \
@ -240,7 +265,9 @@ let update k f t =
let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2) let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2)
let rec equal ~eq a b = Pervasives.(==) a b || match a, b with let rec equal ~eq a b =
Pervasives.(==) a b ||
begin match a, b with
| E, E -> true | E, E -> true
| L (ka, va), L (kb, vb) -> ka = kb && eq va vb | L (ka, va), L (kb, vb) -> ka = kb && eq va vb
| N (pa, sa, la, ra), N (pb, sb, lb, rb) -> | N (pa, sa, la, ra), N (pb, sb, lb, rb) ->
@ -248,6 +275,7 @@ let rec equal ~eq a b = Pervasives.(==) a b || match a, b with
| E, _ | E, _
| N _, _ | N _, _
| L _, _ -> false | L _, _ -> false
end
(*$Q (*$Q
Q.(list (pair int bool)) ( fun l -> \ Q.(list (pair int bool)) ( fun l -> \
@ -289,26 +317,29 @@ let choose t =
try Some (choose_exn t) try Some (choose_exn t)
with Not_found -> None with Not_found -> None
(** {2 Whole-collection operations} *)
let rec union f t1 t2 = let rec union f t1 t2 =
if Pervasives.(==) t1 t2 then t1 match t1, t2 with
else match t1, t2 with
| E, o | o, E -> o | E, o | o, E -> o
| L (k, v), o | L (k, v), o
| o, L (k, v) -> | o, L (k, v) ->
(* insert k, v into o *) (* insert k, v into o *)
insert_ (fun ~old v -> f k old v) k v o insert_ (fun ~old v -> f k old v) k v o
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) -> | N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
if p1 = p2 && Bit.equal m1 m2 if p1 = p2 && Bit.equal m1 m2 then (
then mk_node_ p1 m1 (union f l1 l2) (union f r1 r2) mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 ) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
then if Bit.is_0 p2 ~bit:m1 if Bit.is_0 p2 ~bit:m1
then N (p1, m1, union f l1 t2, r1) then N (p1, m1, union f l1 t2, r1)
else N (p1, m1, l1, union f r1 t2) else N (p1, m1, l1, union f r1 t2)
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 ) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
then if Bit.is_0 p1 ~bit:m2 if Bit.is_0 p1 ~bit:m2
then N (p2, m2, union f t1 l2, r2) then N (p2, m2, union f t1 l2, r2)
else N (p2, m2, l2, union f t1 r2) else N (p2, m2, l2, union f t1 r2)
else join_ t1 p1 t2 p2 ) else (
join_ t1 p1 t2 p2
)
(*$Q & ~small:(fun (a,b) -> List.length a + List.length b) (*$Q & ~small:(fun (a,b) -> List.length a + List.length b)
Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \ Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \
@ -344,9 +375,30 @@ let rec union f t1 t2 =
equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l))) equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l)))
*) *)
(*$inject
let union_l l1 l2 =
let l2' = List.filter (fun (x,_) -> not @@ List.mem_assoc x l1) l2 in
_list_uniq (l1 @ l2')
let inter_l l1 l2 =
let l2' = List.filter (fun (x,_) -> List.mem_assoc x l1) l2 in
_list_uniq l2'
*)
(*$QR
Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit)))
(fun (l1,l2) ->
union_l l1 l2 = _list_uniq @@ to_list (union (fun _ _ _ ->())(of_list l1) (of_list l2)))
*)
(*$QR
Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit)))
(fun (l1,l2) ->
inter_l l1 l2 = _list_uniq @@ to_list (inter (fun _ _ _ ->()) (of_list l1) (of_list l2)))
*)
let rec inter f a b = let rec inter f a b =
if Pervasives.(==) a b then a match a, b with
else match a, b with
| E, _ | _, E -> E | E, _ | _, E -> E
| L (k, v), o | L (k, v), o
| o, L (k, v) -> | o, L (k, v) ->
@ -356,17 +408,17 @@ let rec inter f a b =
with Not_found -> E with Not_found -> E
end end
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) -> | N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
if p1 = p2 && Bit.equal m1 m2 if p1 = p2 && Bit.equal m1 m2 then (
then mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2) mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 ) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
then if Bit.is_0 p2 ~bit:m1 if Bit.is_0 p2 ~bit:m1
then inter f l1 b then inter f l1 b
else inter f r1 b else inter f r1 b
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 ) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
then if Bit.is_0 p1 ~bit:m2 if Bit.is_0 p1 ~bit:m2
then inter f a l2 then inter f a l2
else inter f a r2 else inter f a r2
else E ) else E
(*$R (*$R
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (pp CCString.pp)) assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (pp CCString.pp))
@ -465,11 +517,12 @@ let compare ~cmp a b =
| Some _, None -> 1 | Some _, None -> 1
| None, Some _ -> -1 | None, Some _ -> -1
| Some (ka, va), Some (kb, vb) -> | Some (ka, va), Some (kb, vb) ->
if ka=kb if ka=kb then (
then
let c = cmp va vb in let c = cmp va vb in
if c=0 then cmp_gen cmp a b else c if c=0 then cmp_gen cmp a b else c
else compare ka kb ) else (
compare ka kb
)
in in
cmp_gen cmp (to_gen a) (to_gen b) cmp_gen cmp (to_gen a) (to_gen b)

View file

@ -118,6 +118,7 @@ module Bit : sig
type t = private int type t = private int
val min_int : t val min_int : t
val highest : int -> t val highest : int -> t
val equal_int : int -> t -> bool
end end
val check_invariants : _ t -> bool val check_invariants : _ t -> bool