wip: refactor(CC): use data oriented techniques (SoA)

This commit is contained in:
Simon Cruanes 2021-08-25 09:36:40 -04:00
parent 782afa4415
commit 0efef5b6ef
No known key found for this signature in database
GPG key ID: 4AC01D0849AA62B6
2 changed files with 109 additions and 70 deletions

View file

@ -33,6 +33,31 @@ module Make (A: CC_ARG)
module Term = T.Term module Term = T.Term
module Fun = T.Fun module Fun = T.Fun
(* nodes are represented as integer offsets *)
module Node0 : sig
include Int_id.S
module NVec : Vec_sig.S with type elt := t
module Set : CCSet.S with type elt = t
module Tbl : CCHashtbl.S with type key = t
end = struct
include Int_id.Make()
module NVec = VecI32
module Set = Util.Int_set
module Tbl = Util.Int_tbl
end
module NVec = Node0.NVec
type node = Node0.t
type repr = node (* a node that is representative of its class *)
(* we keep several bitvectors in the congruence closure,
each mapping nodes to a boolean.
An individual bitvector is represented as its offset in the list of
bitvectors. *)
module Bit_field : Int_id.S = Int_id.Make()
type bitfield = Bit_field.t
(* TODO: remove
module Bits : sig module Bits : sig
type t = private int type t = private int
type field type field
@ -62,22 +87,24 @@ module Make (A: CC_ARG)
let merge = (lor) let merge = (lor)
let equal : t -> t -> bool = CCEqual.poly let equal : t -> t -> bool = CCEqual.poly
end end
*)
(** A node of the congruence closure. (* TODO: sparse vec for n_sig0? *)
An equivalence class is represented by its "root" element,
the representative. *) (* the node store, holds data for all the nodes *)
type node = { type node_store = {
n_term: term; n_term: term Vec.t; (* term for the node *)
mutable n_sig0: signature option; (* initial signature *) n_sig0: signature Vec.t; (* initial signature, if any *)
mutable n_bits: Bits.t; (* bitfield for various properties *) n_parents: node Bag.t Vec.t; (* node -> parents(class(node)) *)
mutable n_parents: node Bag.t; (* parent terms of this node *) n_root: NVec.t; (* node -> repr(class(node)) *)
mutable n_root: node; (* representative of congruence class (itself if a representative) *) n_next: NVec.t; (* node -> next(class(node)) *)
mutable n_next: node; (* pointer to next element of congruence class *) n_size: VecI32.t; (* node -> size(class(node)) *)
mutable n_size: int; (* size of the class *) n_as_lit: lit Int_tbl.t; (* root -> literal, if any *)
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) n_expl: explanation_forest_link Vec.t; (* proof forest *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) n_bitfields: Bitvec.t Vec.t; (* bitfield idx -> atom -> bool *)
} }
(* TODO: use node array for 3rd param *)
and signature = (fun_, node, node list) view and signature = (fun_, node, node list) view
and explanation_forest_link = and explanation_forest_link =
@ -97,67 +124,71 @@ module Make (A: CC_ARG)
| E_and of explanation * explanation | E_and of explanation * explanation
| E_theory of explanation | E_theory of explanation
type repr = node
module N = struct module N = struct
type t = node include Node0
type store = node_store
let[@inline] equal (n1:t) n2 = n1 == n2 let[@inline] term self n = Vec.get self.n_term (n:t:>int)
let[@inline] hash n = Term.hash n.n_term let[@inline] pp self out n = Term.pp out (term self n)
let[@inline] term n = n.n_term let[@inline] as_lit self n = n.n_as_lit
let[@inline] pp out n = Term.pp out n.n_term
let[@inline] as_lit n = n.n_as_lit
let make (t:term) : t = let alloc (self:store) (t:term) : t =
let rec n = { let {
n_term=t; n_term; n_sig0; n_parents; n_root; n_next; n_size
n_sig0= None; } = self in
n_bits=Bits.empty; let n = Node0.of_int_unsafe (Vec.size n_term) in
n_parents=Bag.empty; Vec.push n_term t;
n_as_lit=None; (* TODO: provide a method to do it *) Vec.push n_sig0 (Opaque n); (* to be changed *)
n_root=n; Vec.push n_parents Bag.empty;
n_expl=FL_none; NVec.push n_root n;
n_next=n; NVec.push n_next n;
n_size=1; VecI32.push n_size 1;
} in
n n
let[@inline] is_root (n:node) : bool = n.n_root == n let[@inline] is_root (self:store) (n:node) : bool =
let n2 = NVec.get self.n_root (n:t:>int) in
equal n n2
(* traverse the equivalence class of [n] *) (* traverse the equivalence class of [n] *)
let iter_class_ (n:node) : node Iter.t = let iter_class_ (self:store) (n:t) : t Iter.t =
fun yield -> fun yield ->
let rec aux u = let rec aux u =
yield u; yield u;
if u.n_next != n then aux u.n_next let u2 = NVec.get self.n_next (u:t:>int) in
if not (equal n u2) then aux u2
in in
aux n aux n
let[@inline] iter_class n = let[@inline] iter_class self n =
assert (is_root n); assert (is_root self n);
iter_class_ n iter_class_ self n
let[@inline] iter_parents (n:node) : node Iter.t = let[@inline] iter_parents self (n:node) : node Iter.t =
assert (is_root n); assert (is_root self n);
Bag.to_iter n.n_parents Bag.to_iter (Vec.get self.n_parents (n:t:>int))
type bitfield = Bits.field (* TODO: use a vec of bitvec *)
let[@inline] get_field f t = Bits.get f t.n_bits type nonrec bitfield = bitfield
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
let[@inline] get_field (self:store) (f:bitfield) (n:t) =
let bv = Vec.get self.n_bitfields (f:>int) in
Bitvec.get bv (n:t:>int)
let[@inline] set_field (self:store) (f:bitfield) (n:t) b : unit =
let bv = Vec.get self.n_bitfields (f:>int) in
Bitvec.set bv (n:t:>int) b
(* non-recursive, inlinable function for [find] *)
let[@inline] find (self:store) (n:t) : repr =
let n2 = NVec.get self.n_root (n:t:>int) in
assert (is_root self n2);
n2
let[@inline] same_class self (n1:node) (n2:node): bool =
equal (find self n1) (find self n2)
end end
module N_tbl = CCHashtbl.Make(N) module N_tbl = Node0.Tbl
(* non-recursive, inlinable function for [find] *)
let[@inline] find_ (n:node) : repr =
let n2 = n.n_root in
assert (N.is_root n2);
n2
let[@inline] same_class (n1:node)(n2:node): bool =
N.equal (find_ n1) (find_ n2)
let[@inline] find _ n = find_ n
module Expl = struct module Expl = struct
type t = explanation type t = explanation

