Incorporate reviewier feedback.

Also added style elements from PR#116.
This commit is contained in:
Leonid Rozenberg 2017-04-19 12:26:47 -04:00
parent d8a55a98b9
commit f90f73f671
2 changed files with 160 additions and 110 deletions

View file

@ -1,29 +1,30 @@
(** {2 Imperative Bitvectors} *) (** {2 Imperative Bitvectors} *)
let __width = Sys.word_size - 1 let width_ = Sys.word_size - 1
(** We use OCamls ints to store the bits. We index them from the (** We use OCamls ints to store the bits. We index them from the
least significant bit. We create masks to zero out the most significant least significant bit. We create masks to zero out the most significant
bits that aren't used to store values. *) bits that aren't used to store values. This is necessary when we are
let __lsb_masks = constructing or negating a bit vector. *)
let a = Array.make (__width + 1) 0 in let lsb_masks_ =
for i = 1 to __width do let a = Array.make (width_ + 1) 0 in
for i = 1 to width_ do
a.(i) <- a.(i-1) lor (1 lsl (i - 1)) a.(i) <- a.(i-1) lor (1 lsl (i - 1))
done; done;
a a
let __all_ones = __lsb_masks.(__width) let all_ones_ = lsb_masks_.(width_)
(* count the 1 bits in [n]. See https://en.wikipedia.org/wiki/Hamming_weight *) (* count the 1 bits in [n]. See https://en.wikipedia.org/wiki/Hamming_weight *)
let __count_bits n = let count_bits_ n =
let rec recurse count n = let rec recurse count n =
if n = 0 then count else recurse (count+1) (n land (n-1)) if n = 0 then count else recurse (count+1) (n land (n-1))
in in
recurse 0 n recurse 0 n
(* Can I access the "private" members in testing? $Q (* Can I access the "private" members in testing? $Q
(Q.int_bound (Sys.word_size - 1)) (fun i -> __count_bits __lsb_masks.(i) = i) (Q.int_bound (Sys.word_size - 1)) (fun i -> count_bits_ lsb_masks_.(i) = i)
*) *)
type t = { type t = {
@ -35,23 +36,24 @@ let length t = t.size
let empty () = { a = [| |] ; size = 0 } let empty () = { a = [| |] ; size = 0 }
let __to_array_legnth size = let array_length_of_size size =
if size mod __width = 0 then size / __width else (size / __width) + 1 if size mod width_ = 0 then size / width_ else (size / width_) + 1
let create ~size default = let create ~size default =
if size = 0 then { a = [| |]; size } if size = 0 then { a = [| |]; size }
else begin else (
let n = __to_array_legnth size in let n = capa_of_size size in
let arr = if default let a = if default
then Array.make n __all_ones then Array.make n all_ones_
else Array.make n 0 else Array.make n 0
in in
(* adjust last bits *) (* adjust last bits *)
let r = size mod __width in let r = size mod width_ in
if default && r <> 0 if default && r <> 0 then (
then Array.unsafe_set arr (n-1) __lsb_masks.(r); Array.unsafe_set a (n-1) lsb_masks_.(r);
{ a = arr; size } );
end { a; size }
)
(*$Q (*$Q
(Q.pair Q.small_int Q.bool) (fun (size, b) -> create ~size b |> length = size) (Q.pair Q.small_int Q.bool) (fun (size, b) -> create ~size b |> length = size)
@ -65,52 +67,69 @@ let create ~size default =
create ~size:29 true |> to_sorted_list = CCList.range 0 28 create ~size:29 true |> to_sorted_list = CCList.range 0 28
*) *)
let copy bv = { a = Array.copy bv.a ; size = bv.size } let copy bv = { bv with a = Array.copy bv.a }
(*$Q (*$Q
(Q.list Q.small_int) (fun l -> \ (Q.list Q.small_int) (fun l -> \
let bv = of_list l in to_list bv = to_list (copy bv)) let bv = of_list l in to_list bv = to_list (copy bv))
*) *)
let capacity bv = __width * Array.length bv.a let capacity bv = width_ * Array.length bv.a
(* iterate on words of width (at most) [width_] *)
let iter_words ~f bv: unit =
if bv.size = 0 then ()
else (
let len = array_length_of_size bv.size in
assert (len>0);
for i = 0 to len-1 do
let word = Array.unsafe_get a i in
f i word
done;
if r <> 0 then f (len-1) (Array.unsafe_get a (len-1) land lsb_masks_.(r));
)
let cardinal bv = let cardinal bv =
if bv.size = 0 then 0
else (
let n = ref 0 in let n = ref 0 in
for i = 0 to Array.length bv.a - 1 do for i = 0 to Array.length bv.a - 1 do
n := !n + __count_bits bv.a.(i) n := !n + count_bits_ bv.a.(i) (* MSB of last element are all 0 *)
done; done;
!n !n
)
(*$Q (*$Q
Q.small_int (fun size -> create ~size true |> cardinal = size) Q.small_int (fun size -> create ~size true |> cardinal = size)
*) *)
let __really_resize bv ~desired ~current size = let really_resize_ bv ~desired ~current size =
let a' = Array.make desired 0 in let a' = Array.make desired 0 in
Array.blit bv.a 0 a' 0 current; Array.blit bv.a 0 a' 0 current;
bv.a <- a'; bv.a <- a';
bv.size <- size bv.size <- size
let __grow bv size = let grow_ bv size =
if size <= capacity bv (* within capacity *) if size <= capacity bv (* within capacity *)
then bv.size <- size then bv.size <- size
else (* beyond capacity *) else ( (* beyond capacity *)
let desired = __to_array_legnth size in let desired = array_length_of_size size in
let current = Array.length bv.a in let current = Array.length bv.a in
__really_resize bv ~desired ~current size really_resize_ bv ~desired ~current size
)
let __shrink bv size = let shrink_ bv size =
let desired = __to_array_legnth size in let desired = array_length_of_size size in
let current = Array.length bv.a in let current = Array.length bv.a in
__really_resize bv ~desired ~current size really_resize_ bv ~desired ~current size
let resize bv size = let resize bv size =
if size < 0 then invalid_arg "resize: negative size" else if size < 0 then invalid_arg "resize: negative size" else
if size < bv.size (* shrink *) if size < bv.size (* shrink *)
then __shrink bv size then shrink_ bv size
else if size = bv.size else if size = bv.size
then () then ()
else __grow bv size else grow_ bv size
(*$R (*$R
let bv1 = CCBV.create ~size:87 true in let bv1 = CCBV.create ~size:87 true in
@ -124,16 +143,16 @@ let resize bv size =
let is_empty bv = let is_empty bv =
try try
for i = 0 to Array.length bv.a - 1 do for i = 0 to Array.length bv.a - 1 do
if bv.a.(i) <> 0 then raise Exit if bv.a.(i) <> 0 then raise Exit (* MSB of last element are all 0 *)
done; done;
true true
with Exit -> with Exit ->
false false
let get bv i = let get bv i =
if i < 0 then invalid_arg "get: negative index" else if i < 0 then invalid_arg "get: negative index";
let n = i / __width in let n = i / width_ in
let i = i mod __width in let i = i mod width_ in
if n < Array.length bv.a if n < Array.length bv.a
then (Array.unsafe_get bv.a n) land (1 lsl i) <> 0 then (Array.unsafe_get bv.a n) land (1 lsl i) <> 0
else false else false
@ -155,11 +174,13 @@ let get bv i =
*) *)
let set bv i = let set bv i =
if i < 0 then invalid_arg "set: negative index" else if i < 0 then invalid_arg "set: negative index"
let n = i / __width in else (
let j = i mod __width in let n = i / width_ in
if i >= bv.size then __grow bv i; let j = i mod width_ in
if i >= bv.size then grow_ bv i;
Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) lor (1 lsl j)) Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) lor (1 lsl j))
)
(*$T (*$T
let bv = create ~size:3 false in set bv 0; get bv 0 let bv = create ~size:3 false in set bv 0; get bv 0
@ -167,36 +188,40 @@ let set bv i =
*) *)
let reset bv i = let reset bv i =
if i < 0 then invalid_arg "reset: negative index" else if i < 0 then invalid_arg "reset: negative index"
let n = i / __width in else (
let j = i mod __width in let n = i / width_ in
if i >= bv.size then __grow bv i; let j = i mod width_ in
if i >= bv.size then grow_ bv i;
Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) land (lnot (1 lsl j))) Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) land (lnot (1 lsl j)))
)
(*$T (*$T
let bv = create ~size:3 false in set bv 0; reset bv 0; not (get bv 0) let bv = create ~size:3 false in set bv 0; reset bv 0; not (get bv 0)
*) *)
let flip bv i = let flip bv i =
if i < 0 then invalid_arg "reset: negative index" else if i < 0 then invalid_arg "reset: negative index"
let n = i / __width in else (
let j = i mod __width in let n = i / width_ in
if i >= bv.size then __grow bv i; let j = i mod width_ in
if i >= bv.size then grow_ bv i;
Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) lxor (1 lsl j)) Array.unsafe_set bv.a n ((Array.unsafe_get bv.a n) lxor (1 lsl j))
)
(*$R (*$R
let bv = of_list [1;10; 11; 30] in let bv = of_list [1;10; 11; 30] in
flip bv 10; flip bv 10;
assert_equal [1;11;30] (to_sorted_list bv); assert_equal ~printer:Q.Print.(list int) [1;11;30] (to_sorted_list bv);
assert_equal false (get bv 10); assert_equal ~printer:Q.Print.bool false (get bv 10);
flip bv 10; flip bv 10;
assert_equal true (get bv 10); assert_equal ~printer:Q.Print.bool true (get bv 10);
flip bv 5; flip bv 5;
assert_equal [1;5;10;11;30] (to_sorted_list bv); assert_equal ~printer:Q.Print.(list int) [1;5;10;11;30] (to_sorted_list bv);
assert_equal true (get bv 5); assert_equal ~printer:Q.Print.bool true (get bv 5);
flip bv 100; flip bv 100;
assert_equal [1;5;10;11;30;100] (to_sorted_list bv); assert_equal ~printer:Q.Print.(list int) [1;5;10;11;30;100] (to_sorted_list bv);
assert_equal true (get bv 100); assert_equal ~printer:Q.Print.bool true (get bv 100);
*) *)
let clear bv = let clear bv =
@ -216,11 +241,17 @@ let clear bv =
let iter bv f = let iter bv f =
let len = Array.length bv.a in let len = Array.length bv.a in
for n = 0 to len - 1 do for n = 0 to len - 2 do
let j = __width * n in let j = width_ * n in
for i = 0 to __width - 1 do for i = 0 to width_ - 1 do
f (j+i) (bv.a.(n) land (1 lsl i) <> 0) f (j+i) (bv.a.(n) land (1 lsl i) <> 0)
done done
done;
let j = max 0 (width_ * (len - 2)) in
let r = size mod width_ in
let final_length = if r = 0 then width_ else r in
for i = 0 to final_length - 1 do
f (j + i) (bv.a.(len - 1) land (i lsl i) <> 0)
done done
(*$R (*$R
@ -232,14 +263,7 @@ let iter bv f =
*) *)
let iter_true bv f = let iter_true bv f =
let len = Array.length bv.a in iter bv (fun i b -> if b then f i else ())
for n = 0 to len - 1 do
let j = __width * n in
for i = 0 to __width - 1 do
if bv.a.(n) land (1 lsl i) <> 0
then f (j+i)
done
done
(*$T (*$T
of_list [1;5;7] |> iter_true |> Sequence.to_list |> List.sort CCOrd.compare = [1;5;7] of_list [1;5;7] |> iter_true |> Sequence.to_list |> List.sort CCOrd.compare = [1;5;7]
@ -294,15 +318,19 @@ let of_list l =
exception FoundFirst of int exception FoundFirst of int
let first bv = let first_exn bv =
try try
iter_true bv (fun i -> raise (FoundFirst i)); iter_true bv (fun i -> raise (FoundFirst i));
raise Not_found raise Not_found
with FoundFirst i -> with FoundFirst i ->
i i
let first bv =
try Some (first_exn bv)
with Not_found -> None
(*$T (*$T
of_list [50; 10; 17; 22; 3; 12] |> first = 3 of_list [50; 10; 17; 22; 3; 12] |> first = Some 3
*) *)
let filter bv p = let filter bv p =
@ -319,10 +347,10 @@ let negate_self b =
for n = 0 to len - 1 do for n = 0 to len - 1 do
Array.unsafe_set b.a n (lnot (Array.unsafe_get b.a n)) Array.unsafe_set b.a n (lnot (Array.unsafe_get b.a n))
done; done;
let r = b.size mod __width in let r = b.size mod width_ in
if r <> 0 then if r <> 0 then
let l = Array.length b.a - 1 in let l = Array.length b.a - 1 in
Array.unsafe_set b.a l (__lsb_masks.(r) land (Array.unsafe_get b.a l)) Array.unsafe_set b.a l (lsb_masks_.(r) land (Array.unsafe_get b.a l))
(*$T (*$T
let v = of_list [1;2;5;7;] in negate_self v; \ let v = of_list [1;2;5;7;] in negate_self v; \
@ -331,10 +359,10 @@ let negate_self b =
let negate b = let negate b =
let a = Array.map (lnot) b.a in let a = Array.map (lnot) b.a in
let r = b.size mod __width in let r = b.size mod width_ in
if r <> 0 then begin if r <> 0 then begin
let l = Array.length b.a - 1 in let l = Array.length b.a - 1 in
Array.unsafe_set a l (__lsb_masks.(r) land (Array.unsafe_get a l)) Array.unsafe_set a l (lsb_masks_.(r) land (Array.unsafe_get a l))
end; end;
{ a ; size = b.size } { a ; size = b.size }
@ -345,7 +373,7 @@ let negate b =
(* Underlying size grows for union. *) (* Underlying size grows for union. *)
let union_into ~into bv = let union_into ~into bv =
if into.size < bv.size if into.size < bv.size
then __grow into bv.size; then grow_ into bv.size;
for i = 0 to (Array.length into.a) - 1 do for i = 0 to (Array.length into.a) - 1 do
Array.unsafe_set into.a i Array.unsafe_set into.a i
((Array.unsafe_get into.a i) lor (Array.unsafe_get bv.a i)) ((Array.unsafe_get into.a i) lor (Array.unsafe_get bv.a i))
@ -354,21 +382,21 @@ let union_into ~into bv =
(* To avoid potentially 2 passes, figure out what we need to copy. *) (* To avoid potentially 2 passes, figure out what we need to copy. *)
let union b1 b2 = let union b1 b2 =
if b1.size <= b2.size if b1.size <= b2.size
then begin then (
let into = copy b2 in let into = copy b2 in
for i = 0 to (Array.length b1.a) - 1 do for i = 0 to (Array.length b1.a) - 1 do
Array.unsafe_set into.a i Array.unsafe_set into.a i
((Array.unsafe_get into.a i) lor (Array.unsafe_get b1.a i)) ((Array.unsafe_get into.a i) lor (Array.unsafe_get b1.a i))
done; done;
into into
end else begin ) else (
let into = copy b1 in let into = copy b1 in
for i = 0 to (Array.length b1.a) - 1 do for i = 0 to (Array.length b1.a) - 1 do
Array.unsafe_set into.a i Array.unsafe_set into.a i
((Array.unsafe_get into.a i) lor (Array.unsafe_get b2.a i)) ((Array.unsafe_get into.a i) lor (Array.unsafe_get b2.a i))
done; done;
into into
end )
(*$R (*$R
let bv1 = CCBV.of_list [1;2;3;4] in let bv1 = CCBV.of_list [1;2;3;4] in
@ -386,7 +414,7 @@ let union b1 b2 =
(* Underlying size shrinks for inter. *) (* Underlying size shrinks for inter. *)
let inter_into ~into bv = let inter_into ~into bv =
if into.size > bv.size if into.size > bv.size
then __shrink into bv.size; then shrink_ into bv.size;
for i = 0 to (Array.length into.a) - 1 do for i = 0 to (Array.length into.a) - 1 do
Array.unsafe_set into.a i Array.unsafe_set into.a i
((Array.unsafe_get into.a i) land (Array.unsafe_get bv.a i)) ((Array.unsafe_get into.a i) land (Array.unsafe_get bv.a i))
@ -394,21 +422,21 @@ let inter_into ~into bv =
let inter b1 b2 = let inter b1 b2 =
if b1.size <= b2.size if b1.size <= b2.size
then begin then (
let into = copy b1 in let into = copy b1 in
for i = 0 to (Array.length b1.a) - 1 do for i = 0 to (Array.length b1.a) - 1 do
Array.unsafe_set into.a i Array.unsafe_set into.a i
((Array.unsafe_get into.a i) land (Array.unsafe_get b2.a i)) ((Array.unsafe_get into.a i) land (Array.unsafe_get b2.a i))
done; done;
into into
end else begin ) else (
let into = copy b2 in let into = copy b2 in
for i = 0 to (Array.length b2.a) - 1 do for i = 0 to (Array.length b2.a) - 1 do
Array.unsafe_set into.a i Array.unsafe_set into.a i
((Array.unsafe_get into.a i) land (Array.unsafe_get b1.a i)) ((Array.unsafe_get into.a i) land (Array.unsafe_get b1.a i))
done; done;
into into
end )
(*$T (*$T
inter (of_list [1;2;3;4]) (of_list [2;4;6;1]) |> to_sorted_list = [1;2;4] inter (of_list [1;2;3;4]) (of_list [2;4;6;1]) |> to_sorted_list = [1;2;4]
@ -431,7 +459,7 @@ let diff_into ~into bv =
((Array.unsafe_get into.a i) land (lnot (Array.unsafe_get bv.a i))) ((Array.unsafe_get into.a i) land (lnot (Array.unsafe_get bv.a i)))
done done
let diff ~in_ not_in = let diff in_ not_in =
let into = copy in_ in let into = copy in_ in
diff_into ~into not_in; diff_into ~into not_in;
into into
@ -474,9 +502,11 @@ let selecti bv arr =
assert_equal [("b",1); ("c",2); ("f",5)] l; assert_equal [("b",1); ("c",2); ("f",5)] l;
*) *)
(*$T (*$= & ~printer:Q.Print.(list (pair int int))
selecti (of_list [1;4;3]) [| 0;1;2;3;4;5;6;7;8 |] \ selecti (of_list [1;4;3]) [| 0;1;2;3;4;5;6;7;8 |] \
[1,1; 3,3; 4,4] (selecti (of_list [1;4;3]) [| 0;1;2;3;4;5;6;7;8 |] \
|> List.sort CCOrd.compare = [1, 1; 3,3; 4,4] |> List.sort CCOrd.compare = [1, 1; 3,3; 4,4]
|> List.sort CCOrd.compare)
*) *)
type 'a sequence = ('a -> unit) -> unit type 'a sequence = ('a -> unit) -> unit

View file

@ -3,9 +3,13 @@
(** {2 Imperative Bitvectors} (** {2 Imperative Bitvectors}
The size of the bitvector is rounded up to the multiple of 30 or 62. {b BREAKING CHANGES} since NEXT_RELEASE:
In other words some functions such as {!iter} might iterate on more size is now stored along with the bitvector. Some functions have
bits than what was originally asked for. a new signature.
The size of the bitvector used to be rounded up to the multiple of 30 or 62.
In other words some functions such as {!iter} would iterate on more
bits than what was originally asked for. This is not the case anymore.
*) *)
type t type t
@ -21,13 +25,18 @@ val copy : t -> t
(** Copy of bitvector *) (** Copy of bitvector *)
val cardinal : t -> int val cardinal : t -> int
(** Number of set bits. *) (** Number of bits set to one, seen as a set of bits. *)
val length : t -> int val length : t -> int
(** Length of underlying bitvector. *) (** Size of underlying bitvector.
This is not related to the underlying implementation.
Changed at NEXT_RELEASE
*)
val capacity : t -> int val capacity : t -> int
(** The number of bits this bitvector can store without resizing. *) (** The number of bits this bitvector can store without resizing.
@since NEXT_RELEASE *)
val resize : t -> int -> unit val resize : t -> int -> unit
(** Resize the BV so that it has the specified length. This can grow or shrink (** Resize the BV so that it has the specified length. This can grow or shrink
@ -36,19 +45,19 @@ val resize : t -> int -> unit
@raise Invalid_arg on negative sizes. *) @raise Invalid_arg on negative sizes. *)
val is_empty : t -> bool val is_empty : t -> bool
(** Any bit set? *) (** Are there any true bits? *)
val set : t -> int -> unit val set : t -> int -> unit
(** Set i-th bit. *) (** Set i-th bit, extending the bitvector if needed. *)
val get : t -> int -> bool val get : t -> int -> bool
(** Is the i-th bit true? Returns false if the index is too high*) (** Is the i-th bit true? Returns false if the index is too high*)
val reset : t -> int -> unit val reset : t -> int -> unit
(** Set i-th bit to 0 *) (** Set i-th bit to 0, extending the bitvector if needed. *)
val flip : t -> int -> unit val flip : t -> int -> unit
(** Flip i-th bit *) (** Flip i-th bit, extending the bitvector if needed. *)
val clear : t -> unit val clear : t -> unit
(** Set every bit to 0 *) (** Set every bit to 0 *)
@ -72,16 +81,23 @@ val of_list : int list -> t
The bits are interpreted as indices into the returned bitvector, so the final The bits are interpreted as indices into the returned bitvector, so the final
bitvector will have [length t] equal to 1 more than max of list indices. *) bitvector will have [length t] equal to 1 more than max of list indices. *)
val first : t -> int val first : t -> int option
(** First set bit, or return None.
changed type at NEXT_RELEASE *)
val first_exn : t -> int
(** First set bit, or (** First set bit, or
@raise Not_found if all bits are 0 *) @raise Not_found if all bits are 0
@since NEXT_RELEASE *)
val filter : t -> (int -> bool) -> unit val filter : t -> (int -> bool) -> unit
(** [filter bv p] only keeps the true bits of [bv] whose [index] (** [filter bv p] only keeps the true bits of [bv] whose [index]
satisfies [p index] *) satisfies [p index] *)
val negate_self : t -> unit val negate_self : t -> unit
(** [negate_self t] flips all of the bits in [t]. *) (** [negate_self t] flips all of the bits in [t].
@since NEXT_RELEASE *)
val negate : t -> t val negate : t -> t
(** [negate t] returns a copy of [t] with all of the bits flipped. *) (** [negate t] returns a copy of [t] with all of the bits flipped. *)
@ -89,12 +105,12 @@ val negate : t -> t
val union_into : into:t -> t -> unit val union_into : into:t -> t -> unit
(** [union ~into bv] sets [into] to the union of itself and [bv]. (** [union ~into bv] sets [into] to the union of itself and [bv].
Note that [into] will grow to accammodate the union. *) Also updates the length of [into] to be at least [length bv]. *)
val inter_into : into:t -> t -> unit val inter_into : into:t -> t -> unit
(** [inter ~into bv] sets [into] to the intersection of itself and [bv] (** [inter ~into bv] sets [into] to the intersection of itself and [bv]
Note that [into] will shrink to accammodate the union. *) Also updates the length of [into] to be at most [length bv]. *)
val union : t -> t -> t val union : t -> t -> t
(** [union bv1 bv2] returns the union of the two sets *) (** [union bv1 bv2] returns the union of the two sets *)
@ -103,10 +119,14 @@ val inter : t -> t -> t
(** [inter bv1 bv2] returns the intersection of the two sets *) (** [inter bv1 bv2] returns the intersection of the two sets *)
val diff_into : into:t -> t -> unit val diff_into : into:t -> t -> unit
(** [diff ~into t] Modify [into] with only the bits set but not in [t]. *) (** [diff ~into t] Modify [into] with only the bits set but not in [t].
val diff : in_:t -> t -> t @since NEXT_RELEASE *)
(** [diff ~in_ t] Return those bits found [in_] but not in [t]. *)
val diff : t -> t -> t
(** [diff t1 t2] Return those bits found [t1] but not in [t2].
@since NEXT_RELEASE *)
val select : t -> 'a array -> 'a list val select : t -> 'a array -> 'a list
(** [select arr bv] selects the elements of [arr] whose index (** [select arr bv] selects the elements of [arr] whose index