mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-06 03:05:28 -05:00
feat(BV): correct many bugs, clarify parts of the API
This commit is contained in:
parent
75fe196d3a
commit
60b9ece69e
2 changed files with 99 additions and 63 deletions
148
src/data/CCBV.ml
148
src/data/CCBV.ml
|
|
@ -1,5 +1,3 @@
|
|||
(** {2 Imperative Bitvectors} *)
|
||||
|
||||
let width_ = 8
|
||||
|
||||
(* Helper functions *)
|
||||
|
|
@ -52,6 +50,12 @@ let[@inline] __popcount8 (b : int) : int =
|
|||
let b = (b + (b lsr 4)) land m4 in
|
||||
b land 0x7f
|
||||
|
||||
(*
|
||||
invariants for [v:t]:
|
||||
|
||||
- [Bytes.length v.b >= div_ v.size] (enough storage)
|
||||
- all bits above [size] are 0 in [v.b]
|
||||
*)
|
||||
type t = { mutable b: bytes; mutable size: int }
|
||||
|
||||
let length t = t.size
|
||||
|
|
@ -76,22 +80,43 @@ let create ~size default =
|
|||
in
|
||||
(* adjust last bits *)
|
||||
let r = mod_ size in
|
||||
if default && r <> 0 then
|
||||
Bytes.unsafe_set b (n - 1) (Char.unsafe_chr (__lsb_mask r));
|
||||
if default && r <> 0 then unsafe_set_ b (n - 1) (__lsb_mask r);
|
||||
{ b; size }
|
||||
)
|
||||
|
||||
let[@inline] copy bv = { bv with b = Bytes.copy bv.b }
|
||||
let copy bv = { bv with b = Bytes.sub bv.b 0 (bytes_length_of_size bv.size) }
|
||||
let[@inline] capacity bv = mul_ (Bytes.length bv.b)
|
||||
|
||||
(* call [f i width(byte[i]) (byte[i])] on each byte.
|
||||
The last byte might have a width of less than 8. *)
|
||||
let iter_bytes_ (b : t) ~f : unit =
|
||||
for n = 0 to div_ b.size - 1 do
|
||||
f (mul_ n) width_ (unsafe_get_ b.b n)
|
||||
done;
|
||||
let r = mod_ b.size in
|
||||
if r <> 0 then (
|
||||
let last = div_ b.size in
|
||||
f (mul_ last) r (__lsb_mask r land unsafe_get_ b.b last)
|
||||
)
|
||||
|
||||
(* set [byte[i]] to [f(byte[i])] *)
|
||||
let map_bytes_ (b : t) ~f : unit =
|
||||
for n = 0 to div_ b.size - 1 do
|
||||
unsafe_set_ b.b n (f (unsafe_get_ b.b n))
|
||||
done;
|
||||
let r = mod_ b.size in
|
||||
if r <> 0 then (
|
||||
let last = div_ b.size in
|
||||
let mask = __lsb_mask r in
|
||||
unsafe_set_ b.b last (mask land f (mask land unsafe_get_ b.b last))
|
||||
)
|
||||
|
||||
let cardinal bv =
|
||||
if bv.size = 0 then
|
||||
0
|
||||
else (
|
||||
let n = ref 0 in
|
||||
for i = 0 to Bytes.length bv.b - 1 do
|
||||
n := !n + __popcount8 (get_ bv.b i) (* MSB of last element are all 0 *)
|
||||
done;
|
||||
iter_bytes_ bv ~f:(fun _ _ b -> n := !n + __popcount8 b);
|
||||
!n
|
||||
)
|
||||
|
||||
|
|
@ -99,10 +124,19 @@ let really_resize_ bv ~desired ~current size =
|
|||
bv.size <- size;
|
||||
if desired <> current then (
|
||||
let b = Bytes.make desired zero in
|
||||
Bytes.blit bv.b 0 b 0 current;
|
||||
Bytes.blit bv.b 0 b 0 (min desired current);
|
||||
bv.b <- b
|
||||
)
|
||||
|
||||
(* set bits above [n] to 0 *)
|
||||
let[@inline never] clear_bits_above_ bv top =
|
||||
let n = div_ top in
|
||||
let j = mod_ top in
|
||||
Bytes.fill bv.b (n + 1)
|
||||
(bytes_length_of_size bv.size - n - 1)
|
||||
(Char.unsafe_chr 0);
|
||||
unsafe_set_ bv.b n (unsafe_get_ bv.b n land __lsb_mask j)
|
||||
|
||||
let[@inline never] grow_to_at_least_real_ bv size =
|
||||
(* beyond capacity *)
|
||||
let current = Bytes.length bv.b in
|
||||
|
|
@ -123,15 +157,20 @@ let grow_to_at_least_ bv size =
|
|||
grow_to_at_least_real_ bv size
|
||||
|
||||
let shrink_ bv size =
|
||||
let desired = bytes_length_of_size size in
|
||||
let current = Bytes.length bv.b in
|
||||
really_resize_ bv ~desired ~current size
|
||||
assert (size <= bv.size);
|
||||
if size < bv.size then (
|
||||
let desired = bytes_length_of_size size in
|
||||
let current = Bytes.length bv.b in
|
||||
if desired = current then clear_bits_above_ bv size;
|
||||
really_resize_ bv ~desired ~current size
|
||||
)
|
||||
|
||||
let resize bv size =
|
||||
if size < 0 then invalid_arg "resize: negative size";
|
||||
if size < bv.size then
|
||||
if size < bv.size then (
|
||||
clear_bits_above_ bv size;
|
||||
bv.size <- size
|
||||
else if size > bv.size then
|
||||
) else if size > bv.size then
|
||||
grow_to_at_least_ bv size
|
||||
|
||||
let resize_minimize_memory bv size =
|
||||
|
|
@ -197,6 +236,10 @@ let flip bv i =
|
|||
|
||||
let clear bv = Bytes.fill bv.b 0 (Bytes.length bv.b) zero
|
||||
|
||||
let clear_and_shrink bv =
|
||||
clear bv;
|
||||
bv.size <- 0
|
||||
|
||||
let equal_bytes_ size b1 b2 =
|
||||
try
|
||||
for i = 0 to bytes_length_of_size size - 1 do
|
||||
|
|
@ -208,31 +251,12 @@ let equal_bytes_ size b1 b2 =
|
|||
let equal x y : bool = x.size = y.size && equal_bytes_ x.size x.b y.b
|
||||
|
||||
let iter bv f =
|
||||
let len = bytes_length_of_size bv.size in
|
||||
assert (len <= Bytes.length bv.b);
|
||||
for n = 0 to len - 2 do
|
||||
let j = mul_ n in
|
||||
let word_n = unsafe_get_ bv.b n in
|
||||
for i = 0 to width_ - 1 do
|
||||
f (j + i) (word_n land (1 lsl i) <> 0)
|
||||
done
|
||||
done;
|
||||
if bv.size > 0 then (
|
||||
let j = mul_ (len - 1) in
|
||||
let r = mod_ bv.size in
|
||||
let final_length =
|
||||
if r = 0 then
|
||||
width_
|
||||
else
|
||||
r
|
||||
in
|
||||
let final_word = unsafe_get_ bv.b (len - 1) in
|
||||
for i = 0 to final_length - 1 do
|
||||
f (j + i) (final_word land (1 lsl i) <> 0)
|
||||
done
|
||||
)
|
||||
iter_bytes_ bv ~f:(fun off width_n word_n ->
|
||||
for i = 0 to width_n - 1 do
|
||||
f (off + i) (word_n land (1 lsl i) <> 0)
|
||||
done)
|
||||
|
||||
let[@inline] iter_true bv f =
|
||||
let iter_true bv f =
|
||||
iter bv (fun i b ->
|
||||
if b then
|
||||
f i
|
||||
|
|
@ -248,7 +272,11 @@ let to_sorted_list bv = List.rev (to_list bv)
|
|||
|
||||
(* Interpret these as indices. *)
|
||||
let of_list l =
|
||||
let size = List.fold_left max 0 l + 1 in
|
||||
let size =
|
||||
match l with
|
||||
| [] -> 0
|
||||
| _ -> List.fold_left max 0 l + 1
|
||||
in
|
||||
let bv = create ~size false in
|
||||
List.iter (fun i -> set bv i) l;
|
||||
bv
|
||||
|
|
@ -263,30 +291,16 @@ let first_exn bv =
|
|||
|
||||
let first bv = try Some (first_exn bv) with Not_found -> None
|
||||
let filter bv p = iter_true bv (fun i -> if not (p i) then reset bv i)
|
||||
|
||||
let negate_self b =
|
||||
let len = Bytes.length b.b in
|
||||
for n = 0 to len - 1 do
|
||||
unsafe_set_ b.b n (lnot (unsafe_get_ b.b n))
|
||||
done;
|
||||
let r = mod_ b.size in
|
||||
if r <> 0 then (
|
||||
let l = Bytes.length b.b - 1 in
|
||||
unsafe_set_ b.b l (__lsb_mask r land unsafe_get_ b.b l)
|
||||
)
|
||||
let negate_self bv = map_bytes_ bv ~f:(fun b -> lnot b)
|
||||
|
||||
let negate a =
|
||||
let b = Bytes.map (fun c -> Char.unsafe_chr (lnot (Char.code c))) a.b in
|
||||
let r = mod_ a.size in
|
||||
if r <> 0 then (
|
||||
let l = Bytes.length a.b - 1 in
|
||||
unsafe_set_ b l (__lsb_mask r land unsafe_get_ b l)
|
||||
);
|
||||
{ b; size = a.size }
|
||||
let b = copy a in
|
||||
negate_self b;
|
||||
b
|
||||
|
||||
let union_into_no_resize_ ~into bv =
|
||||
assert (Bytes.length into.b >= Bytes.length bv.b);
|
||||
for i = 0 to Bytes.length bv.b - 1 do
|
||||
assert (Bytes.length into.b >= bytes_length_of_size bv.size);
|
||||
for i = 0 to bytes_length_of_size bv.size - 1 do
|
||||
unsafe_set_ into.b i (unsafe_get_ into.b i lor unsafe_get_ bv.b i)
|
||||
done
|
||||
|
||||
|
|
@ -308,8 +322,8 @@ let union b1 b2 =
|
|||
)
|
||||
|
||||
let inter_into_no_resize_ ~into bv =
|
||||
assert (Bytes.length into.b <= Bytes.length bv.b);
|
||||
for i = 0 to Bytes.length into.b - 1 do
|
||||
assert (into.size <= bv.size);
|
||||
for i = 0 to bytes_length_of_size into.size - 1 do
|
||||
unsafe_set_ into.b i (unsafe_get_ into.b i land unsafe_get_ bv.b i)
|
||||
done
|
||||
|
||||
|
|
@ -395,4 +409,16 @@ module Internal_ = struct
|
|||
|
||||
let __popcount8 = __popcount8
|
||||
let __lsb_mask = __lsb_mask
|
||||
|
||||
let __check_invariant self =
|
||||
let n = div_ self.size in
|
||||
let j = mod_ self.size in
|
||||
assert (Bytes.length self.b >= n);
|
||||
if j > 0 then
|
||||
assert (
|
||||
let c = get_ self.b n in
|
||||
c land __lsb_mask j = c);
|
||||
for i = n + 1 to Bytes.length self.b - 1 do
|
||||
assert (get_ self.b i = 0)
|
||||
done
|
||||
end
|
||||
|
|
|
|||
|
|
@ -71,7 +71,11 @@ val flip : t -> int -> unit
|
|||
(** Flip i-th bit, extending the bitvector if needed. *)
|
||||
|
||||
val clear : t -> unit
|
||||
(** Set every bit to 0. *)
|
||||
(** Set every bit to 0. Does not change the length. *)
|
||||
|
||||
val clear_and_shrink : t -> unit
|
||||
(** Set every bit to 0, and set length to 0.
|
||||
@since NEXT_RELEASE *)
|
||||
|
||||
val iter : t -> (int -> bool -> unit) -> unit
|
||||
(** Iterate on all bits. *)
|
||||
|
|
@ -120,7 +124,12 @@ val union_into : into:t -> t -> unit
|
|||
|
||||
val inter_into : into:t -> t -> unit
|
||||
(** [inter_into ~into bv] sets [into] to the intersection of itself and [bv].
|
||||
Also updates the length of [into] to be at most [length bv]. *)
|
||||
Also updates the length of [into] to be at most [length bv].
|
||||
|
||||
After executing:
|
||||
- [length ~into' = min (length into) (length bv)].
|
||||
- [for all i: get into' ==> get into i /\ get bv i]
|
||||
*)
|
||||
|
||||
val union : t -> t -> t
|
||||
(** [union bv1 bv2] returns the union of the two sets. *)
|
||||
|
|
@ -166,6 +175,7 @@ module Internal_ : sig
|
|||
val __to_word_l : t -> char list
|
||||
val __popcount8 : int -> int
|
||||
val __lsb_mask : int -> int
|
||||
val __check_invariant : t -> unit
|
||||
end
|
||||
|
||||
(**/**)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue