Merge pull request #117 from rleonid/add_negate

change `CCBV` to track size, with some breaking changes (thank to @rleonid)
This commit is contained in:
Simon Cruanes 2017-04-19 22:25:07 +02:00 committed by GitHub
commit 02f1d5c4a0
3 changed files with 303 additions and 130 deletions

View file

@ -20,3 +20,4 @@
- David Sheets (@dsheets)
- Glenn Slotte (glennsl)
- @LemonBoy
- Leonid Rozenberg (@rleonid)

View file

@ -1,79 +1,122 @@
(* This file is free software, part of containers. See file "license" for more details. *)
(** {2 Imperative Bitvectors} *)
let __width = Sys.word_size - 2
let width_ = Sys.word_size - 1
(* int with [n] ones *)
let rec __shift bv n =
if n = 0
then bv
else __shift ((bv lsl 1) lor 1) (n-1)
(** We use OCamls ints to store the bits. We index them from the
least significant bit. We create masks to zero out the most significant
bits that aren't used to store values. This is necessary when we are
constructing or negating a bit vector. *)
let lsb_masks_ =
let a = Array.make (width_ + 1) 0 in
for i = 1 to width_ do
a.(i) <- a.(i-1) lor (1 lsl (i - 1))
done;
a
(* only ones *)
let __all_ones = __shift 0 __width
let all_ones_ = lsb_masks_.(width_)
(* count the 1 bits in [n]. See https://en.wikipedia.org/wiki/Hamming_weight *)
let count_bits_ n =
let rec recurse count n =
if n = 0 then count else recurse (count+1) (n land (n-1))
in
recurse 0 n
(* Can I access the "private" members in testing? $Q
(Q.int_bound (Sys.word_size - 1)) (fun i -> count_bits_ lsb_masks_.(i) = i)
*)
type t = {
mutable a : int array;
mutable size : int;
}
let empty () = { a = [| |] }
let length t = t.size
let empty () = { a = [| |] ; size = 0 }
let array_length_of_size size =
if size mod width_ = 0 then size / width_ else (size / width_) + 1
let create ~size default =
if size = 0 then { a = [| |] }
else begin
let n = if size mod __width = 0 then size / __width else (size / __width) + 1 in
let arr = if default
then Array.make n __all_ones
if size = 0 then { a = [| |]; size }
else (
let n = array_length_of_size size in
let a = if default
then Array.make n all_ones_
else Array.make n 0
in
(* adjust last bits *)
if default && (size mod __width) <> 0
then arr.(n-1) <- __shift 0 (size - (n-1) * __width);
{ a = arr }
end
let r = size mod width_ in
if default && r <> 0 then (
Array.unsafe_set a (n-1) lsb_masks_.(r);
);
{ a; size }
)
(*$Q
(Q.pair Q.small_int Q.bool) (fun (size, b) -> create ~size b |> length = size)
*)
(*$T
create ~size:17 true |> cardinal = 17
create ~size:32 true |> cardinal= 32
create ~size:32 true |> cardinal = 32
create ~size:132 true |> cardinal = 132
create ~size:200 false |> cardinal = 0
create ~size:29 true |> to_sorted_list = CCList.range 0 28
*)
let copy bv = { a=Array.copy bv.a; }
let copy bv = { bv with a = Array.copy bv.a }
(*$Q
(Q.list Q.small_int) (fun l -> \
let bv = of_list l in to_list bv = to_list (copy bv))
*)
let length bv = Array.length bv.a
let resize bv len =
if len > Array.length bv.a
then begin
let a' = Array.make len 0 in
Array.blit bv.a 0 a' 0 (Array.length bv.a);
bv.a <- a'
end
(* count the 1 bits in [n]. See https://en.wikipedia.org/wiki/Hamming_weight *)
let __count_bits n =
let rec recurse count n =
if n = 0 then count else recurse (count+1) (n land (n-1))
in
if n < 0
then recurse 1 (n lsr 1) (* only on unsigned *)
else recurse 0 n
let capacity bv = width_ * Array.length bv.a
let cardinal bv =
let n = ref 0 in
for i = 0 to length bv - 1 do
n := !n + __count_bits bv.a.(i)
done;
!n
if bv.size = 0 then 0
else (
let n = ref 0 in
for i = 0 to Array.length bv.a - 1 do
n := !n + count_bits_ bv.a.(i) (* MSB of last element are all 0 *)
done;
!n
)
(*$Q
Q.small_int (fun size -> create ~size true |> cardinal = size)
*)
let really_resize_ bv ~desired ~current size =
let a' = Array.make desired 0 in
Array.blit bv.a 0 a' 0 current;
bv.a <- a';
bv.size <- size
let grow_ bv size =
if size <= capacity bv (* within capacity *)
then bv.size <- size
else ( (* beyond capacity *)
let desired = array_length_of_size size in
let current = Array.length bv.a in
really_resize_ bv ~desired ~current size
)
let shrink_ bv size =
let desired = array_length_of_size size in
let current = Array.length bv.a in
really_resize_ bv ~desired ~current size
let resize bv size =
if size < 0 then invalid_arg "resize: negative size" else
if size < bv.size (* shrink *)
then shrink_ bv size
else if size = bv.size
then ()
else grow_ bv size
(*$R
let bv1 = CCBV.create ~size:87 true in
@ -87,18 +130,18 @@ let cardinal bv =
let is_empty bv =
try
for i = 0 to Array.length bv.a - 1 do
if bv.a.(i) <> 0 then raise Exit
if bv.a.(i) <> 0 then raise Exit (* MSB of last element are all 0 *)
done;
true
with Exit ->
false
let get bv i =
let n = i / __width in
if i < 0 then invalid_arg "get: negative index";
let n = i / width_ in
let i = i mod width_ in
if n < Array.length bv.a
then
let i = i - n * __width in
bv.a.(n) land (1 lsl i) <> 0
then (Array.unsafe_get bv.a n) land (1 lsl i) <> 0
else false
(*$R
@ -118,11 +161,13 @@ let get bv i =
*)
let set bv i =
let n = i / __width in
if n >= Array.length bv.a
then resize bv (n+1);
let i = i - n * __width in
bv.a.(n) <- bv.a.(n) lor (1 lsl i)
if i < 0 then invalid_arg "set: negative index"
else (
let n = i / width_ in
let j = i mod width_ in
if i >= bv.size then grow_ bv (i+1);
Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) lor (1 lsl j))
)
(*$T
let bv = create ~size:3 false in set bv 0; get bv 0
@ -130,40 +175,44 @@ let set bv i =
*)
let reset bv i =
let n = i / __width in
if n >= Array.length bv.a
then resize bv (n+1);
let i = i - n * __width in
bv.a.(n) <- bv.a.(n) land (lnot (1 lsl i))
if i < 0 then invalid_arg "reset: negative index"
else (
let n = i / width_ in
let j = i mod width_ in
if i >= bv.size then grow_ bv (i+1);
Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) land (lnot (1 lsl j)))
)
(*$T
let bv = create ~size:3 false in set bv 0; reset bv 0; not (get bv 0)
*)
let flip bv i =
let n = i / __width in
if n >= Array.length bv.a
then resize bv (n+1);
let i = i - n * __width in
bv.a.(n) <- bv.a.(n) lxor (1 lsl i)
if i < 0 then invalid_arg "reset: negative index"
else (
let n = i / width_ in
let j = i mod width_ in
if i >= bv.size then grow_ bv (i+1);
Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) lxor (1 lsl j))
)
(*$R
let bv = of_list [1;10; 11; 30] in
flip bv 10;
assert_equal [1;11;30] (to_sorted_list bv);
assert_equal false (get bv 10);
assert_equal ~printer:Q.Print.(list int) [1;11;30] (to_sorted_list bv);
assert_equal ~printer:Q.Print.bool false (get bv 10);
flip bv 10;
assert_equal true (get bv 10);
assert_equal ~printer:Q.Print.bool true (get bv 10);
flip bv 5;
assert_equal [1;5;10;11;30] (to_sorted_list bv);
assert_equal true (get bv 5);
assert_equal ~printer:Q.Print.(list int) [1;5;10;11;30] (to_sorted_list bv);
assert_equal ~printer:Q.Print.bool true (get bv 5);
flip bv 100;
assert_equal [1;5;10;11;30;100] (to_sorted_list bv);
assert_equal true (get bv 100);
assert_equal ~printer:Q.Print.(list int) [1;5;10;11;30;100] (to_sorted_list bv);
assert_equal ~printer:Q.Print.bool true (get bv 100);
*)
let clear bv =
Array.iteri (fun i _ -> bv.a.(i) <- 0) bv.a
Array.fill bv.a 0 (Array.length bv.a) 0
(*$T
let bv = create ~size:37 true in cardinal bv = 37 && (clear bv; cardinal bv= 0)
@ -179,11 +228,17 @@ let clear bv =
let iter bv f =
let len = Array.length bv.a in
for n = 0 to len - 1 do
let j = __width * n in
for i = 0 to __width - 1 do
for n = 0 to len - 2 do
let j = width_ * n in
for i = 0 to width_ - 1 do
f (j+i) (bv.a.(n) land (1 lsl i) <> 0)
done
done;
let j = width_ * (len - 1) in
let r = bv.size mod width_ in
let final_length = if r = 0 then width_ else r in
for i = 0 to final_length - 1 do
f (j + i) (bv.a.(len - 1) land (1 lsl i) <> 0)
done
(*$R
@ -195,14 +250,7 @@ let iter bv f =
*)
let iter_true bv f =
let len = Array.length bv.a in
for n = 0 to len - 1 do
let j = __width * n in
for i = 0 to __width - 1 do
if bv.a.(n) land (1 lsl i) <> 0
then f (j+i)
done
done
iter bv (fun i b -> if b then f i else ())
(*$T
of_list [1;5;7] |> iter_true |> Sequence.to_list |> List.sort CCOrd.compare = [1;5;7]
@ -242,8 +290,9 @@ let to_list bv =
let to_sorted_list bv =
List.rev (to_list bv)
(* Interpret these as indices. *)
let of_list l =
let size = List.fold_left max 0 l in
let size = (List.fold_left max 0 l) + 1 in
let bv = create ~size false in
List.iter (fun i -> set bv i) l;
bv
@ -256,15 +305,19 @@ let of_list l =
exception FoundFirst of int
let first bv =
let first_exn bv =
try
iter_true bv (fun i -> raise (FoundFirst i));
raise Not_found
with FoundFirst i ->
i
let first bv =
try Some (first_exn bv)
with Not_found -> None
(*$T
of_list [50; 10; 17; 22; 3; 12] |> first = 3
of_list [50; 10; 17; 22; 3; 12] |> first = Some 3
*)
let filter bv p =
@ -276,18 +329,61 @@ let filter bv p =
to_sorted_list bv = [2;4;6]
*)
let negate_self b =
let len = Array.length b.a in
for n = 0 to len - 1 do
Array.unsafe_set b.a n (lnot (Array.unsafe_get b.a n))
done;
let r = b.size mod width_ in
if r <> 0 then
let l = Array.length b.a - 1 in
Array.unsafe_set b.a l (lsb_masks_.(r) land (Array.unsafe_get b.a l))
(*$T
let v = of_list [1;2;5;7;] in negate_self v; \
cardinal v = (List.length [0;3;4;6])
*)
let negate b =
let a = Array.map (lnot) b.a in
let r = b.size mod width_ in
if r <> 0 then begin
let l = Array.length b.a - 1 in
Array.unsafe_set a l (lsb_masks_.(r) land (Array.unsafe_get a l))
end;
{ a ; size = b.size }
(*$Q
Q.small_int (fun size -> create ~size false |> negate |> cardinal = size)
*)
(* Underlying size grows for union. *)
let union_into ~into bv =
if length into < length bv
then resize into (length bv);
let len = Array.length bv.a in
for i = 0 to len - 1 do
into.a.(i) <- into.a.(i) lor bv.a.(i)
if into.size < bv.size
then grow_ into bv.size;
for i = 0 to (Array.length into.a) - 1 do
Array.unsafe_set into.a i
((Array.unsafe_get into.a i) lor (Array.unsafe_get bv.a i))
done
let union bv1 bv2 =
let bv = copy bv1 in
union_into ~into:bv bv2;
bv
(* To avoid potentially 2 passes, figure out what we need to copy. *)
let union b1 b2 =
if b1.size <= b2.size
then (
let into = copy b2 in
for i = 0 to (Array.length b1.a) - 1 do
Array.unsafe_set into.a i
((Array.unsafe_get into.a i) lor (Array.unsafe_get b1.a i))
done;
into
) else (
let into = copy b1 in
for i = 0 to (Array.length b1.a) - 1 do
Array.unsafe_set into.a i
((Array.unsafe_get into.a i) lor (Array.unsafe_get b2.a i))
done;
into
)
(*$R
let bv1 = CCBV.of_list [1;2;3;4] in
@ -302,22 +398,32 @@ let union bv1 bv2 =
union (of_list [1;2;3;4;5]) (of_list [7;3;5;6]) |> to_sorted_list = CCList.range 1 7
*)
(* Underlying size shrinks for inter. *)
let inter_into ~into bv =
let n = min (length into) (length bv) in
for i = 0 to n - 1 do
into.a.(i) <- into.a.(i) land bv.a.(i)
if into.size > bv.size
then shrink_ into bv.size;
for i = 0 to (Array.length into.a) - 1 do
Array.unsafe_set into.a i
((Array.unsafe_get into.a i) land (Array.unsafe_get bv.a i))
done
let inter bv1 bv2 =
if length bv1 < length bv2
then
let bv = copy bv1 in
let () = inter_into ~into:bv bv2 in
bv
else
let bv = copy bv2 in
let () = inter_into ~into:bv bv1 in
bv
let inter b1 b2 =
if b1.size <= b2.size
then (
let into = copy b1 in
for i = 0 to (Array.length b1.a) - 1 do
Array.unsafe_set into.a i
((Array.unsafe_get into.a i) land (Array.unsafe_get b2.a i))
done;
into
) else (
let into = copy b2 in
for i = 0 to (Array.length b2.a) - 1 do
Array.unsafe_set into.a i
((Array.unsafe_get into.a i) land (Array.unsafe_get b1.a i))
done;
into
)
(*$T
inter (of_list [1;2;3;4]) (of_list [2;4;6;1]) |> to_sorted_list = [1;2;4]
@ -331,6 +437,28 @@ let inter bv1 bv2 =
assert_equal [3;4] l;
*)
(* Underlying size depends on the 'in_' set for diff, so we don't change
it's size! *)
let diff_into ~into bv =
let n = min (Array.length into.a) (Array.length bv.a) in
for i = 0 to n - 1 do
Array.unsafe_set into.a i
((Array.unsafe_get into.a i) land (lnot (Array.unsafe_get bv.a i)))
done
let diff in_ not_in =
let into = copy in_ in
diff_into ~into not_in;
into
(*$T
diff (of_list [1;2;3]) (of_list [1;2;3]) |> to_list = [];
diff (of_list [1;2;3]) (of_list [1;2;3;4]) |> to_list = [];
diff (of_list [1;2;3;4]) (of_list [1;2;3]) |> to_list = [4];
diff (of_list [1;2;3]) (of_list [1;2;3;400]) |> to_list = [];
diff (of_list [1;2;3;400]) (of_list [1;2;3]) |> to_list = [400];
*)
let select bv arr =
let l = ref [] in
begin try
@ -369,10 +497,10 @@ let selecti bv arr =
assert_equal [("b",1); ("c",2); ("f",5)] l;
*)
(*$T
selecti (of_list [1;4;3]) [| 0;1;2;3;4;5;6;7;8 |] \
|> List.sort CCOrd.compare = [1, 1; 3,3; 4,4]
*)
(*$= & ~printer:Q.Print.(list (pair int int))
[1,1; 3,3; 4,4] (selecti (of_list [1;4;3]) [| 0;1;2;3;4;5;6;7;8 |] \
|> List.sort CCOrd.compare)
*)
type 'a sequence = ('a -> unit) -> unit

