From 0efef5b6efffa2b016f7fb5dee4bbeef88d5f29f Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 25 Aug 2021 09:36:40 -0400 Subject: [PATCH] wip: refactor(CC): use data oriented techniques (SoA) --- src/cc/Sidekick_cc.ml | 145 +++++++++++++++++++++++--------------- src/core/Sidekick_core.ml | 34 +++++---- 2 files changed, 109 insertions(+), 70 deletions(-) diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 949e1230..89ef0beb 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -33,6 +33,31 @@ module Make (A: CC_ARG) module Term = T.Term module Fun = T.Fun + (* nodes are represented as integer offsets *) + module Node0 : sig + include Int_id.S + module NVec : Vec_sig.S with type elt := t + module Set : CCSet.S with type elt = t + module Tbl : CCHashtbl.S with type key = t + end = struct + include Int_id.Make() + module NVec = VecI32 + module Set = Util.Int_set + module Tbl = Util.Int_tbl + end + module NVec = Node0.NVec + + type node = Node0.t + type repr = node (* a node that is representative of its class *) + + (* we keep several bitvectors in the congruence closure, + each mapping nodes to a boolean. + An individual bitvector is represented as its offset in the list of + bitvectors. *) + module Bit_field : Int_id.S = Int_id.Make() + type bitfield = Bit_field.t + + (* TODO: remove module Bits : sig type t = private int type field @@ -62,22 +87,24 @@ module Make (A: CC_ARG) let merge = (lor) let equal : t -> t -> bool = CCEqual.poly end + *) - (** A node of the congruence closure. - An equivalence class is represented by its "root" element, - the representative. *) - type node = { - n_term: term; - mutable n_sig0: signature option; (* initial signature *) - mutable n_bits: Bits.t; (* bitfield for various properties *) - mutable n_parents: node Bag.t; (* parent terms of this node *) - mutable n_root: node; (* representative of congruence class (itself if a representative) *) - mutable n_next: node; (* pointer to next element of congruence class *) - mutable n_size: int; (* size of the class *) - mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) - mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) + (* TODO: sparse vec for n_sig0? *) + + (* the node store, holds data for all the nodes *) + type node_store = { + n_term: term Vec.t; (* term for the node *) + n_sig0: signature Vec.t; (* initial signature, if any *) + n_parents: node Bag.t Vec.t; (* node -> parents(class(node)) *) + n_root: NVec.t; (* node -> repr(class(node)) *) + n_next: NVec.t; (* node -> next(class(node)) *) + n_size: VecI32.t; (* node -> size(class(node)) *) + n_as_lit: lit Int_tbl.t; (* root -> literal, if any *) + n_expl: explanation_forest_link Vec.t; (* proof forest *) + n_bitfields: Bitvec.t Vec.t; (* bitfield idx -> atom -> bool *) } + (* TODO: use node array for 3rd param *) and signature = (fun_, node, node list) view and explanation_forest_link = @@ -97,67 +124,71 @@ module Make (A: CC_ARG) | E_and of explanation * explanation | E_theory of explanation - type repr = node - module N = struct - type t = node + include Node0 + type store = node_store - let[@inline] equal (n1:t) n2 = n1 == n2 - let[@inline] hash n = Term.hash n.n_term - let[@inline] term n = n.n_term - let[@inline] pp out n = Term.pp out n.n_term - let[@inline] as_lit n = n.n_as_lit + let[@inline] term self n = Vec.get self.n_term (n:t:>int) + let[@inline] pp self out n = Term.pp out (term self n) + let[@inline] as_lit self n = n.n_as_lit - let make (t:term) : t = - let rec n = { - n_term=t; - n_sig0= None; - n_bits=Bits.empty; - n_parents=Bag.empty; - n_as_lit=None; (* TODO: provide a method to do it *) - n_root=n; - n_expl=FL_none; - n_next=n; - n_size=1; - } in + let alloc (self:store) (t:term) : t = + let { + n_term; n_sig0; n_parents; n_root; n_next; n_size + } = self in + let n = Node0.of_int_unsafe (Vec.size n_term) in + Vec.push n_term t; + Vec.push n_sig0 (Opaque n); (* to be changed *) + Vec.push n_parents Bag.empty; + NVec.push n_root n; + NVec.push n_next n; + VecI32.push n_size 1; n - let[@inline] is_root (n:node) : bool = n.n_root == n + let[@inline] is_root (self:store) (n:node) : bool = + let n2 = NVec.get self.n_root (n:t:>int) in + equal n n2 (* traverse the equivalence class of [n] *) - let iter_class_ (n:node) : node Iter.t = + let iter_class_ (self:store) (n:t) : t Iter.t = fun yield -> let rec aux u = yield u; - if u.n_next != n then aux u.n_next + let u2 = NVec.get self.n_next (u:t:>int) in + if not (equal n u2) then aux u2 in aux n - let[@inline] iter_class n = - assert (is_root n); - iter_class_ n + let[@inline] iter_class self n = + assert (is_root self n); + iter_class_ self n - let[@inline] iter_parents (n:node) : node Iter.t = - assert (is_root n); - Bag.to_iter n.n_parents + let[@inline] iter_parents self (n:node) : node Iter.t = + assert (is_root self n); + Bag.to_iter (Vec.get self.n_parents (n:t:>int)) - type bitfield = Bits.field - let[@inline] get_field f t = Bits.get f t.n_bits - let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits + (* TODO: use a vec of bitvec *) + type nonrec bitfield = bitfield + + let[@inline] get_field (self:store) (f:bitfield) (n:t) = + let bv = Vec.get self.n_bitfields (f:>int) in + Bitvec.get bv (n:t:>int) + + let[@inline] set_field (self:store) (f:bitfield) (n:t) b : unit = + let bv = Vec.get self.n_bitfields (f:>int) in + Bitvec.set bv (n:t:>int) b + + (* non-recursive, inlinable function for [find] *) + let[@inline] find (self:store) (n:t) : repr = + let n2 = NVec.get self.n_root (n:t:>int) in + assert (is_root self n2); + n2 + + let[@inline] same_class self (n1:node) (n2:node): bool = + equal (find self n1) (find self n2) end - module N_tbl = CCHashtbl.Make(N) - - (* non-recursive, inlinable function for [find] *) - let[@inline] find_ (n:node) : repr = - let n2 = n.n_root in - assert (N.is_root n2); - n2 - - let[@inline] same_class (n1:node)(n2:node): bool = - N.equal (find_ n1) (find_ n2) - - let[@inline] find _ n = find_ n + module N_tbl = Node0.Tbl module Expl = struct type t = explanation diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index ffcc279d..4442b0c5 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -383,6 +383,8 @@ module type CC_S = sig merged, to detect conflicts and solve equations à la Shostak. *) module N : sig + type store + type t (** An equivalent class, containing terms that are proved to be equal. @@ -390,7 +392,7 @@ module type CC_S = sig A value of type [t] points to a particular term, but see {!find} to get the representative of the class. *) - val term : t -> term + val term : store -> t -> term (** Term contained in this equivalence class. If [is_root n], then [term n] is the class' representative term. *) @@ -402,19 +404,19 @@ module type CC_S = sig val hash : t -> int (** An opaque hash of this node. *) - val pp : t Fmt.printer + val pp : store -> t Fmt.printer (** Unspecified printing of the node, for example its term, a unique ID, etc. *) - val is_root : t -> bool + val is_root : store -> t -> bool (** Is the node a root (ie the representative of its class)? See {!find} to get the root. *) - val iter_class : t -> t Iter.t + val iter_class : store -> t -> t Iter.t (** Traverse the congruence class. Precondition: [is_root n] (see {!find} below) *) - val iter_parents : t -> t Iter.t + val iter_parents : store -> t -> t Iter.t (** Traverse the parents of the class. Precondition: [is_root n] (see {!find} below) *) @@ -456,6 +458,9 @@ module type CC_S = sig val term_store : t -> term_store + val n_store : t -> N.store + (** Store of nodes *) + val find : t -> node -> repr (** Current representative *) @@ -1213,13 +1218,14 @@ end = struct else None let on_new_term self cc n (t:T.t) : unit = - Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n); + let nstore = CC.n_store cc in + Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name (N.pp nstore) n); let maybe_m, l = M.of_term cc n t in begin match maybe_m with | Some v -> Log.debugf 20 (fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])" - M.name N.pp n M.pp v); + M.name (N.pp nstore) n M.pp v); SI.CC.set_bitfield cc self.field_has_value true n; N_tbl.add self.values n v | None -> () @@ -1228,25 +1234,25 @@ end = struct (fun (n_u,m_u) -> Log.debugf 20 (fun k->k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])" - M.name N.pp n N.pp n_u M.pp m_u); + M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u); let n_u = CC.find cc n_u in if CC.get_bitfield self.cc self.field_has_value n_u then ( let m_u' = try N_tbl.find self.values n_u with Not_found -> - Error.errorf "node %a has bitfield but no value" N.pp n_u + Error.errorf "node %a has bitfield but no value" (N.pp nstore) n_u in match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with | Error expl -> Error.errorf "when merging@ @[for node %a@],@ \ values %a and %a:@ conflict %a" - N.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl + (N.pp nstore) n_u M.pp m_u M.pp m_u' CC.Expl.pp expl | Ok m_u_merged -> Log.debugf 20 (fun k->k "(@[monoid[%s].on-new-term.sub.merged@ \ :n %a@ :sub-t %a@ :value %a@])" - M.name N.pp n N.pp n_u M.pp m_u_merged); + M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u_merged); N_tbl.add self.values n_u m_u_merged; ) else ( (* just add to [n_u] *) @@ -1261,12 +1267,13 @@ end = struct N_tbl.to_iter self.values let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit = + let nstore = CC.n_store cc in begin match get self n1, get self n2 with | Some v1, Some v2 -> Log.debugf 5 (fun k->k "(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 %a@ :val2 %a@])@])" - M.name N.pp n1 M.pp v1 N.pp n2 M.pp v2); + M.name (N.pp nstore) n1 M.pp v1 (N.pp nstore) n2 M.pp v2); begin match M.merge cc n1 v1 n2 v2 e_n1_n2 with | Ok v' -> N_tbl.remove self.values n2; (* only keep repr *) @@ -1282,7 +1289,8 @@ end = struct end let pp out (self:t) : unit = - let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" N.pp t M.pp v in + let nstore = CC.n_store self.cc in + let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" (N.pp nstore) t M.pp v in Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) (iter_all self) let create_and_setup ?size (solver:SI.t) : t =