add tests to CCIntMap, add type safety, and fix various bugs in {union,inter}

This commit is contained in:
Simon Cruanes 2015-09-02 11:59:33 +02:00
parent d7a58b2ef0
commit 132414ba9d
2 changed files with 190 additions and 54 deletions

View file

@ -29,54 +29,113 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
(* "Fast Mergeable Integer Maps", Okasaki & Gill.
We use big-endian trees. *)
(** Masks with exactly one bit active *)
module Bit : sig
type t = private int
val highest : int -> t
val min_int : t
val is_0 : bit:t -> int -> bool
val is_1 : bit:t -> int -> bool
val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *)
val lt : t -> t -> bool
val gt : t -> t -> bool
end = struct
type t = int
let min_int = min_int
let rec highest_bit_naive x m =
if x=m then m
else highest_bit_naive (x land (lnot m)) (2*m)
let mask_20_ = 1 lsl 20
let mask_40_ = 1 lsl 40
let highest x =
if x<0 then min_int
else if Sys.word_size > 40 && x > mask_40_
then (* remove least significant 40 bits *)
let x' = x land (lnot (mask_40_ -1)) in
highest_bit_naive x' mask_40_
else if x> mask_20_
then (* small shortcut: remove least significant 20 bits *)
let x' = x land (lnot (mask_20_ -1)) in
highest_bit_naive x' mask_20_
else highest_bit_naive x 1
let is_0 ~bit x = x land bit = 0
let is_1 ~bit x = x land bit = bit
let mask ~mask x = (x lor (mask -1)) land (lnot mask)
(* low endian: let mask_ x ~mask = x land (mask - 1) *)
let gt a b = (b != min_int) && (a = min_int || a > b)
let lt a b = gt b a
end
type 'a t =
| E (* empty *)
| L of int * 'a (* leaf *)
| N of int (* common prefix *) * int (* bit switch *) * 'a t * 'a t
| N of int (* common prefix *) * Bit.t (* bit switch *) * 'a t * 'a t
let empty = E
let bit_is_0_ x ~bit = x land bit = 0
let mask_ x ~mask = (x lor (mask -1)) land (lnot mask)
(* low endian: let mask_ x ~mask = x land (mask - 1) *)
let is_prefix_ ~prefix y ~bit = prefix = mask_ y ~mask:bit
(* loop down until x=lowest_bit_ x *)
let rec highest_bit_naive x m =
if m = 0 then 0
else if x land m = 0 then highest_bit_naive x (m lsr 1)
else m
let highest_bit =
(* the highest representable 2^n *)
let max_log = 1 lsl (Sys.word_size - 2) in
fun x ->
if x > 1 lsl 20
then (* small shortcut: remove least significant 20 bits *)
let x' = x land (lnot ((1 lsl 20) -1)) in
highest_bit_naive x' max_log
else highest_bit_naive x max_log
let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
(*$Q
Q.int (fun i -> \
let b = highest_bit i in \
i < 0 || (b <= i && (i-b) < b))
let b = Bit.highest i in \
((b:>int) land i = (b:>int)) && (i < 0 || ((b:>int) <= i && (i-(b:>int)) < (b:>int))))
Q.int (fun i -> (Bit.highest i = Bit.min_int) = (i < 0))
Q.int (fun i -> ((Bit.highest i:>int) < 0) = (Bit.highest i = Bit.min_int))
Q.int (fun i -> let j = (Bit.highest i :> int) in j land (j-1) = 0)
*)
(*$T
(Bit.highest min_int :> int) = min_int
(Bit.highest 2 :> int) = 2
(Bit.highest 17 :> int) = 16
(Bit.highest 300 :> int) = 256
*)
(* helper:
let b_of_i i =
let rec f acc i =
if i=0 then acc else let q, r = i/2, i mod 2
if i=0 then acc else let q, r = i/2, abs (i mod 2)
in
f (r::acc) q in f [] i;;
*)
(* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *)
let branching_bit_ a b =
highest_bit (a lxor b)
let branching_bit_ a b = Bit.highest (a lxor b)
(* TODO use hint in branching_bit_ *)
let check_invariants t =
(* check that keys are prefixed by every node in their path *)
let rec check_keys path t = match t with
| E -> true
| L (k, _) ->
List.for_all
(fun (prefix, switch, side) ->
is_prefix_ ~prefix k ~bit:switch
&&
match side with
| `Left -> Bit.is_0 k ~bit:switch
| `Right -> Bit.is_1 k ~bit:switch
) path
| N (prefix, switch, l, r) ->
check_keys ((prefix, switch, `Left) :: path) l
&&
check_keys ((prefix, switch, `Right) :: path) r
in
check_keys [] t
(*$Q
Q.(list (pair int bool)) (fun l -> \
check_invariants (of_list l))
*)
let rec find_exn k t = match t with
| E -> raise Not_found
@ -84,11 +143,13 @@ let rec find_exn k t = match t with
| L _ -> raise Not_found
| N (prefix, m, l, r) ->
if is_prefix_ ~prefix k ~bit:m
then if bit_is_0_ k ~bit:m
then if Bit.is_0 k ~bit:m
then find_exn k l
else find_exn k r
else raise Not_found
(* TODO test with lt_unsigned_ *)
(* FIXME: valid if k < 0?
if k <= prefix (* search tree *)
then find_exn k l
@ -99,10 +160,23 @@ let find k t =
try Some (find_exn k t)
with Not_found -> None
(*$Q
Q.(list (pair int int)) (fun l -> \
let l = CCList.Set.uniq ~eq:(CCFun.compose_binop fst (=)) l in \
let m = of_list l in \
List.for_all (fun (k,v) -> find k m = Some v) l)
*)
let mem k t =
try ignore (find_exn k t); true
with Not_found -> false
(*$Q
Q.(list (pair int int)) (fun l -> \
let m = of_list l in \
List.for_all (fun (k,_) -> mem k m) l)
*)
let mk_node_ prefix switch l r = match l, r with
| E, o | o, E -> o
| _ -> N (prefix, switch, l, r)
@ -111,10 +185,15 @@ let mk_node_ prefix switch l r = match l, r with
(p1 and p2 do not overlap) *)
let join_ t1 p1 t2 p2 =
let switch = branching_bit_ p1 p2 in
let prefix = mask_ p1 ~mask:switch in
if bit_is_0_ p1 ~bit:switch
then mk_node_ prefix switch t1 t2
else (assert (bit_is_0_ p2 ~bit:switch); mk_node_ prefix switch t2 t1)
let prefix = Bit.mask p1 ~mask:switch in
if Bit.is_0 p1 ~bit:switch
then (
assert (Bit.is_1 p2 ~bit:switch);
mk_node_ prefix switch t1 t2
) else (
assert (Bit.is_0 p2 ~bit:switch);
mk_node_ prefix switch t2 t1
)
let singleton k v = L (k, v)
@ -127,7 +206,7 @@ let rec insert_ c k v t = match t with
else join_ t k' (L (k, v)) k
| N (prefix, switch, l, r) ->
if is_prefix_ ~prefix k ~bit:switch
then if bit_is_0_ k ~bit:switch
then if Bit.is_0 k ~bit:switch
then N(prefix, switch, insert_ c k v l, r)
else N(prefix, switch, l, insert_ c k v r)
else join_ (L(k,v)) k t prefix
@ -145,11 +224,17 @@ let rec remove k t = match t with
| L (k', _) -> if k=k' then E else t
| N (prefix, switch, l, r) ->
if is_prefix_ ~prefix k ~bit:switch
then if bit_is_0_ k ~bit:switch
then if Bit.is_0 k ~bit:switch
then mk_node_ prefix switch (remove k l) r
else mk_node_ prefix switch l (remove k r)
else t (* not present *)
(*$Q & ~count:20
Q.(list (pair int int)) (fun l -> \
let l = CCList.Set.uniq l in let m = of_list l in \
List.for_all (fun (k,_) -> mem k m && not (mem k (remove k m))) l)
*)
let update k f t =
try
let v = find_exn k t in
@ -162,6 +247,8 @@ let update k f t =
| None -> t
| Some v -> add k v t
(* TODO test *)
let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2)
let rec equal ~eq a b = match a, b with
@ -201,7 +288,8 @@ let choose t =
try Some (choose_exn t)
with Not_found -> None
let rec union f a b = match a, b with
(* TODO fix *)
let rec union f t1 t2 = match t1, t2 with
| E, o | o, E -> o
| L (k, v), o
| o, L (k, v) ->
@ -210,16 +298,43 @@ let rec union f a b = match a, b with
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
if p1 = p2 && m1 = m2
then mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
else if m1 < m2 && is_prefix_ ~prefix:p2 p1 ~bit:m1
then if bit_is_0_ p2 ~bit:m1
then N (p1, m1, union f l1 b, r1)
else N (p1, m1, l1, union f r1 b)
else if m1 > m2 && is_prefix_ ~prefix:p1 p2 ~bit:m2
then if bit_is_0_ p1 ~bit:m2
then N (p2, m2, union f l2 a, r2)
else N (p2, m2, l2, union f r2 a)
else join_ a p1 b p2
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
then if Bit.is_0 p2 ~bit:m1
then N (p1, m1, union f l1 t2, r1)
else N (p1, m1, l1, union f r1 t2)
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
then if Bit.is_0 p1 ~bit:m2
then N (p2, m2, union f t1 l2, r2)
else N (p2, m2, l2, union f t1 r2)
else join_ t1 p1 t2 p2
(*$Q & ~small:(fun (a,b) -> List.length a + List.length b)
Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \
check_invariants (union (fun _ _ x -> x) (of_list l1) (of_list l2)))
Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \
check_invariants (inter (fun _ _ x -> x) (of_list l1) (of_list l2)))
*)
(*$R
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print))
(of_list [1, "1"; 2, "2"; 3, "3"; 4, "4"])
(union (fun _ a b -> a)
(of_list [1, "1"; 3, "3"]) (of_list [2, "2"; 4, "4"]));
*)
(*$R
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print))
(of_list [1, "1"; 2, "2"; 3, "3"; 4, "4"])
(union (fun _ a b -> a)
(of_list [1, "1"; 2, "2"; 3, "3"]) (of_list [2, "2"; 4, "4"]))
*)
(*$Q
Q.(list (pair int bool)) (fun l -> \
equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l)))
*)
(* TODO fix *)
let rec inter f a b = match a, b with
| E, _ | _, E -> E
| L (k, v), o
@ -232,16 +347,28 @@ let rec inter f a b = match a, b with
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
if p1 = p2 && m1 = m2
then mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
else if m1 < m2 && is_prefix_ ~prefix:p2 p1 ~bit:m1
then if bit_is_0_ p2 ~bit:m1
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
then if Bit.is_0 p2 ~bit:m1
then inter f l1 b
else inter f r1 b
else if m1 > m2 && is_prefix_ ~prefix:p1 p2 ~bit:m2
then if bit_is_0_ p1 ~bit:m2
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
then if Bit.is_0 p1 ~bit:m2
then inter f l2 a
else inter f r2 a
else E
(*$R
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print))
(singleton 2 "2")
(inter (fun _ a b -> a)
(of_list [1, "1"; 2, "2"; 3, "3"]) (of_list [2, "2"; 4, "4"]))
*)
(*$Q
Q.(list (pair int bool)) (fun l -> \
equal ~eq:(=) (of_list l) (inter (fun _ a _ -> a) (of_list l)(of_list l)))
*)
(* TODO: write tests *)
(** {2 Whole-collection operations} *)
@ -375,7 +502,7 @@ let rec as_tree t () = match t with
| E -> `Nil
| L (k, v) -> `Node (`Leaf (k, v), [])
| N (prefix, switch, l, r) ->
`Node (`Node (prefix, switch), [as_tree l; as_tree r])
`Node (`Node (prefix, (switch:>int)), [as_tree l; as_tree r])
(** {2 IO} *)

View file

@ -61,6 +61,7 @@ val compare : cmp:('a -> 'a -> int) -> 'a t -> 'a t -> int
val update : int -> ('a option -> 'a option) -> 'a t -> 'a t
val cardinal : _ t -> int
(** Number of bindings in the map. Linear time *)
val iter : (int -> 'a -> unit) -> 'a t -> unit
@ -114,10 +115,6 @@ val of_klist : (int * 'a) klist -> 'a t
val to_klist : 'a t -> (int * 'a) klist
(** @since NEXT_RELEASE *)
(** Helpers *)
val highest_bit : int -> int
type 'a tree = unit -> [`Nil | `Node of 'a * 'a tree list]
val as_tree : 'a t -> [`Node of int * int | `Leaf of int * 'a ] tree
@ -129,3 +126,15 @@ type 'a printer = Format.formatter -> 'a -> unit
val print : 'a printer -> 'a t printer
(** @since NEXT_RELEASE *)
(** Helpers *)
(**/**)
module Bit : sig
type t = private int
val min_int : t
val highest : int -> t
end
val check_invariants : _ t -> bool
(**/**)