View file

@ -3,9 +3,13 @@
(** {2 Imperative Bitvectors}
The size of the bitvector is rounded up to the multiple of 30 or 62.
In other words some functions such as {!iter} might iterate on more
bits than what was originally asked for.
{b BREAKING CHANGES} since NEXT_RELEASE:
size is now stored along with the bitvector. Some functions have
a new signature.
The size of the bitvector used to be rounded up to the multiple of 30 or 62.
In other words some functions such as {!iter} would iterate on more
bits than what was originally asked for. This is not the case anymore.
*)
type t
@ -21,29 +25,39 @@ val copy : t -> t
(** Copy of bitvector *)
val cardinal : t -> int
(** Number of bits set *)
(** Number of bits set to one, seen as a set of bits. *)
val length : t -> int
(** Length of underlying array *)
(** Size of underlying bitvector.
This is not related to the underlying implementation.
Changed at NEXT_RELEASE
*)
val capacity : t -> int
(** The number of bits this bitvector can store without resizing.
@since NEXT_RELEASE *)
val resize : t -> int -> unit
(** Resize the BV so that it has at least the given physical length
[resize bv n] should make [bv] able to store [(Sys.word_size - 2)* n] bits *)
(** Resize the BV so that it has the specified length. This can grow or shrink
the underlying bitvector.
@raise Invalid_arg on negative sizes. *)
val is_empty : t -> bool
(** Any bit set? *)
(** Are there any true bits? *)
val set : t -> int -> unit
(** Set i-th bit. *)
(** Set i-th bit, extending the bitvector if needed. *)
val get : t -> int -> bool
(** Is the i-th bit true? Returns false if the index is too high*)
val reset : t -> int -> unit
(** Set i-th bit to 0 *)
(** Set i-th bit to 0, extending the bitvector if needed. *)
val flip : t -> int -> unit
(** Flip i-th bit *)
(** Flip i-th bit, extending the bitvector if needed. *)
val clear : t -> unit
(** Set every bit to 0 *)
@ -62,21 +76,41 @@ val to_sorted_list : t -> int list
increasing order *)
val of_list : int list -> t
(** From a list of true bits *)
(** From a list of true bits.
val first : t -> int
(** First set bit, or
@raise Not_found if all bits are 0 *)
The bits are interpreted as indices into the returned bitvector, so the final
bitvector will have [length t] equal to 1 more than max of list indices. *)
val first : t -> int option
(** First set bit, or return None.
changed type at NEXT_RELEASE *)
val first_exn : t -> int
(** First set bit, or
@raise Not_found if all bits are 0
@since NEXT_RELEASE *)
val filter : t -> (int -> bool) -> unit
(** [filter bv p] only keeps the true bits of [bv] whose [index]
satisfies [p index] *)
val negate_self : t -> unit
(** [negate_self t] flips all of the bits in [t].
@since NEXT_RELEASE *)
val negate : t -> t
(** [negate t] returns a copy of [t] with all of the bits flipped. *)
val union_into : into:t -> t -> unit
(** [union ~into bv] sets [into] to the union of itself and [bv]. *)
(** [union ~into bv] sets [into] to the union of itself and [bv].
Also updates the length of [into] to be at least [length bv]. *)
val inter_into : into:t -> t -> unit
(** [inter ~into bv] sets [into] to the intersection of itself and [bv] *)
(** [inter ~into bv] sets [into] to the intersection of itself and [bv]
Also updates the length of [into] to be at most [length bv]. *)
val union : t -> t -> t
(** [union bv1 bv2] returns the union of the two sets *)
@ -84,6 +118,16 @@ val union : t -> t -> t
val inter : t -> t -> t
(** [inter bv1 bv2] returns the intersection of the two sets *)
val diff_into : into:t -> t -> unit
(** [diff ~into t] Modify [into] with only the bits set but not in [t].
@since NEXT_RELEASE *)
val diff : t -> t -> t
(** [diff t1 t2] Return those bits found [t1] but not in [t2].
@since NEXT_RELEASE *)
val select : t -> 'a array -> 'a list
(** [select arr bv] selects the elements of [arr] whose index
corresponds to a true bit in [bv]. If [bv] is too short, elements of [arr]