View file

@ -383,6 +383,8 @@ module type CC_S = sig
merged, to detect conflicts and solve equations à la Shostak. merged, to detect conflicts and solve equations à la Shostak.
*) *)
module N : sig module N : sig
type store
type t type t
(** An equivalent class, containing terms that are proved (** An equivalent class, containing terms that are proved
to be equal. to be equal.
@ -390,7 +392,7 @@ module type CC_S = sig
A value of type [t] points to a particular term, but see A value of type [t] points to a particular term, but see
{!find} to get the representative of the class. *) {!find} to get the representative of the class. *)
val term : t -> term val term : store -> t -> term
(** Term contained in this equivalence class. (** Term contained in this equivalence class.
If [is_root n], then [term n] is the class' representative term. *) If [is_root n], then [term n] is the class' representative term. *)
@ -402,19 +404,19 @@ module type CC_S = sig
val hash : t -> int val hash : t -> int
(** An opaque hash of this node. *) (** An opaque hash of this node. *)
val pp : t Fmt.printer val pp : store -> t Fmt.printer
(** Unspecified printing of the node, for example its term, (** Unspecified printing of the node, for example its term,
a unique ID, etc. *) a unique ID, etc. *)
val is_root : t -> bool val is_root : store -> t -> bool
(** Is the node a root (ie the representative of its class)? (** Is the node a root (ie the representative of its class)?
See {!find} to get the root. *) See {!find} to get the root. *)
val iter_class : t -> t Iter.t val iter_class : store -> t -> t Iter.t
(** Traverse the congruence class. (** Traverse the congruence class.
Precondition: [is_root n] (see {!find} below) *) Precondition: [is_root n] (see {!find} below) *)
val iter_parents : t -> t Iter.t val iter_parents : store -> t -> t Iter.t
(** Traverse the parents of the class. (** Traverse the parents of the class.
Precondition: [is_root n] (see {!find} below) *) Precondition: [is_root n] (see {!find} below) *)
@ -456,6 +458,9 @@ module type CC_S = sig
val term_store : t -> term_store val term_store : t -> term_store
val n_store : t -> N.store
(** Store of nodes *)
val find : t -> node -> repr val find : t -> node -> repr
(** Current representative *) (** Current representative *)
@ -1213,13 +1218,14 @@ end = struct
else None else None
let on_new_term self cc n (t:T.t) : unit = let on_new_term self cc n (t:T.t) : unit =
Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n); let nstore = CC.n_store cc in
Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name (N.pp nstore) n);
let maybe_m, l = M.of_term cc n t in let maybe_m, l = M.of_term cc n t in
begin match maybe_m with begin match maybe_m with
| Some v -> | Some v ->
Log.debugf 20 Log.debugf 20
(fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])" (fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])"
M.name N.pp n M.pp v); M.name (N.pp nstore) n M.pp v);
SI.CC.set_bitfield cc self.field_has_value true n; SI.CC.set_bitfield cc self.field_has_value true n;
N_tbl.add self.values n v N_tbl.add self.values n v
| None -> () | None -> ()
@ -1228,25 +1234,25 @@ end = struct
(fun (n_u,m_u) -> (fun (n_u,m_u) ->
Log.debugf 20 Log.debugf 20
(fun k->k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])" (fun k->k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])"
M.name N.pp n N.pp n_u M.pp m_u); M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u);
let n_u = CC.find cc n_u in let n_u = CC.find cc n_u in
if CC.get_bitfield self.cc self.field_has_value n_u then ( if CC.get_bitfield self.cc self.field_has_value n_u then (
let m_u' = let m_u' =
try N_tbl.find self.values n_u try N_tbl.find self.values n_u
with Not_found -> with Not_found ->
Error.errorf "node %a has bitfield but no value" N.pp n_u Error.errorf "node %a has bitfield but no value" (N.pp nstore) n_u
in in
match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with
| Error expl -> | Error expl ->
Error.errorf Error.errorf
"when merging@ @[for node %a@],@ \ "when merging@ @[for node %a@],@ \
values %a and %a:@ conflict %a" values %a and %a:@ conflict %a"
N.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl (N.pp nstore) n_u M.pp m_u M.pp m_u' CC.Expl.pp expl
| Ok m_u_merged -> | Ok m_u_merged ->
Log.debugf 20 Log.debugf 20
(fun k->k "(@[monoid[%s].on-new-term.sub.merged@ \ (fun k->k "(@[monoid[%s].on-new-term.sub.merged@ \
:n %a@ :sub-t %a@ :value %a@])" :n %a@ :sub-t %a@ :value %a@])"
M.name N.pp n N.pp n_u M.pp m_u_merged); M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u_merged);
N_tbl.add self.values n_u m_u_merged; N_tbl.add self.values n_u m_u_merged;
) else ( ) else (
(* just add to [n_u] *) (* just add to [n_u] *)
@ -1261,12 +1267,13 @@ end = struct
N_tbl.to_iter self.values N_tbl.to_iter self.values
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit = let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
let nstore = CC.n_store cc in
begin match get self n1, get self n2 with begin match get self n1, get self n2 with
| Some v1, Some v2 -> | Some v1, Some v2 ->
Log.debugf 5 Log.debugf 5
(fun k->k (fun k->k
"(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 %a@ :val2 %a@])@])" "(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 %a@ :val2 %a@])@])"
M.name N.pp n1 M.pp v1 N.pp n2 M.pp v2); M.name (N.pp nstore) n1 M.pp v1 (N.pp nstore) n2 M.pp v2);
begin match M.merge cc n1 v1 n2 v2 e_n1_n2 with begin match M.merge cc n1 v1 n2 v2 e_n1_n2 with
| Ok v' -> | Ok v' ->
N_tbl.remove self.values n2; (* only keep repr *) N_tbl.remove self.values n2; (* only keep repr *)
@ -1282,7 +1289,8 @@ end = struct
end end
let pp out (self:t) : unit = let pp out (self:t) : unit =
let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" N.pp t M.pp v in let nstore = CC.n_store self.cc in
let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" (N.pp nstore) t M.pp v in
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) (iter_all self) Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) (iter_all self)
let create_and_setup ?size (solver:SI.t) : t = let create_and_setup ?size (solver:SI.t) : t =