test(intmap): add some tests for CCIntMap, also improve style

This commit is contained in:
Simon Cruanes 2018-06-04 23:32:08 -05:00
parent 0c48cff2a1
commit ca0521512f
2 changed files with 109 additions and 55 deletions

View file

@ -17,12 +17,13 @@ module Bit : sig
val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *)
val lt : t -> t -> bool
val gt : t -> t -> bool
val equal_int : int -> t -> bool
end = struct
type t = int
let min_int = min_int
let equal = (=)
let equal : t -> t -> bool = Pervasives.(=)
let rec highest_bit_naive x m =
if x=m then m
@ -33,26 +34,45 @@ end = struct
let highest x =
if x<0 then min_int
else if Sys.word_size > 40 && x > mask_40_
then (* remove least significant 40 bits *)
else if Sys.word_size > 40 && x > mask_40_ then (
(* remove least significant 40 bits *)
let x' = x land (lnot (mask_40_ -1)) in
highest_bit_naive x' mask_40_
else if x> mask_20_
then (* small shortcut: remove least significant 20 bits *)
) else if x> mask_20_ then (
(* small shortcut: remove least significant 20 bits *)
let x' = x land (lnot (mask_20_ -1)) in
highest_bit_naive x' mask_20_
else highest_bit_naive x 1
) else (
highest_bit_naive x 1
)
let is_0 ~bit x = x land bit = 0
let is_1 ~bit x = x land bit = bit
let[@inline] is_0 ~bit x = x land bit = 0
let[@inline] is_1 ~bit x = x land bit = bit
let mask ~mask x = (x lor (mask -1)) land (lnot mask)
let[@inline] mask ~mask x = (x lor (mask -1)) land (lnot mask)
(* low endian: let mask_ x ~mask = x land (mask - 1) *)
let gt a b = (b != min_int) && (a = min_int || a > b)
let lt a b = gt b a
let[@inline] gt a b = (b != min_int) && (a = min_int || a > b)
let[@inline] lt a b = gt b a
let equal_int = Pervasives.(=)
end
(*$inject
let highest2 x : int =
let rec aux i =
if i=0 then i
else if 1 = (x lsr i) then 1 lsl i else aux (i-1)
in
if x<0 then min_int else aux (Sys.word_size-2)
*)
(*$QR & ~count:1_000
Q.int (fun x ->
if Bit.equal_int (highest2 x) (Bit.highest x) then true
else QCheck.Test.fail_reportf "x=%d, highest=%d, highest2=%d@." x
(Bit.highest x :> int) (highest2 x))
*)
type 'a t =
| E (* empty *)
| L of int * 'a (* leaf *)
@ -60,10 +80,11 @@ type 'a t =
let empty = E
let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
let[@inline] is_prefix_ ~prefix y ~bit =
prefix = Bit.mask y ~mask:bit
(*$inject
let _list_uniq = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b))
let _list_uniq l = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b)) l
*)
(*$Q
@ -92,7 +113,7 @@ let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
*)
(* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *)
let branching_bit_ a b = Bit.highest (a lxor b)
let[@inline] branching_bit_ a b = Bit.highest (a lxor b)
(* TODO use hint in branching_bit_ *)
@ -103,15 +124,14 @@ let check_invariants t =
| L (k, _) ->
List.for_all
(fun (prefix, switch, side) ->
is_prefix_ ~prefix k ~bit:switch
&&
match side with
is_prefix_ ~prefix k ~bit:switch &&
begin match side with
| `Left -> Bit.is_0 k ~bit:switch
| `Right -> Bit.is_1 k ~bit:switch
) path
end)
path
| N (prefix, switch, l, r) ->
check_keys ((prefix, switch, `Left) :: path) l
&&
check_keys ((prefix, switch, `Left) :: path) l &&
check_keys ((prefix, switch, `Right) :: path) r
in
check_keys [] t
@ -126,11 +146,13 @@ let rec find_exn k t = match t with
| L (k', v) when k = k' -> v
| L _ -> raise Not_found
| N (prefix, m, l, r) ->
if is_prefix_ ~prefix k ~bit:m
then if Bit.is_0 k ~bit:m
if is_prefix_ ~prefix k ~bit:m then (
if Bit.is_0 k ~bit:m
then find_exn k l
else find_exn k r
else raise Not_found
) else (
raise Not_found
)
(* XXX could test with lt_unsigned_? *)
@ -161,7 +183,7 @@ let mem k t =
List.for_all (fun (k,_) -> mem k m) l)
*)
let mk_node_ prefix switch l r = match l, r with
let[@inline] mk_node_ prefix switch l r = match l, r with
| E, o | o, E -> o
| _ -> N (prefix, switch, l, r)
@ -170,8 +192,7 @@ let mk_node_ prefix switch l r = match l, r with
let join_ t1 p1 t2 p2 =
let switch = branching_bit_ p1 p2 in
let prefix = Bit.mask p1 ~mask:switch in
if Bit.is_0 p1 ~bit:switch
then (
if Bit.is_0 p1 ~bit:switch then (
assert (Bit.is_1 p2 ~bit:switch);
mk_node_ prefix switch t1 t2
) else (
@ -179,7 +200,7 @@ let join_ t1 p1 t2 p2 =
mk_node_ prefix switch t2 t1
)
let singleton k v = L (k, v)
let[@inline] singleton k v = L (k, v)
(* c: conflict function *)
let rec insert_ c k v t = match t with
@ -189,11 +210,13 @@ let rec insert_ c k v t = match t with
then L (k, c ~old:v' v)
else join_ t k' (L (k, v)) k
| N (prefix, switch, l, r) ->
if is_prefix_ ~prefix k ~bit:switch
then if Bit.is_0 k ~bit:switch
if is_prefix_ ~prefix k ~bit:switch then (
if Bit.is_0 k ~bit:switch
then N(prefix, switch, insert_ c k v l, r)
else N(prefix, switch, l, insert_ c k v r)
else join_ (L(k,v)) k t prefix
) else (
join_ (L(k,v)) k t prefix
)
let add k v t = insert_ (fun ~old:_ v -> v) k v t
@ -207,11 +230,13 @@ let rec remove k t = match t with
| E -> E
| L (k', _) -> if k=k' then E else t
| N (prefix, switch, l, r) ->
if is_prefix_ ~prefix k ~bit:switch
then if Bit.is_0 k ~bit:switch
if is_prefix_ ~prefix k ~bit:switch then (
if Bit.is_0 k ~bit:switch
then mk_node_ prefix switch (remove k l) r
else mk_node_ prefix switch l (remove k r)
else t (* not present *)
) else (
t (* not present *)
)
(*$Q & ~count:20
Q.(list (pair int int)) (fun l -> \
@ -240,7 +265,9 @@ let update k f t =
let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2)
let rec equal ~eq a b = Pervasives.(==) a b || match a, b with
let rec equal ~eq a b =
Pervasives.(==) a b ||
begin match a, b with
| E, E -> true
| L (ka, va), L (kb, vb) -> ka = kb && eq va vb
| N (pa, sa, la, ra), N (pb, sb, lb, rb) ->
@ -248,6 +275,7 @@ let rec equal ~eq a b = Pervasives.(==) a b || match a, b with
| E, _
| N _, _
| L _, _ -> false
end
(*$Q
Q.(list (pair int bool)) ( fun l -> \
@ -289,26 +317,29 @@ let choose t =
try Some (choose_exn t)
with Not_found -> None
(** {2 Whole-collection operations} *)
let rec union f t1 t2 =
if Pervasives.(==) t1 t2 then t1
else match t1, t2 with
match t1, t2 with
| E, o | o, E -> o
| L (k, v), o
| o, L (k, v) ->
(* insert k, v into o *)
insert_ (fun ~old v -> f k old v) k v o
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
if p1 = p2 && Bit.equal m1 m2
then mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
then if Bit.is_0 p2 ~bit:m1
if p1 = p2 && Bit.equal m1 m2 then (
mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
if Bit.is_0 p2 ~bit:m1
then N (p1, m1, union f l1 t2, r1)
else N (p1, m1, l1, union f r1 t2)
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
then if Bit.is_0 p1 ~bit:m2
) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
if Bit.is_0 p1 ~bit:m2
then N (p2, m2, union f t1 l2, r2)
else N (p2, m2, l2, union f t1 r2)
else join_ t1 p1 t2 p2
) else (
join_ t1 p1 t2 p2
)
(*$Q & ~small:(fun (a,b) -> List.length a + List.length b)
Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \
@ -344,9 +375,30 @@ let rec union f t1 t2 =
equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l)))
*)
(*$inject
let union_l l1 l2 =
let l2' = List.filter (fun (x,_) -> not @@ List.mem_assoc x l1) l2 in
_list_uniq (l1 @ l2')
let inter_l l1 l2 =
let l2' = List.filter (fun (x,_) -> List.mem_assoc x l1) l2 in
_list_uniq l2'
*)
(*$QR
Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit)))
(fun (l1,l2) ->
union_l l1 l2 = _list_uniq @@ to_list (union (fun _ _ _ ->())(of_list l1) (of_list l2)))
*)
(*$QR
Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit)))
(fun (l1,l2) ->
inter_l l1 l2 = _list_uniq @@ to_list (inter (fun _ _ _ ->()) (of_list l1) (of_list l2)))
*)
let rec inter f a b =
if Pervasives.(==) a b then a
else match a, b with
match a, b with
| E, _ | _, E -> E
| L (k, v), o
| o, L (k, v) ->
@ -356,17 +408,17 @@ let rec inter f a b =
with Not_found -> E
end
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
if p1 = p2 && Bit.equal m1 m2
then mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
then if Bit.is_0 p2 ~bit:m1
if p1 = p2 && Bit.equal m1 m2 then (
mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
if Bit.is_0 p2 ~bit:m1
then inter f l1 b
else inter f r1 b
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
then if Bit.is_0 p1 ~bit:m2
) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
if Bit.is_0 p1 ~bit:m2
then inter f a l2
else inter f a r2
else E
) else E
(*$R
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (pp CCString.pp))
@ -465,11 +517,12 @@ let compare ~cmp a b =
| Some _, None -> 1
| None, Some _ -> -1
| Some (ka, va), Some (kb, vb) ->
if ka=kb
then
if ka=kb then (
let c = cmp va vb in
if c=0 then cmp_gen cmp a b else c
else compare ka kb
) else (
compare ka kb
)
in
cmp_gen cmp (to_gen a) (to_gen b)

View file

@ -118,6 +118,7 @@ module Bit : sig
type t = private int
val min_int : t
val highest : int -> t
val equal_int : int -> t -> bool
end
val check_invariants : _ t -> bool