feat(BV): correct many bugs, clarify parts of the API

This commit is contained in:
Simon Cruanes 2022-07-04 21:47:39 -04:00
parent 75fe196d3a
commit 60b9ece69e
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 99 additions and 63 deletions

View file

@ -1,5 +1,3 @@
(** {2 Imperative Bitvectors} *)
let width_ = 8
(* Helper functions *)
@ -52,6 +50,12 @@ let[@inline] __popcount8 (b : int) : int =
let b = (b + (b lsr 4)) land m4 in
b land 0x7f
(*
invariants for [v:t]:
- [Bytes.length v.b >= div_ v.size] (enough storage)
- all bits above [size] are 0 in [v.b]
*)
type t = { mutable b: bytes; mutable size: int }
let length t = t.size
@ -76,22 +80,43 @@ let create ~size default =
in
(* adjust last bits *)
let r = mod_ size in
if default && r <> 0 then
Bytes.unsafe_set b (n - 1) (Char.unsafe_chr (__lsb_mask r));
if default && r <> 0 then unsafe_set_ b (n - 1) (__lsb_mask r);
{ b; size }
)
let[@inline] copy bv = { bv with b = Bytes.copy bv.b }
let copy bv = { bv with b = Bytes.sub bv.b 0 (bytes_length_of_size bv.size) }
let[@inline] capacity bv = mul_ (Bytes.length bv.b)
(* call [f i width(byte[i]) (byte[i])] on each byte.
The last byte might have a width of less than 8. *)
let iter_bytes_ (b : t) ~f : unit =
for n = 0 to div_ b.size - 1 do
f (mul_ n) width_ (unsafe_get_ b.b n)
done;
let r = mod_ b.size in
if r <> 0 then (
let last = div_ b.size in
f (mul_ last) r (__lsb_mask r land unsafe_get_ b.b last)
)
(* set [byte[i]] to [f(byte[i])] *)
let map_bytes_ (b : t) ~f : unit =
for n = 0 to div_ b.size - 1 do
unsafe_set_ b.b n (f (unsafe_get_ b.b n))
done;
let r = mod_ b.size in
if r <> 0 then (
let last = div_ b.size in
let mask = __lsb_mask r in
unsafe_set_ b.b last (mask land f (mask land unsafe_get_ b.b last))
)
let cardinal bv =
if bv.size = 0 then
0
else (
let n = ref 0 in
for i = 0 to Bytes.length bv.b - 1 do
n := !n + __popcount8 (get_ bv.b i) (* MSB of last element are all 0 *)
done;
iter_bytes_ bv ~f:(fun _ _ b -> n := !n + __popcount8 b);
!n
)
@ -99,10 +124,19 @@ let really_resize_ bv ~desired ~current size =
bv.size <- size;
if desired <> current then (
let b = Bytes.make desired zero in
Bytes.blit bv.b 0 b 0 current;
Bytes.blit bv.b 0 b 0 (min desired current);
bv.b <- b
)
(* set bits above [n] to 0 *)
let[@inline never] clear_bits_above_ bv top =
let n = div_ top in
let j = mod_ top in
Bytes.fill bv.b (n + 1)
(bytes_length_of_size bv.size - n - 1)
(Char.unsafe_chr 0);
unsafe_set_ bv.b n (unsafe_get_ bv.b n land __lsb_mask j)
let[@inline never] grow_to_at_least_real_ bv size =
(* beyond capacity *)
let current = Bytes.length bv.b in
@ -123,15 +157,20 @@ let grow_to_at_least_ bv size =
grow_to_at_least_real_ bv size
let shrink_ bv size =
let desired = bytes_length_of_size size in
let current = Bytes.length bv.b in
really_resize_ bv ~desired ~current size
assert (size <= bv.size);
if size < bv.size then (
let desired = bytes_length_of_size size in
let current = Bytes.length bv.b in
if desired = current then clear_bits_above_ bv size;
really_resize_ bv ~desired ~current size
)
let resize bv size =
if size < 0 then invalid_arg "resize: negative size";
if size < bv.size then
if size < bv.size then (
clear_bits_above_ bv size;
bv.size <- size
else if size > bv.size then
) else if size > bv.size then
grow_to_at_least_ bv size
let resize_minimize_memory bv size =
@ -197,6 +236,10 @@ let flip bv i =
let clear bv = Bytes.fill bv.b 0 (Bytes.length bv.b) zero
let clear_and_shrink bv =
clear bv;
bv.size <- 0
let equal_bytes_ size b1 b2 =
try
for i = 0 to bytes_length_of_size size - 1 do
@ -208,31 +251,12 @@ let equal_bytes_ size b1 b2 =
let equal x y : bool = x.size = y.size && equal_bytes_ x.size x.b y.b
let iter bv f =
let len = bytes_length_of_size bv.size in
assert (len <= Bytes.length bv.b);
for n = 0 to len - 2 do
let j = mul_ n in
let word_n = unsafe_get_ bv.b n in
for i = 0 to width_ - 1 do
f (j + i) (word_n land (1 lsl i) <> 0)
done
done;
if bv.size > 0 then (
let j = mul_ (len - 1) in
let r = mod_ bv.size in
let final_length =
if r = 0 then
width_
else
r
in
let final_word = unsafe_get_ bv.b (len - 1) in
for i = 0 to final_length - 1 do
f (j + i) (final_word land (1 lsl i) <> 0)
done
)
iter_bytes_ bv ~f:(fun off width_n word_n ->
for i = 0 to width_n - 1 do
f (off + i) (word_n land (1 lsl i) <> 0)
done)
let[@inline] iter_true bv f =
let iter_true bv f =
iter bv (fun i b ->
if b then
f i
@ -248,7 +272,11 @@ let to_sorted_list bv = List.rev (to_list bv)
(* Interpret these as indices. *)
let of_list l =
let size = List.fold_left max 0 l + 1 in
let size =
match l with
| [] -> 0
| _ -> List.fold_left max 0 l + 1
in
let bv = create ~size false in
List.iter (fun i -> set bv i) l;
bv
@ -263,30 +291,16 @@ let first_exn bv =
let first bv = try Some (first_exn bv) with Not_found -> None
let filter bv p = iter_true bv (fun i -> if not (p i) then reset bv i)
let negate_self b =
let len = Bytes.length b.b in
for n = 0 to len - 1 do
unsafe_set_ b.b n (lnot (unsafe_get_ b.b n))
done;
let r = mod_ b.size in
if r <> 0 then (
let l = Bytes.length b.b - 1 in
unsafe_set_ b.b l (__lsb_mask r land unsafe_get_ b.b l)
)
let negate_self bv = map_bytes_ bv ~f:(fun b -> lnot b)
let negate a =
let b = Bytes.map (fun c -> Char.unsafe_chr (lnot (Char.code c))) a.b in
let r = mod_ a.size in
if r <> 0 then (
let l = Bytes.length a.b - 1 in
unsafe_set_ b l (__lsb_mask r land unsafe_get_ b l)
);
{ b; size = a.size }
let b = copy a in
negate_self b;
b
let union_into_no_resize_ ~into bv =
assert (Bytes.length into.b >= Bytes.length bv.b);
for i = 0 to Bytes.length bv.b - 1 do
assert (Bytes.length into.b >= bytes_length_of_size bv.size);
for i = 0 to bytes_length_of_size bv.size - 1 do
unsafe_set_ into.b i (unsafe_get_ into.b i lor unsafe_get_ bv.b i)
done
@ -308,8 +322,8 @@ let union b1 b2 =
)
let inter_into_no_resize_ ~into bv =
assert (Bytes.length into.b <= Bytes.length bv.b);
for i = 0 to Bytes.length into.b - 1 do
assert (into.size <= bv.size);
for i = 0 to bytes_length_of_size into.size - 1 do
unsafe_set_ into.b i (unsafe_get_ into.b i land unsafe_get_ bv.b i)
done
@ -395,4 +409,16 @@ module Internal_ = struct
let __popcount8 = __popcount8
let __lsb_mask = __lsb_mask
let __check_invariant self =
let n = div_ self.size in
let j = mod_ self.size in
assert (Bytes.length self.b >= n);
if j > 0 then
assert (
let c = get_ self.b n in
c land __lsb_mask j = c);
for i = n + 1 to Bytes.length self.b - 1 do
assert (get_ self.b i = 0)
done
end

View file

@ -71,7 +71,11 @@ val flip : t -> int -> unit
(** Flip i-th bit, extending the bitvector if needed. *)
val clear : t -> unit
(** Set every bit to 0. *)
(** Set every bit to 0. Does not change the length. *)
val clear_and_shrink : t -> unit
(** Set every bit to 0, and set length to 0.
@since NEXT_RELEASE *)
val iter : t -> (int -> bool -> unit) -> unit
(** Iterate on all bits. *)
@ -120,7 +124,12 @@ val union_into : into:t -> t -> unit
val inter_into : into:t -> t -> unit
(** [inter_into ~into bv] sets [into] to the intersection of itself and [bv].
Also updates the length of [into] to be at most [length bv]. *)
Also updates the length of [into] to be at most [length bv].
After executing:
- [length ~into' = min (length into) (length bv)].
- [for all i: get into' ==> get into i /\ get bv i]
*)
val union : t -> t -> t
(** [union bv1 bv2] returns the union of the two sets. *)
@ -166,6 +175,7 @@ module Internal_ : sig
val __to_word_l : t -> char list
val __popcount8 : int -> int
val __lsb_mask : int -> int
val __check_invariant : t -> unit
end
(**/**)