perf: restore previous handling of bitfields

This commit is contained in:
Simon Cruanes 2021-08-26 00:10:43 -04:00
parent 4fd291b117
commit 4e0eae6d9e
No known key found for this signature in database
GPG key ID: 4AC01D0849AA62B6

View file

@ -50,12 +50,37 @@ module Make (A: CC_ARG)
type node = Node0.t
type repr = node (* a node that is representative of its class *)
(* we keep several bitvectors in the congruence closure,
each mapping nodes to a boolean.
An individual bitvector is represented as its offset in the list of
bitvectors. *)
module Bit_field : Int_id.S = Int_id.Make()
type bitfield = Bit_field.t
(* we associate bitfields to each node in the congruence closure.
They're packed into an integer. *)
module Bits : sig
type t = private int
type field
type bitfield_gen
val empty : t
val equal : t -> t -> bool
val mk_field : bitfield_gen -> field
val mk_gen : unit -> bitfield_gen
val get : field -> t -> bool
val set : field -> t -> bool -> t
val merge : t -> t -> t
end = struct
type bitfield_gen = int ref
let max_width = Sys.word_size - 2
let mk_gen() = ref 0
type t = int
type field = int
let empty : t = 0
let mk_field (gen:bitfield_gen) : field =
let n = !gen in
if n > max_width then Error.errorf "maximum number of CC bitfields reached";
incr gen;
1 lsl n
let[@inline] get field x = (x land field) <> 0
let[@inline] set field x b =
if b then x lor field else x land (lnot field)
let merge = (lor)
let equal : t -> t -> bool = CCEqual.poly
end
(* TODO: sparse vec for n_sig0? *)
(* the node store, holds data for all the nodes *)
@ -66,9 +91,9 @@ module Make (A: CC_ARG)
n_root: NVec.t; (* node -> repr(class(node)) *)
n_next: NVec.t; (* node -> next(class(node)) *)
n_size: VecI32.t; (* node -> size(class(node)) *)
n_as_lit: lit Node0.Tbl.t; (* root -> literal, if any *)
n_as_lit: lit option Vec.t; (* root -> literal, if any *)
n_expl: explanation_forest_link Vec.t; (* proof forest *)
n_bitfields: Bitvec.t Vec.t; (* bitfield idx -> atom -> bool *)
n_bitfields: Bits.t Vec.t; (* node -> bitfields *)
}
(* TODO: use node array for 3rd param *)
@ -112,14 +137,14 @@ module Make (A: CC_ARG)
let[@inline] set_parents self n b = Vec.set self.n_parents (n:t:>int) b
let[@inline] upd_parents ~f self n = set_parents self n (f (parents self n))
let[@inline] as_lit self n = Tbl.get self.n_as_lit n
let[@inline] set_as_lit self n lit = Tbl.replace self.n_as_lit n lit
let[@inline] clear_as_lit self n = Tbl.remove self.n_as_lit n
let[@inline] as_lit self n = Vec.get self.n_as_lit (n:t:>int)
let[@inline] set_as_lit self n lit = Vec.set self.n_as_lit (n:t:>int) (Some lit)
let[@inline] clear_as_lit self n = Vec.set self.n_as_lit (n:t:>int) None
let alloc (self:store) (t:term) : t =
let {
n_term; n_sig0; n_parents; n_root; n_next; n_size; n_expl;
n_as_lit=_; n_bitfields;
n_as_lit; n_bitfields;
} = self in
let n = Node0.of_int_unsafe (Vec.size n_term) in
Vec.push n_term t;
@ -129,7 +154,8 @@ module Make (A: CC_ARG)
NVec.push n_root n;
NVec.push n_next n;
VecI32.push n_size 1;
Vec.iter (fun bv -> Bitvec.ensure_size bv ((n:t:>int)+1)) n_bitfields;
Vec.push n_as_lit None;
Vec.push n_bitfields Bits.empty;
assert (term self n == t);
n
@ -138,7 +164,7 @@ module Make (A: CC_ARG)
assert ((n:>int) + 1 = Vec.size self.n_term);
let {
n_term; n_sig0; n_parents; n_root; n_next; n_size; n_expl;
n_as_lit=_; n_bitfields=_;
n_as_lit; n_bitfields;
} = self in
ignore (Vec.pop_exn n_term : term);
ignore (Vec.pop_exn n_sig0 : signature);
@ -147,6 +173,8 @@ module Make (A: CC_ARG)
ignore (NVec.pop n_next : t);
ignore (Vec.pop_exn n_expl : explanation_forest_link);
ignore (VecI32.pop n_size : int);
ignore (Vec.pop_exn n_as_lit : _ option);
ignore (Vec.pop_exn n_bitfields : Bits.t);
()
let[@inline] is_root (self:store) (n:node) : bool =
@ -171,15 +199,7 @@ module Make (A: CC_ARG)
assert (is_root self n);
Bag.to_iter (Vec.get self.n_parents (n:t:>int))
type nonrec bitfield = bitfield
let alloc_bitfield ~descr (self:store) : bitfield =
Log.debugf 5 (fun k->k "(@[cc.allocate-bit-field@ :descr %s@])" descr);
let field = Bit_field.of_int_unsafe (Vec.size self.n_bitfields) in
let bv = Bitvec.create() in
Bitvec.ensure_size bv (Vec.size self.n_term);
Vec.push self.n_bitfields bv;
field
type bitfield = Bits.field
let create () : store = {
n_term=Vec.create ();
@ -189,17 +209,16 @@ module Make (A: CC_ARG)
n_parents=Vec.create ();
n_size=VecI32.create ~cap:1024 ();
n_expl=Vec.create ();
n_as_lit=Tbl.create 256;
n_as_lit=Vec.create ();
n_bitfields=Vec.create();
}
let[@inline] get_field (self:store) (f:bitfield) (n:t) =
let bv = Vec.get self.n_bitfields (f:>int) in
Bitvec.get bv (n:t:>int)
let[@inline] set_field (self:store) (f:bitfield) (n:t) b : unit =
let bv = Vec.get self.n_bitfields (f:>int) in
Bitvec.set bv (n:t:>int) b
let[@inline] bitfields self n = Vec.get self.n_bitfields (n:t:>int)
let[@inline] set_bitfields self n f = Vec.set self.n_bitfields (n:t:>int) f
let[@inline] get_field self f n = Bits.get f (Vec.get self.n_bitfields (n:t:>int))
let[@inline] set_field self f n b =
let cur_v = Vec.get self.n_bitfields (n:t:>int) in
Vec.set self.n_bitfields (n:t:>int) (Bits.set f cur_v b)
(* non-recursive, inlinable function for [find] *)
let[@inline] find (self:store) (n:t) : repr =
@ -324,6 +343,7 @@ module Make (A: CC_ARG)
mutable on_propagate: ev_on_propagate list;
mutable on_is_subterm : ev_on_is_subterm list;
mutable new_merges: bool; (* true if >=1 class was modified since last check *)
bitgen: Bits.bitfield_gen;
field_marked_explain: N.bitfield; (* used to mark traversed nodes when looking for a common ancestor *)
true_ : node lazy_t;
false_ : node lazy_t;
@ -351,14 +371,9 @@ module Make (A: CC_ARG)
let[@inline] term_store (cc:t) = cc.tst
let[@inline] n_store (cc:t) = cc.nstore
(* new bitfield *)
let allocate_bitfield ~descr self = N.alloc_bitfield ~descr self.nstore
(* iterate on existing bitfields *)
let[@inline] iter_bitfields (self:t) ~(f:Bit_field.t -> unit) : unit =
for i=0 to Vec.size self.nstore.n_bitfields - 1 do
f (Bit_field.of_int_unsafe i)
done
let allocate_bitfield ~descr self =
Log.debugf 5 (fun k->k "(@[cc.allocate-bit-field@ :descr %s@])" descr);
Bits.mk_field self.bitgen
let[@inline] on_backtrack cc f : unit =
Backtrack_stack.push_if_nonzero_level cc.undo f
@ -807,24 +822,14 @@ module Make (A: CC_ARG)
let r_into_old_next = N.next nstore r_into in
let r_from_old_next = N.next nstore r_from in
let r_into_old_parents = N.parents nstore r_into in
let r_into_old_bits = N.bitfields nstore r_into in
(* swap [into.next] and [from.next], merging the classes *)
N.set_next nstore r_into r_from_old_next;
N.set_next nstore r_from r_into_old_next;
N.upd_parents nstore r_into ~f:(fun p -> Bag.append p (N.parents nstore r_from));
N.set_size nstore r_into (N.size nstore r_into + N.size nstore r_from);
(* merge bitfields, and backtrack changes *)
iter_bitfields self ~f:(fun field ->
let b_into = N.get_field nstore field r_into in
if not b_into then (
let b_from = N.get_field nstore field r_from in
if b_from then (
(* we modify the field of [r_into], remember to undo that *)
on_backtrack self (fun () -> N.set_field nstore field r_into false);
N.set_field nstore field r_into true;
);
));
N.set_bitfields nstore r_into
(Bits.merge (N.bitfields nstore r_into) (N.bitfields nstore r_from));
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
on_backtrack self
@ -835,6 +840,7 @@ module Make (A: CC_ARG)
N.set_next nstore r_into r_into_old_next;
N.set_next nstore r_from r_from_old_next;
N.set_parents nstore r_into r_into_old_parents;
N.set_bitfields nstore r_into r_into_old_bits;
(* NOTE: this must come after the restoration of [next] pointers,
otherwise we'd iterate on too big a class *)
N.iter_class_ nstore r_from (fun u -> N.set_root nstore u r_from);
@ -996,14 +1002,15 @@ module Make (A: CC_ARG)
?(size=`Big)
(tst:term_store) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let bitgen = Bits.mk_gen () in
let nstore = N.create() in
let field_marked_explain = N.alloc_bitfield ~descr:"mark-explain" nstore in
assert ((field_marked_explain :> int) = 0);
let field_marked_explain = Bits.mk_field bitgen in
let rec cc = {
tst;
nstore;
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
bitgen;
on_pre_merge;
on_post_merge;
on_new_term;