mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-06 03:05:28 -05:00
test(intmap): add some tests for CCIntMap, also improve style
This commit is contained in:
parent
0c48cff2a1
commit
ca0521512f
2 changed files with 109 additions and 55 deletions
|
|
@ -17,12 +17,13 @@ module Bit : sig
|
|||
val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *)
|
||||
val lt : t -> t -> bool
|
||||
val gt : t -> t -> bool
|
||||
val equal_int : int -> t -> bool
|
||||
end = struct
|
||||
type t = int
|
||||
|
||||
let min_int = min_int
|
||||
|
||||
let equal = (=)
|
||||
let equal : t -> t -> bool = Pervasives.(=)
|
||||
|
||||
let rec highest_bit_naive x m =
|
||||
if x=m then m
|
||||
|
|
@ -33,26 +34,45 @@ end = struct
|
|||
|
||||
let highest x =
|
||||
if x<0 then min_int
|
||||
else if Sys.word_size > 40 && x > mask_40_
|
||||
then (* remove least significant 40 bits *)
|
||||
else if Sys.word_size > 40 && x > mask_40_ then (
|
||||
(* remove least significant 40 bits *)
|
||||
let x' = x land (lnot (mask_40_ -1)) in
|
||||
highest_bit_naive x' mask_40_
|
||||
else if x> mask_20_
|
||||
then (* small shortcut: remove least significant 20 bits *)
|
||||
) else if x> mask_20_ then (
|
||||
(* small shortcut: remove least significant 20 bits *)
|
||||
let x' = x land (lnot (mask_20_ -1)) in
|
||||
highest_bit_naive x' mask_20_
|
||||
else highest_bit_naive x 1
|
||||
) else (
|
||||
highest_bit_naive x 1
|
||||
)
|
||||
|
||||
let is_0 ~bit x = x land bit = 0
|
||||
let is_1 ~bit x = x land bit = bit
|
||||
let[@inline] is_0 ~bit x = x land bit = 0
|
||||
let[@inline] is_1 ~bit x = x land bit = bit
|
||||
|
||||
let mask ~mask x = (x lor (mask -1)) land (lnot mask)
|
||||
let[@inline] mask ~mask x = (x lor (mask -1)) land (lnot mask)
|
||||
(* low endian: let mask_ x ~mask = x land (mask - 1) *)
|
||||
|
||||
let gt a b = (b != min_int) && (a = min_int || a > b)
|
||||
let lt a b = gt b a
|
||||
let[@inline] gt a b = (b != min_int) && (a = min_int || a > b)
|
||||
let[@inline] lt a b = gt b a
|
||||
let equal_int = Pervasives.(=)
|
||||
end
|
||||
|
||||
(*$inject
|
||||
let highest2 x : int =
|
||||
let rec aux i =
|
||||
if i=0 then i
|
||||
else if 1 = (x lsr i) then 1 lsl i else aux (i-1)
|
||||
in
|
||||
if x<0 then min_int else aux (Sys.word_size-2)
|
||||
*)
|
||||
|
||||
(*$QR & ~count:1_000
|
||||
Q.int (fun x ->
|
||||
if Bit.equal_int (highest2 x) (Bit.highest x) then true
|
||||
else QCheck.Test.fail_reportf "x=%d, highest=%d, highest2=%d@." x
|
||||
(Bit.highest x :> int) (highest2 x))
|
||||
*)
|
||||
|
||||
type 'a t =
|
||||
| E (* empty *)
|
||||
| L of int * 'a (* leaf *)
|
||||
|
|
@ -60,10 +80,11 @@ type 'a t =
|
|||
|
||||
let empty = E
|
||||
|
||||
let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
|
||||
let[@inline] is_prefix_ ~prefix y ~bit =
|
||||
prefix = Bit.mask y ~mask:bit
|
||||
|
||||
(*$inject
|
||||
let _list_uniq = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b))
|
||||
let _list_uniq l = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b)) l
|
||||
*)
|
||||
|
||||
(*$Q
|
||||
|
|
@ -92,7 +113,7 @@ let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
|
|||
*)
|
||||
|
||||
(* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *)
|
||||
let branching_bit_ a b = Bit.highest (a lxor b)
|
||||
let[@inline] branching_bit_ a b = Bit.highest (a lxor b)
|
||||
|
||||
(* TODO use hint in branching_bit_ *)
|
||||
|
||||
|
|
@ -103,15 +124,14 @@ let check_invariants t =
|
|||
| L (k, _) ->
|
||||
List.for_all
|
||||
(fun (prefix, switch, side) ->
|
||||
is_prefix_ ~prefix k ~bit:switch
|
||||
&&
|
||||
match side with
|
||||
is_prefix_ ~prefix k ~bit:switch &&
|
||||
begin match side with
|
||||
| `Left -> Bit.is_0 k ~bit:switch
|
||||
| `Right -> Bit.is_1 k ~bit:switch
|
||||
) path
|
||||
end)
|
||||
path
|
||||
| N (prefix, switch, l, r) ->
|
||||
check_keys ((prefix, switch, `Left) :: path) l
|
||||
&&
|
||||
check_keys ((prefix, switch, `Left) :: path) l &&
|
||||
check_keys ((prefix, switch, `Right) :: path) r
|
||||
in
|
||||
check_keys [] t
|
||||
|
|
@ -126,11 +146,13 @@ let rec find_exn k t = match t with
|
|||
| L (k', v) when k = k' -> v
|
||||
| L _ -> raise Not_found
|
||||
| N (prefix, m, l, r) ->
|
||||
if is_prefix_ ~prefix k ~bit:m
|
||||
then if Bit.is_0 k ~bit:m
|
||||
if is_prefix_ ~prefix k ~bit:m then (
|
||||
if Bit.is_0 k ~bit:m
|
||||
then find_exn k l
|
||||
else find_exn k r
|
||||
else raise Not_found
|
||||
) else (
|
||||
raise Not_found
|
||||
)
|
||||
|
||||
(* XXX could test with lt_unsigned_? *)
|
||||
|
||||
|
|
@ -161,7 +183,7 @@ let mem k t =
|
|||
List.for_all (fun (k,_) -> mem k m) l)
|
||||
*)
|
||||
|
||||
let mk_node_ prefix switch l r = match l, r with
|
||||
let[@inline] mk_node_ prefix switch l r = match l, r with
|
||||
| E, o | o, E -> o
|
||||
| _ -> N (prefix, switch, l, r)
|
||||
|
||||
|
|
@ -170,8 +192,7 @@ let mk_node_ prefix switch l r = match l, r with
|
|||
let join_ t1 p1 t2 p2 =
|
||||
let switch = branching_bit_ p1 p2 in
|
||||
let prefix = Bit.mask p1 ~mask:switch in
|
||||
if Bit.is_0 p1 ~bit:switch
|
||||
then (
|
||||
if Bit.is_0 p1 ~bit:switch then (
|
||||
assert (Bit.is_1 p2 ~bit:switch);
|
||||
mk_node_ prefix switch t1 t2
|
||||
) else (
|
||||
|
|
@ -179,7 +200,7 @@ let join_ t1 p1 t2 p2 =
|
|||
mk_node_ prefix switch t2 t1
|
||||
)
|
||||
|
||||
let singleton k v = L (k, v)
|
||||
let[@inline] singleton k v = L (k, v)
|
||||
|
||||
(* c: conflict function *)
|
||||
let rec insert_ c k v t = match t with
|
||||
|
|
@ -189,11 +210,13 @@ let rec insert_ c k v t = match t with
|
|||
then L (k, c ~old:v' v)
|
||||
else join_ t k' (L (k, v)) k
|
||||
| N (prefix, switch, l, r) ->
|
||||
if is_prefix_ ~prefix k ~bit:switch
|
||||
then if Bit.is_0 k ~bit:switch
|
||||
if is_prefix_ ~prefix k ~bit:switch then (
|
||||
if Bit.is_0 k ~bit:switch
|
||||
then N(prefix, switch, insert_ c k v l, r)
|
||||
else N(prefix, switch, l, insert_ c k v r)
|
||||
else join_ (L(k,v)) k t prefix
|
||||
) else (
|
||||
join_ (L(k,v)) k t prefix
|
||||
)
|
||||
|
||||
let add k v t = insert_ (fun ~old:_ v -> v) k v t
|
||||
|
||||
|
|
@ -207,11 +230,13 @@ let rec remove k t = match t with
|
|||
| E -> E
|
||||
| L (k', _) -> if k=k' then E else t
|
||||
| N (prefix, switch, l, r) ->
|
||||
if is_prefix_ ~prefix k ~bit:switch
|
||||
then if Bit.is_0 k ~bit:switch
|
||||
if is_prefix_ ~prefix k ~bit:switch then (
|
||||
if Bit.is_0 k ~bit:switch
|
||||
then mk_node_ prefix switch (remove k l) r
|
||||
else mk_node_ prefix switch l (remove k r)
|
||||
else t (* not present *)
|
||||
) else (
|
||||
t (* not present *)
|
||||
)
|
||||
|
||||
(*$Q & ~count:20
|
||||
Q.(list (pair int int)) (fun l -> \
|
||||
|
|
@ -240,7 +265,9 @@ let update k f t =
|
|||
|
||||
let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2)
|
||||
|
||||
let rec equal ~eq a b = Pervasives.(==) a b || match a, b with
|
||||
let rec equal ~eq a b =
|
||||
Pervasives.(==) a b ||
|
||||
begin match a, b with
|
||||
| E, E -> true
|
||||
| L (ka, va), L (kb, vb) -> ka = kb && eq va vb
|
||||
| N (pa, sa, la, ra), N (pb, sb, lb, rb) ->
|
||||
|
|
@ -248,6 +275,7 @@ let rec equal ~eq a b = Pervasives.(==) a b || match a, b with
|
|||
| E, _
|
||||
| N _, _
|
||||
| L _, _ -> false
|
||||
end
|
||||
|
||||
(*$Q
|
||||
Q.(list (pair int bool)) ( fun l -> \
|
||||
|
|
@ -289,26 +317,29 @@ let choose t =
|
|||
try Some (choose_exn t)
|
||||
with Not_found -> None
|
||||
|
||||
(** {2 Whole-collection operations} *)
|
||||
|
||||
let rec union f t1 t2 =
|
||||
if Pervasives.(==) t1 t2 then t1
|
||||
else match t1, t2 with
|
||||
match t1, t2 with
|
||||
| E, o | o, E -> o
|
||||
| L (k, v), o
|
||||
| o, L (k, v) ->
|
||||
(* insert k, v into o *)
|
||||
insert_ (fun ~old v -> f k old v) k v o
|
||||
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
||||
if p1 = p2 && Bit.equal m1 m2
|
||||
then mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
|
||||
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
|
||||
then if Bit.is_0 p2 ~bit:m1
|
||||
if p1 = p2 && Bit.equal m1 m2 then (
|
||||
mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
|
||||
) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
|
||||
if Bit.is_0 p2 ~bit:m1
|
||||
then N (p1, m1, union f l1 t2, r1)
|
||||
else N (p1, m1, l1, union f r1 t2)
|
||||
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
|
||||
then if Bit.is_0 p1 ~bit:m2
|
||||
) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
|
||||
if Bit.is_0 p1 ~bit:m2
|
||||
then N (p2, m2, union f t1 l2, r2)
|
||||
else N (p2, m2, l2, union f t1 r2)
|
||||
else join_ t1 p1 t2 p2
|
||||
) else (
|
||||
join_ t1 p1 t2 p2
|
||||
)
|
||||
|
||||
(*$Q & ~small:(fun (a,b) -> List.length a + List.length b)
|
||||
Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \
|
||||
|
|
@ -344,9 +375,30 @@ let rec union f t1 t2 =
|
|||
equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l)))
|
||||
*)
|
||||
|
||||
(*$inject
|
||||
let union_l l1 l2 =
|
||||
let l2' = List.filter (fun (x,_) -> not @@ List.mem_assoc x l1) l2 in
|
||||
_list_uniq (l1 @ l2')
|
||||
|
||||
let inter_l l1 l2 =
|
||||
let l2' = List.filter (fun (x,_) -> List.mem_assoc x l1) l2 in
|
||||
_list_uniq l2'
|
||||
*)
|
||||
|
||||
(*$QR
|
||||
Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit)))
|
||||
(fun (l1,l2) ->
|
||||
union_l l1 l2 = _list_uniq @@ to_list (union (fun _ _ _ ->())(of_list l1) (of_list l2)))
|
||||
*)
|
||||
|
||||
(*$QR
|
||||
Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit)))
|
||||
(fun (l1,l2) ->
|
||||
inter_l l1 l2 = _list_uniq @@ to_list (inter (fun _ _ _ ->()) (of_list l1) (of_list l2)))
|
||||
*)
|
||||
|
||||
let rec inter f a b =
|
||||
if Pervasives.(==) a b then a
|
||||
else match a, b with
|
||||
match a, b with
|
||||
| E, _ | _, E -> E
|
||||
| L (k, v), o
|
||||
| o, L (k, v) ->
|
||||
|
|
@ -356,17 +408,17 @@ let rec inter f a b =
|
|||
with Not_found -> E
|
||||
end
|
||||
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
||||
if p1 = p2 && Bit.equal m1 m2
|
||||
then mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
|
||||
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
|
||||
then if Bit.is_0 p2 ~bit:m1
|
||||
if p1 = p2 && Bit.equal m1 m2 then (
|
||||
mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
|
||||
) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then (
|
||||
if Bit.is_0 p2 ~bit:m1
|
||||
then inter f l1 b
|
||||
else inter f r1 b
|
||||
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
|
||||
then if Bit.is_0 p1 ~bit:m2
|
||||
) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then (
|
||||
if Bit.is_0 p1 ~bit:m2
|
||||
then inter f a l2
|
||||
else inter f a r2
|
||||
else E
|
||||
) else E
|
||||
|
||||
(*$R
|
||||
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (pp CCString.pp))
|
||||
|
|
@ -465,11 +517,12 @@ let compare ~cmp a b =
|
|||
| Some _, None -> 1
|
||||
| None, Some _ -> -1
|
||||
| Some (ka, va), Some (kb, vb) ->
|
||||
if ka=kb
|
||||
then
|
||||
if ka=kb then (
|
||||
let c = cmp va vb in
|
||||
if c=0 then cmp_gen cmp a b else c
|
||||
else compare ka kb
|
||||
) else (
|
||||
compare ka kb
|
||||
)
|
||||
in
|
||||
cmp_gen cmp (to_gen a) (to_gen b)
|
||||
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ module Bit : sig
|
|||
type t = private int
|
||||
val min_int : t
|
||||
val highest : int -> t
|
||||
val equal_int : int -> t -> bool
|
||||
end
|
||||
val check_invariants : _ t -> bool
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue