mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-07 03:35:30 -05:00
add tests to CCIntMap, add type safety, and fix various bugs in {union,inter}
This commit is contained in:
parent
d7a58b2ef0
commit
132414ba9d
2 changed files with 190 additions and 54 deletions
|
|
@ -29,54 +29,113 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
(* "Fast Mergeable Integer Maps", Okasaki & Gill.
|
||||
We use big-endian trees. *)
|
||||
|
||||
(** Masks with exactly one bit active *)
|
||||
module Bit : sig
|
||||
type t = private int
|
||||
val highest : int -> t
|
||||
val min_int : t
|
||||
val is_0 : bit:t -> int -> bool
|
||||
val is_1 : bit:t -> int -> bool
|
||||
val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *)
|
||||
val lt : t -> t -> bool
|
||||
val gt : t -> t -> bool
|
||||
end = struct
|
||||
type t = int
|
||||
|
||||
let min_int = min_int
|
||||
|
||||
let rec highest_bit_naive x m =
|
||||
if x=m then m
|
||||
else highest_bit_naive (x land (lnot m)) (2*m)
|
||||
|
||||
let mask_20_ = 1 lsl 20
|
||||
let mask_40_ = 1 lsl 40
|
||||
|
||||
let highest x =
|
||||
if x<0 then min_int
|
||||
else if Sys.word_size > 40 && x > mask_40_
|
||||
then (* remove least significant 40 bits *)
|
||||
let x' = x land (lnot (mask_40_ -1)) in
|
||||
highest_bit_naive x' mask_40_
|
||||
else if x> mask_20_
|
||||
then (* small shortcut: remove least significant 20 bits *)
|
||||
let x' = x land (lnot (mask_20_ -1)) in
|
||||
highest_bit_naive x' mask_20_
|
||||
else highest_bit_naive x 1
|
||||
|
||||
let is_0 ~bit x = x land bit = 0
|
||||
let is_1 ~bit x = x land bit = bit
|
||||
|
||||
let mask ~mask x = (x lor (mask -1)) land (lnot mask)
|
||||
(* low endian: let mask_ x ~mask = x land (mask - 1) *)
|
||||
|
||||
let gt a b = (b != min_int) && (a = min_int || a > b)
|
||||
let lt a b = gt b a
|
||||
end
|
||||
|
||||
type 'a t =
|
||||
| E (* empty *)
|
||||
| L of int * 'a (* leaf *)
|
||||
| N of int (* common prefix *) * int (* bit switch *) * 'a t * 'a t
|
||||
| N of int (* common prefix *) * Bit.t (* bit switch *) * 'a t * 'a t
|
||||
|
||||
let empty = E
|
||||
|
||||
let bit_is_0_ x ~bit = x land bit = 0
|
||||
|
||||
let mask_ x ~mask = (x lor (mask -1)) land (lnot mask)
|
||||
(* low endian: let mask_ x ~mask = x land (mask - 1) *)
|
||||
|
||||
let is_prefix_ ~prefix y ~bit = prefix = mask_ y ~mask:bit
|
||||
|
||||
(* loop down until x=lowest_bit_ x *)
|
||||
let rec highest_bit_naive x m =
|
||||
if m = 0 then 0
|
||||
else if x land m = 0 then highest_bit_naive x (m lsr 1)
|
||||
else m
|
||||
|
||||
let highest_bit =
|
||||
(* the highest representable 2^n *)
|
||||
let max_log = 1 lsl (Sys.word_size - 2) in
|
||||
fun x ->
|
||||
if x > 1 lsl 20
|
||||
then (* small shortcut: remove least significant 20 bits *)
|
||||
let x' = x land (lnot ((1 lsl 20) -1)) in
|
||||
highest_bit_naive x' max_log
|
||||
else highest_bit_naive x max_log
|
||||
let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
|
||||
|
||||
(*$Q
|
||||
Q.int (fun i -> \
|
||||
let b = highest_bit i in \
|
||||
i < 0 || (b <= i && (i-b) < b))
|
||||
let b = Bit.highest i in \
|
||||
((b:>int) land i = (b:>int)) && (i < 0 || ((b:>int) <= i && (i-(b:>int)) < (b:>int))))
|
||||
Q.int (fun i -> (Bit.highest i = Bit.min_int) = (i < 0))
|
||||
Q.int (fun i -> ((Bit.highest i:>int) < 0) = (Bit.highest i = Bit.min_int))
|
||||
Q.int (fun i -> let j = (Bit.highest i :> int) in j land (j-1) = 0)
|
||||
*)
|
||||
|
||||
(*$T
|
||||
(Bit.highest min_int :> int) = min_int
|
||||
(Bit.highest 2 :> int) = 2
|
||||
(Bit.highest 17 :> int) = 16
|
||||
(Bit.highest 300 :> int) = 256
|
||||
*)
|
||||
|
||||
(* helper:
|
||||
|
||||
let b_of_i i =
|
||||
let rec f acc i =
|
||||
if i=0 then acc else let q, r = i/2, i mod 2
|
||||
if i=0 then acc else let q, r = i/2, abs (i mod 2)
|
||||
in
|
||||
f (r::acc) q in f [] i;;
|
||||
*)
|
||||
|
||||
(* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *)
|
||||
let branching_bit_ a b =
|
||||
highest_bit (a lxor b)
|
||||
let branching_bit_ a b = Bit.highest (a lxor b)
|
||||
|
||||
(* TODO use hint in branching_bit_ *)
|
||||
|
||||
let check_invariants t =
|
||||
(* check that keys are prefixed by every node in their path *)
|
||||
let rec check_keys path t = match t with
|
||||
| E -> true
|
||||
| L (k, _) ->
|
||||
List.for_all
|
||||
(fun (prefix, switch, side) ->
|
||||
is_prefix_ ~prefix k ~bit:switch
|
||||
&&
|
||||
match side with
|
||||
| `Left -> Bit.is_0 k ~bit:switch
|
||||
| `Right -> Bit.is_1 k ~bit:switch
|
||||
) path
|
||||
| N (prefix, switch, l, r) ->
|
||||
check_keys ((prefix, switch, `Left) :: path) l
|
||||
&&
|
||||
check_keys ((prefix, switch, `Right) :: path) r
|
||||
in
|
||||
check_keys [] t
|
||||
|
||||
(*$Q
|
||||
Q.(list (pair int bool)) (fun l -> \
|
||||
check_invariants (of_list l))
|
||||
*)
|
||||
|
||||
let rec find_exn k t = match t with
|
||||
| E -> raise Not_found
|
||||
|
|
@ -84,11 +143,13 @@ let rec find_exn k t = match t with
|
|||
| L _ -> raise Not_found
|
||||
| N (prefix, m, l, r) ->
|
||||
if is_prefix_ ~prefix k ~bit:m
|
||||
then if bit_is_0_ k ~bit:m
|
||||
then if Bit.is_0 k ~bit:m
|
||||
then find_exn k l
|
||||
else find_exn k r
|
||||
else raise Not_found
|
||||
|
||||
(* TODO test with lt_unsigned_ *)
|
||||
|
||||
(* FIXME: valid if k < 0?
|
||||
if k <= prefix (* search tree *)
|
||||
then find_exn k l
|
||||
|
|
@ -99,10 +160,23 @@ let find k t =
|
|||
try Some (find_exn k t)
|
||||
with Not_found -> None
|
||||
|
||||
(*$Q
|
||||
Q.(list (pair int int)) (fun l -> \
|
||||
let l = CCList.Set.uniq ~eq:(CCFun.compose_binop fst (=)) l in \
|
||||
let m = of_list l in \
|
||||
List.for_all (fun (k,v) -> find k m = Some v) l)
|
||||
*)
|
||||
|
||||
let mem k t =
|
||||
try ignore (find_exn k t); true
|
||||
with Not_found -> false
|
||||
|
||||
(*$Q
|
||||
Q.(list (pair int int)) (fun l -> \
|
||||
let m = of_list l in \
|
||||
List.for_all (fun (k,_) -> mem k m) l)
|
||||
*)
|
||||
|
||||
let mk_node_ prefix switch l r = match l, r with
|
||||
| E, o | o, E -> o
|
||||
| _ -> N (prefix, switch, l, r)
|
||||
|
|
@ -111,10 +185,15 @@ let mk_node_ prefix switch l r = match l, r with
|
|||
(p1 and p2 do not overlap) *)
|
||||
let join_ t1 p1 t2 p2 =
|
||||
let switch = branching_bit_ p1 p2 in
|
||||
let prefix = mask_ p1 ~mask:switch in
|
||||
if bit_is_0_ p1 ~bit:switch
|
||||
then mk_node_ prefix switch t1 t2
|
||||
else (assert (bit_is_0_ p2 ~bit:switch); mk_node_ prefix switch t2 t1)
|
||||
let prefix = Bit.mask p1 ~mask:switch in
|
||||
if Bit.is_0 p1 ~bit:switch
|
||||
then (
|
||||
assert (Bit.is_1 p2 ~bit:switch);
|
||||
mk_node_ prefix switch t1 t2
|
||||
) else (
|
||||
assert (Bit.is_0 p2 ~bit:switch);
|
||||
mk_node_ prefix switch t2 t1
|
||||
)
|
||||
|
||||
let singleton k v = L (k, v)
|
||||
|
||||
|
|
@ -127,7 +206,7 @@ let rec insert_ c k v t = match t with
|
|||
else join_ t k' (L (k, v)) k
|
||||
| N (prefix, switch, l, r) ->
|
||||
if is_prefix_ ~prefix k ~bit:switch
|
||||
then if bit_is_0_ k ~bit:switch
|
||||
then if Bit.is_0 k ~bit:switch
|
||||
then N(prefix, switch, insert_ c k v l, r)
|
||||
else N(prefix, switch, l, insert_ c k v r)
|
||||
else join_ (L(k,v)) k t prefix
|
||||
|
|
@ -145,11 +224,17 @@ let rec remove k t = match t with
|
|||
| L (k', _) -> if k=k' then E else t
|
||||
| N (prefix, switch, l, r) ->
|
||||
if is_prefix_ ~prefix k ~bit:switch
|
||||
then if bit_is_0_ k ~bit:switch
|
||||
then if Bit.is_0 k ~bit:switch
|
||||
then mk_node_ prefix switch (remove k l) r
|
||||
else mk_node_ prefix switch l (remove k r)
|
||||
else t (* not present *)
|
||||
|
||||
(*$Q & ~count:20
|
||||
Q.(list (pair int int)) (fun l -> \
|
||||
let l = CCList.Set.uniq l in let m = of_list l in \
|
||||
List.for_all (fun (k,_) -> mem k m && not (mem k (remove k m))) l)
|
||||
*)
|
||||
|
||||
let update k f t =
|
||||
try
|
||||
let v = find_exn k t in
|
||||
|
|
@ -162,6 +247,8 @@ let update k f t =
|
|||
| None -> t
|
||||
| Some v -> add k v t
|
||||
|
||||
(* TODO test *)
|
||||
|
||||
let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2)
|
||||
|
||||
let rec equal ~eq a b = match a, b with
|
||||
|
|
@ -201,7 +288,8 @@ let choose t =
|
|||
try Some (choose_exn t)
|
||||
with Not_found -> None
|
||||
|
||||
let rec union f a b = match a, b with
|
||||
(* TODO fix *)
|
||||
let rec union f t1 t2 = match t1, t2 with
|
||||
| E, o | o, E -> o
|
||||
| L (k, v), o
|
||||
| o, L (k, v) ->
|
||||
|
|
@ -210,16 +298,43 @@ let rec union f a b = match a, b with
|
|||
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
||||
if p1 = p2 && m1 = m2
|
||||
then mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
|
||||
else if m1 < m2 && is_prefix_ ~prefix:p2 p1 ~bit:m1
|
||||
then if bit_is_0_ p2 ~bit:m1
|
||||
then N (p1, m1, union f l1 b, r1)
|
||||
else N (p1, m1, l1, union f r1 b)
|
||||
else if m1 > m2 && is_prefix_ ~prefix:p1 p2 ~bit:m2
|
||||
then if bit_is_0_ p1 ~bit:m2
|
||||
then N (p2, m2, union f l2 a, r2)
|
||||
else N (p2, m2, l2, union f r2 a)
|
||||
else join_ a p1 b p2
|
||||
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
|
||||
then if Bit.is_0 p2 ~bit:m1
|
||||
then N (p1, m1, union f l1 t2, r1)
|
||||
else N (p1, m1, l1, union f r1 t2)
|
||||
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
|
||||
then if Bit.is_0 p1 ~bit:m2
|
||||
then N (p2, m2, union f t1 l2, r2)
|
||||
else N (p2, m2, l2, union f t1 r2)
|
||||
else join_ t1 p1 t2 p2
|
||||
|
||||
(*$Q & ~small:(fun (a,b) -> List.length a + List.length b)
|
||||
Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \
|
||||
check_invariants (union (fun _ _ x -> x) (of_list l1) (of_list l2)))
|
||||
Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \
|
||||
check_invariants (inter (fun _ _ x -> x) (of_list l1) (of_list l2)))
|
||||
*)
|
||||
|
||||
(*$R
|
||||
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print))
|
||||
(of_list [1, "1"; 2, "2"; 3, "3"; 4, "4"])
|
||||
(union (fun _ a b -> a)
|
||||
(of_list [1, "1"; 3, "3"]) (of_list [2, "2"; 4, "4"]));
|
||||
*)
|
||||
|
||||
(*$R
|
||||
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print))
|
||||
(of_list [1, "1"; 2, "2"; 3, "3"; 4, "4"])
|
||||
(union (fun _ a b -> a)
|
||||
(of_list [1, "1"; 2, "2"; 3, "3"]) (of_list [2, "2"; 4, "4"]))
|
||||
*)
|
||||
|
||||
(*$Q
|
||||
Q.(list (pair int bool)) (fun l -> \
|
||||
equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l)))
|
||||
*)
|
||||
|
||||
(* TODO fix *)
|
||||
let rec inter f a b = match a, b with
|
||||
| E, _ | _, E -> E
|
||||
| L (k, v), o
|
||||
|
|
@ -232,16 +347,28 @@ let rec inter f a b = match a, b with
|
|||
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
||||
if p1 = p2 && m1 = m2
|
||||
then mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
|
||||
else if m1 < m2 && is_prefix_ ~prefix:p2 p1 ~bit:m1
|
||||
then if bit_is_0_ p2 ~bit:m1
|
||||
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1
|
||||
then if Bit.is_0 p2 ~bit:m1
|
||||
then inter f l1 b
|
||||
else inter f r1 b
|
||||
else if m1 > m2 && is_prefix_ ~prefix:p1 p2 ~bit:m2
|
||||
then if bit_is_0_ p1 ~bit:m2
|
||||
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2
|
||||
then if Bit.is_0 p1 ~bit:m2
|
||||
then inter f l2 a
|
||||
else inter f r2 a
|
||||
else E
|
||||
|
||||
(*$R
|
||||
assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (print CCString.print))
|
||||
(singleton 2 "2")
|
||||
(inter (fun _ a b -> a)
|
||||
(of_list [1, "1"; 2, "2"; 3, "3"]) (of_list [2, "2"; 4, "4"]))
|
||||
*)
|
||||
|
||||
(*$Q
|
||||
Q.(list (pair int bool)) (fun l -> \
|
||||
equal ~eq:(=) (of_list l) (inter (fun _ a _ -> a) (of_list l)(of_list l)))
|
||||
*)
|
||||
|
||||
(* TODO: write tests *)
|
||||
|
||||
(** {2 Whole-collection operations} *)
|
||||
|
|
@ -375,7 +502,7 @@ let rec as_tree t () = match t with
|
|||
| E -> `Nil
|
||||
| L (k, v) -> `Node (`Leaf (k, v), [])
|
||||
| N (prefix, switch, l, r) ->
|
||||
`Node (`Node (prefix, switch), [as_tree l; as_tree r])
|
||||
`Node (`Node (prefix, (switch:>int)), [as_tree l; as_tree r])
|
||||
|
||||
(** {2 IO} *)
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ val compare : cmp:('a -> 'a -> int) -> 'a t -> 'a t -> int
|
|||
val update : int -> ('a option -> 'a option) -> 'a t -> 'a t
|
||||
|
||||
val cardinal : _ t -> int
|
||||
(** Number of bindings in the map. Linear time *)
|
||||
|
||||
val iter : (int -> 'a -> unit) -> 'a t -> unit
|
||||
|
||||
|
|
@ -114,10 +115,6 @@ val of_klist : (int * 'a) klist -> 'a t
|
|||
val to_klist : 'a t -> (int * 'a) klist
|
||||
(** @since NEXT_RELEASE *)
|
||||
|
||||
(** Helpers *)
|
||||
|
||||
val highest_bit : int -> int
|
||||
|
||||
type 'a tree = unit -> [`Nil | `Node of 'a * 'a tree list]
|
||||
|
||||
val as_tree : 'a t -> [`Node of int * int | `Leaf of int * 'a ] tree
|
||||
|
|
@ -129,3 +126,15 @@ type 'a printer = Format.formatter -> 'a -> unit
|
|||
val print : 'a printer -> 'a t printer
|
||||
(** @since NEXT_RELEASE *)
|
||||
|
||||
(** Helpers *)
|
||||
|
||||
(**/**)
|
||||
|
||||
module Bit : sig
|
||||
type t = private int
|
||||
val min_int : t
|
||||
val highest : int -> t
|
||||
end
|
||||
val check_invariants : _ t -> bool
|
||||
|
||||
(**/**)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue