wip: refactor(CC): use data oriented techniques (SoA)

This commit is contained in:
Simon Cruanes 2021-08-25 09:36:40 -04:00
parent 782afa4415
commit 0efef5b6ef
No known key found for this signature in database
GPG key ID: 4AC01D0849AA62B6
2 changed files with 109 additions and 70 deletions

View file

@ -33,6 +33,31 @@ module Make (A: CC_ARG)
module Term = T.Term
module Fun = T.Fun
(* nodes are represented as integer offsets *)
module Node0 : sig
include Int_id.S
module NVec : Vec_sig.S with type elt := t
module Set : CCSet.S with type elt = t
module Tbl : CCHashtbl.S with type key = t
end = struct
include Int_id.Make()
module NVec = VecI32
module Set = Util.Int_set
module Tbl = Util.Int_tbl
end
module NVec = Node0.NVec
type node = Node0.t
type repr = node (* a node that is representative of its class *)
(* we keep several bitvectors in the congruence closure,
each mapping nodes to a boolean.
An individual bitvector is represented as its offset in the list of
bitvectors. *)
module Bit_field : Int_id.S = Int_id.Make()
type bitfield = Bit_field.t
(* TODO: remove
module Bits : sig
type t = private int
type field
@ -62,22 +87,24 @@ module Make (A: CC_ARG)
let merge = (lor)
let equal : t -> t -> bool = CCEqual.poly
end
*)
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
the representative. *)
type node = {
n_term: term;
mutable n_sig0: signature option; (* initial signature *)
mutable n_bits: Bits.t; (* bitfield for various properties *)
mutable n_parents: node Bag.t; (* parent terms of this node *)
mutable n_root: node; (* representative of congruence class (itself if a representative) *)
mutable n_next: node; (* pointer to next element of congruence class *)
mutable n_size: int; (* size of the class *)
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
(* TODO: sparse vec for n_sig0? *)
(* the node store, holds data for all the nodes *)
type node_store = {
n_term: term Vec.t; (* term for the node *)
n_sig0: signature Vec.t; (* initial signature, if any *)
n_parents: node Bag.t Vec.t; (* node -> parents(class(node)) *)
n_root: NVec.t; (* node -> repr(class(node)) *)
n_next: NVec.t; (* node -> next(class(node)) *)
n_size: VecI32.t; (* node -> size(class(node)) *)
n_as_lit: lit Int_tbl.t; (* root -> literal, if any *)
n_expl: explanation_forest_link Vec.t; (* proof forest *)
n_bitfields: Bitvec.t Vec.t; (* bitfield idx -> atom -> bool *)
}
(* TODO: use node array for 3rd param *)
and signature = (fun_, node, node list) view
and explanation_forest_link =
@ -97,67 +124,71 @@ module Make (A: CC_ARG)
| E_and of explanation * explanation
| E_theory of explanation
type repr = node
module N = struct
type t = node
include Node0
type store = node_store
let[@inline] equal (n1:t) n2 = n1 == n2
let[@inline] hash n = Term.hash n.n_term
let[@inline] term n = n.n_term
let[@inline] pp out n = Term.pp out n.n_term
let[@inline] as_lit n = n.n_as_lit
let[@inline] term self n = Vec.get self.n_term (n:t:>int)
let[@inline] pp self out n = Term.pp out (term self n)
let[@inline] as_lit self n = n.n_as_lit
let make (t:term) : t =
let rec n = {
n_term=t;
n_sig0= None;
n_bits=Bits.empty;
n_parents=Bag.empty;
n_as_lit=None; (* TODO: provide a method to do it *)
n_root=n;
n_expl=FL_none;
n_next=n;
n_size=1;
} in
let alloc (self:store) (t:term) : t =
let {
n_term; n_sig0; n_parents; n_root; n_next; n_size
} = self in
let n = Node0.of_int_unsafe (Vec.size n_term) in
Vec.push n_term t;
Vec.push n_sig0 (Opaque n); (* to be changed *)
Vec.push n_parents Bag.empty;
NVec.push n_root n;
NVec.push n_next n;
VecI32.push n_size 1;
n
let[@inline] is_root (n:node) : bool = n.n_root == n
let[@inline] is_root (self:store) (n:node) : bool =
let n2 = NVec.get self.n_root (n:t:>int) in
equal n n2
(* traverse the equivalence class of [n] *)
let iter_class_ (n:node) : node Iter.t =
let iter_class_ (self:store) (n:t) : t Iter.t =
fun yield ->
let rec aux u =
yield u;
if u.n_next != n then aux u.n_next
let u2 = NVec.get self.n_next (u:t:>int) in
if not (equal n u2) then aux u2
in
aux n
let[@inline] iter_class n =
assert (is_root n);
iter_class_ n
let[@inline] iter_class self n =
assert (is_root self n);
iter_class_ self n
let[@inline] iter_parents (n:node) : node Iter.t =
assert (is_root n);
Bag.to_iter n.n_parents
let[@inline] iter_parents self (n:node) : node Iter.t =
assert (is_root self n);
Bag.to_iter (Vec.get self.n_parents (n:t:>int))
type bitfield = Bits.field
let[@inline] get_field f t = Bits.get f t.n_bits
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
end
(* TODO: use a vec of bitvec *)
type nonrec bitfield = bitfield
module N_tbl = CCHashtbl.Make(N)
let[@inline] get_field (self:store) (f:bitfield) (n:t) =
let bv = Vec.get self.n_bitfields (f:>int) in
Bitvec.get bv (n:t:>int)
let[@inline] set_field (self:store) (f:bitfield) (n:t) b : unit =
let bv = Vec.get self.n_bitfields (f:>int) in
Bitvec.set bv (n:t:>int) b
(* non-recursive, inlinable function for [find] *)
let[@inline] find_ (n:node) : repr =
let n2 = n.n_root in
assert (N.is_root n2);
let[@inline] find (self:store) (n:t) : repr =
let n2 = NVec.get self.n_root (n:t:>int) in
assert (is_root self n2);
n2
let[@inline] same_class (n1:node)(n2:node): bool =
N.equal (find_ n1) (find_ n2)
let[@inline] same_class self (n1:node) (n2:node): bool =
equal (find self n1) (find self n2)
end
let[@inline] find _ n = find_ n
module N_tbl = Node0.Tbl
module Expl = struct
type t = explanation

View file

@ -383,6 +383,8 @@ module type CC_S = sig
merged, to detect conflicts and solve equations à la Shostak.
*)
module N : sig
type store
type t
(** An equivalent class, containing terms that are proved
to be equal.
@ -390,7 +392,7 @@ module type CC_S = sig
A value of type [t] points to a particular term, but see
{!find} to get the representative of the class. *)
val term : t -> term
val term : store -> t -> term
(** Term contained in this equivalence class.
If [is_root n], then [term n] is the class' representative term. *)
@ -402,19 +404,19 @@ module type CC_S = sig
val hash : t -> int
(** An opaque hash of this node. *)
val pp : t Fmt.printer
val pp : store -> t Fmt.printer
(** Unspecified printing of the node, for example its term,
a unique ID, etc. *)
val is_root : t -> bool
val is_root : store -> t -> bool
(** Is the node a root (ie the representative of its class)?
See {!find} to get the root. *)
val iter_class : t -> t Iter.t
val iter_class : store -> t -> t Iter.t
(** Traverse the congruence class.
Precondition: [is_root n] (see {!find} below) *)
val iter_parents : t -> t Iter.t
val iter_parents : store -> t -> t Iter.t
(** Traverse the parents of the class.
Precondition: [is_root n] (see {!find} below) *)
@ -456,6 +458,9 @@ module type CC_S = sig
val term_store : t -> term_store
val n_store : t -> N.store
(** Store of nodes *)
val find : t -> node -> repr
(** Current representative *)
@ -1213,13 +1218,14 @@ end = struct
else None
let on_new_term self cc n (t:T.t) : unit =
Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n);
let nstore = CC.n_store cc in
Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name (N.pp nstore) n);
let maybe_m, l = M.of_term cc n t in
begin match maybe_m with
| Some v ->
Log.debugf 20
(fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])"
M.name N.pp n M.pp v);
M.name (N.pp nstore) n M.pp v);
SI.CC.set_bitfield cc self.field_has_value true n;
N_tbl.add self.values n v
| None -> ()
@ -1228,25 +1234,25 @@ end = struct
(fun (n_u,m_u) ->
Log.debugf 20
(fun k->k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])"
M.name N.pp n N.pp n_u M.pp m_u);
M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u);
let n_u = CC.find cc n_u in
if CC.get_bitfield self.cc self.field_has_value n_u then (
let m_u' =
try N_tbl.find self.values n_u
with Not_found ->
Error.errorf "node %a has bitfield but no value" N.pp n_u
Error.errorf "node %a has bitfield but no value" (N.pp nstore) n_u
in
match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with
| Error expl ->
Error.errorf
"when merging@ @[for node %a@],@ \
values %a and %a:@ conflict %a"
N.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl
(N.pp nstore) n_u M.pp m_u M.pp m_u' CC.Expl.pp expl
| Ok m_u_merged ->
Log.debugf 20
(fun k->k "(@[monoid[%s].on-new-term.sub.merged@ \
:n %a@ :sub-t %a@ :value %a@])"
M.name N.pp n N.pp n_u M.pp m_u_merged);
M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u_merged);
N_tbl.add self.values n_u m_u_merged;
) else (
(* just add to [n_u] *)
@ -1261,12 +1267,13 @@ end = struct
N_tbl.to_iter self.values
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
let nstore = CC.n_store cc in
begin match get self n1, get self n2 with
| Some v1, Some v2 ->
Log.debugf 5
(fun k->k
"(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 %a@ :val2 %a@])@])"
M.name N.pp n1 M.pp v1 N.pp n2 M.pp v2);
M.name (N.pp nstore) n1 M.pp v1 (N.pp nstore) n2 M.pp v2);
begin match M.merge cc n1 v1 n2 v2 e_n1_n2 with
| Ok v' ->
N_tbl.remove self.values n2; (* only keep repr *)
@ -1282,7 +1289,8 @@ end = struct
end
let pp out (self:t) : unit =
let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" N.pp t M.pp v in
let nstore = CC.n_store self.cc in
let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" (N.pp nstore) t M.pp v in
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) (iter_all self)
let create_and_setup ?size (solver:SI.t) : t =