mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-12 14:00:42 -05:00
wip: refactor(CC): use data oriented techniques (SoA)
This commit is contained in:
parent
782afa4415
commit
0efef5b6ef
2 changed files with 109 additions and 70 deletions
|
|
@ -33,6 +33,31 @@ module Make (A: CC_ARG)
|
||||||
module Term = T.Term
|
module Term = T.Term
|
||||||
module Fun = T.Fun
|
module Fun = T.Fun
|
||||||
|
|
||||||
|
(* nodes are represented as integer offsets *)
|
||||||
|
module Node0 : sig
|
||||||
|
include Int_id.S
|
||||||
|
module NVec : Vec_sig.S with type elt := t
|
||||||
|
module Set : CCSet.S with type elt = t
|
||||||
|
module Tbl : CCHashtbl.S with type key = t
|
||||||
|
end = struct
|
||||||
|
include Int_id.Make()
|
||||||
|
module NVec = VecI32
|
||||||
|
module Set = Util.Int_set
|
||||||
|
module Tbl = Util.Int_tbl
|
||||||
|
end
|
||||||
|
module NVec = Node0.NVec
|
||||||
|
|
||||||
|
type node = Node0.t
|
||||||
|
type repr = node (* a node that is representative of its class *)
|
||||||
|
|
||||||
|
(* we keep several bitvectors in the congruence closure,
|
||||||
|
each mapping nodes to a boolean.
|
||||||
|
An individual bitvector is represented as its offset in the list of
|
||||||
|
bitvectors. *)
|
||||||
|
module Bit_field : Int_id.S = Int_id.Make()
|
||||||
|
type bitfield = Bit_field.t
|
||||||
|
|
||||||
|
(* TODO: remove
|
||||||
module Bits : sig
|
module Bits : sig
|
||||||
type t = private int
|
type t = private int
|
||||||
type field
|
type field
|
||||||
|
|
@ -62,22 +87,24 @@ module Make (A: CC_ARG)
|
||||||
let merge = (lor)
|
let merge = (lor)
|
||||||
let equal : t -> t -> bool = CCEqual.poly
|
let equal : t -> t -> bool = CCEqual.poly
|
||||||
end
|
end
|
||||||
|
*)
|
||||||
|
|
||||||
(** A node of the congruence closure.
|
(* TODO: sparse vec for n_sig0? *)
|
||||||
An equivalence class is represented by its "root" element,
|
|
||||||
the representative. *)
|
(* the node store, holds data for all the nodes *)
|
||||||
type node = {
|
type node_store = {
|
||||||
n_term: term;
|
n_term: term Vec.t; (* term for the node *)
|
||||||
mutable n_sig0: signature option; (* initial signature *)
|
n_sig0: signature Vec.t; (* initial signature, if any *)
|
||||||
mutable n_bits: Bits.t; (* bitfield for various properties *)
|
n_parents: node Bag.t Vec.t; (* node -> parents(class(node)) *)
|
||||||
mutable n_parents: node Bag.t; (* parent terms of this node *)
|
n_root: NVec.t; (* node -> repr(class(node)) *)
|
||||||
mutable n_root: node; (* representative of congruence class (itself if a representative) *)
|
n_next: NVec.t; (* node -> next(class(node)) *)
|
||||||
mutable n_next: node; (* pointer to next element of congruence class *)
|
n_size: VecI32.t; (* node -> size(class(node)) *)
|
||||||
mutable n_size: int; (* size of the class *)
|
n_as_lit: lit Int_tbl.t; (* root -> literal, if any *)
|
||||||
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
|
n_expl: explanation_forest_link Vec.t; (* proof forest *)
|
||||||
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
|
n_bitfields: Bitvec.t Vec.t; (* bitfield idx -> atom -> bool *)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(* TODO: use node array for 3rd param *)
|
||||||
and signature = (fun_, node, node list) view
|
and signature = (fun_, node, node list) view
|
||||||
|
|
||||||
and explanation_forest_link =
|
and explanation_forest_link =
|
||||||
|
|
@ -97,67 +124,71 @@ module Make (A: CC_ARG)
|
||||||
| E_and of explanation * explanation
|
| E_and of explanation * explanation
|
||||||
| E_theory of explanation
|
| E_theory of explanation
|
||||||
|
|
||||||
type repr = node
|
|
||||||
|
|
||||||
module N = struct
|
module N = struct
|
||||||
type t = node
|
include Node0
|
||||||
|
type store = node_store
|
||||||
|
|
||||||
let[@inline] equal (n1:t) n2 = n1 == n2
|
let[@inline] term self n = Vec.get self.n_term (n:t:>int)
|
||||||
let[@inline] hash n = Term.hash n.n_term
|
let[@inline] pp self out n = Term.pp out (term self n)
|
||||||
let[@inline] term n = n.n_term
|
let[@inline] as_lit self n = n.n_as_lit
|
||||||
let[@inline] pp out n = Term.pp out n.n_term
|
|
||||||
let[@inline] as_lit n = n.n_as_lit
|
|
||||||
|
|
||||||
let make (t:term) : t =
|
let alloc (self:store) (t:term) : t =
|
||||||
let rec n = {
|
let {
|
||||||
n_term=t;
|
n_term; n_sig0; n_parents; n_root; n_next; n_size
|
||||||
n_sig0= None;
|
} = self in
|
||||||
n_bits=Bits.empty;
|
let n = Node0.of_int_unsafe (Vec.size n_term) in
|
||||||
n_parents=Bag.empty;
|
Vec.push n_term t;
|
||||||
n_as_lit=None; (* TODO: provide a method to do it *)
|
Vec.push n_sig0 (Opaque n); (* to be changed *)
|
||||||
n_root=n;
|
Vec.push n_parents Bag.empty;
|
||||||
n_expl=FL_none;
|
NVec.push n_root n;
|
||||||
n_next=n;
|
NVec.push n_next n;
|
||||||
n_size=1;
|
VecI32.push n_size 1;
|
||||||
} in
|
|
||||||
n
|
n
|
||||||
|
|
||||||
let[@inline] is_root (n:node) : bool = n.n_root == n
|
let[@inline] is_root (self:store) (n:node) : bool =
|
||||||
|
let n2 = NVec.get self.n_root (n:t:>int) in
|
||||||
|
equal n n2
|
||||||
|
|
||||||
(* traverse the equivalence class of [n] *)
|
(* traverse the equivalence class of [n] *)
|
||||||
let iter_class_ (n:node) : node Iter.t =
|
let iter_class_ (self:store) (n:t) : t Iter.t =
|
||||||
fun yield ->
|
fun yield ->
|
||||||
let rec aux u =
|
let rec aux u =
|
||||||
yield u;
|
yield u;
|
||||||
if u.n_next != n then aux u.n_next
|
let u2 = NVec.get self.n_next (u:t:>int) in
|
||||||
|
if not (equal n u2) then aux u2
|
||||||
in
|
in
|
||||||
aux n
|
aux n
|
||||||
|
|
||||||
let[@inline] iter_class n =
|
let[@inline] iter_class self n =
|
||||||
assert (is_root n);
|
assert (is_root self n);
|
||||||
iter_class_ n
|
iter_class_ self n
|
||||||
|
|
||||||
let[@inline] iter_parents (n:node) : node Iter.t =
|
let[@inline] iter_parents self (n:node) : node Iter.t =
|
||||||
assert (is_root n);
|
assert (is_root self n);
|
||||||
Bag.to_iter n.n_parents
|
Bag.to_iter (Vec.get self.n_parents (n:t:>int))
|
||||||
|
|
||||||
type bitfield = Bits.field
|
(* TODO: use a vec of bitvec *)
|
||||||
let[@inline] get_field f t = Bits.get f t.n_bits
|
type nonrec bitfield = bitfield
|
||||||
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
|
|
||||||
|
let[@inline] get_field (self:store) (f:bitfield) (n:t) =
|
||||||
|
let bv = Vec.get self.n_bitfields (f:>int) in
|
||||||
|
Bitvec.get bv (n:t:>int)
|
||||||
|
|
||||||
|
let[@inline] set_field (self:store) (f:bitfield) (n:t) b : unit =
|
||||||
|
let bv = Vec.get self.n_bitfields (f:>int) in
|
||||||
|
Bitvec.set bv (n:t:>int) b
|
||||||
|
|
||||||
|
(* non-recursive, inlinable function for [find] *)
|
||||||
|
let[@inline] find (self:store) (n:t) : repr =
|
||||||
|
let n2 = NVec.get self.n_root (n:t:>int) in
|
||||||
|
assert (is_root self n2);
|
||||||
|
n2
|
||||||
|
|
||||||
|
let[@inline] same_class self (n1:node) (n2:node): bool =
|
||||||
|
equal (find self n1) (find self n2)
|
||||||
end
|
end
|
||||||
|
|
||||||
module N_tbl = CCHashtbl.Make(N)
|
module N_tbl = Node0.Tbl
|
||||||
|
|
||||||
(* non-recursive, inlinable function for [find] *)
|
|
||||||
let[@inline] find_ (n:node) : repr =
|
|
||||||
let n2 = n.n_root in
|
|
||||||
assert (N.is_root n2);
|
|
||||||
n2
|
|
||||||
|
|
||||||
let[@inline] same_class (n1:node)(n2:node): bool =
|
|
||||||
N.equal (find_ n1) (find_ n2)
|
|
||||||
|
|
||||||
let[@inline] find _ n = find_ n
|
|
||||||
|
|
||||||
module Expl = struct
|
module Expl = struct
|
||||||
type t = explanation
|
type t = explanation
|
||||||
|
|
|
||||||
|
|
@ -383,6 +383,8 @@ module type CC_S = sig
|
||||||
merged, to detect conflicts and solve equations à la Shostak.
|
merged, to detect conflicts and solve equations à la Shostak.
|
||||||
*)
|
*)
|
||||||
module N : sig
|
module N : sig
|
||||||
|
type store
|
||||||
|
|
||||||
type t
|
type t
|
||||||
(** An equivalent class, containing terms that are proved
|
(** An equivalent class, containing terms that are proved
|
||||||
to be equal.
|
to be equal.
|
||||||
|
|
@ -390,7 +392,7 @@ module type CC_S = sig
|
||||||
A value of type [t] points to a particular term, but see
|
A value of type [t] points to a particular term, but see
|
||||||
{!find} to get the representative of the class. *)
|
{!find} to get the representative of the class. *)
|
||||||
|
|
||||||
val term : t -> term
|
val term : store -> t -> term
|
||||||
(** Term contained in this equivalence class.
|
(** Term contained in this equivalence class.
|
||||||
If [is_root n], then [term n] is the class' representative term. *)
|
If [is_root n], then [term n] is the class' representative term. *)
|
||||||
|
|
||||||
|
|
@ -402,19 +404,19 @@ module type CC_S = sig
|
||||||
val hash : t -> int
|
val hash : t -> int
|
||||||
(** An opaque hash of this node. *)
|
(** An opaque hash of this node. *)
|
||||||
|
|
||||||
val pp : t Fmt.printer
|
val pp : store -> t Fmt.printer
|
||||||
(** Unspecified printing of the node, for example its term,
|
(** Unspecified printing of the node, for example its term,
|
||||||
a unique ID, etc. *)
|
a unique ID, etc. *)
|
||||||
|
|
||||||
val is_root : t -> bool
|
val is_root : store -> t -> bool
|
||||||
(** Is the node a root (ie the representative of its class)?
|
(** Is the node a root (ie the representative of its class)?
|
||||||
See {!find} to get the root. *)
|
See {!find} to get the root. *)
|
||||||
|
|
||||||
val iter_class : t -> t Iter.t
|
val iter_class : store -> t -> t Iter.t
|
||||||
(** Traverse the congruence class.
|
(** Traverse the congruence class.
|
||||||
Precondition: [is_root n] (see {!find} below) *)
|
Precondition: [is_root n] (see {!find} below) *)
|
||||||
|
|
||||||
val iter_parents : t -> t Iter.t
|
val iter_parents : store -> t -> t Iter.t
|
||||||
(** Traverse the parents of the class.
|
(** Traverse the parents of the class.
|
||||||
Precondition: [is_root n] (see {!find} below) *)
|
Precondition: [is_root n] (see {!find} below) *)
|
||||||
|
|
||||||
|
|
@ -456,6 +458,9 @@ module type CC_S = sig
|
||||||
|
|
||||||
val term_store : t -> term_store
|
val term_store : t -> term_store
|
||||||
|
|
||||||
|
val n_store : t -> N.store
|
||||||
|
(** Store of nodes *)
|
||||||
|
|
||||||
val find : t -> node -> repr
|
val find : t -> node -> repr
|
||||||
(** Current representative *)
|
(** Current representative *)
|
||||||
|
|
||||||
|
|
@ -1213,13 +1218,14 @@ end = struct
|
||||||
else None
|
else None
|
||||||
|
|
||||||
let on_new_term self cc n (t:T.t) : unit =
|
let on_new_term self cc n (t:T.t) : unit =
|
||||||
Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n);
|
let nstore = CC.n_store cc in
|
||||||
|
Log.debugf 50 (fun k->k "@[monoid[%s].on-new-term.try@ %a@])" M.name (N.pp nstore) n);
|
||||||
let maybe_m, l = M.of_term cc n t in
|
let maybe_m, l = M.of_term cc n t in
|
||||||
begin match maybe_m with
|
begin match maybe_m with
|
||||||
| Some v ->
|
| Some v ->
|
||||||
Log.debugf 20
|
Log.debugf 20
|
||||||
(fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])"
|
(fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])"
|
||||||
M.name N.pp n M.pp v);
|
M.name (N.pp nstore) n M.pp v);
|
||||||
SI.CC.set_bitfield cc self.field_has_value true n;
|
SI.CC.set_bitfield cc self.field_has_value true n;
|
||||||
N_tbl.add self.values n v
|
N_tbl.add self.values n v
|
||||||
| None -> ()
|
| None -> ()
|
||||||
|
|
@ -1228,25 +1234,25 @@ end = struct
|
||||||
(fun (n_u,m_u) ->
|
(fun (n_u,m_u) ->
|
||||||
Log.debugf 20
|
Log.debugf 20
|
||||||
(fun k->k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])"
|
(fun k->k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])"
|
||||||
M.name N.pp n N.pp n_u M.pp m_u);
|
M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u);
|
||||||
let n_u = CC.find cc n_u in
|
let n_u = CC.find cc n_u in
|
||||||
if CC.get_bitfield self.cc self.field_has_value n_u then (
|
if CC.get_bitfield self.cc self.field_has_value n_u then (
|
||||||
let m_u' =
|
let m_u' =
|
||||||
try N_tbl.find self.values n_u
|
try N_tbl.find self.values n_u
|
||||||
with Not_found ->
|
with Not_found ->
|
||||||
Error.errorf "node %a has bitfield but no value" N.pp n_u
|
Error.errorf "node %a has bitfield but no value" (N.pp nstore) n_u
|
||||||
in
|
in
|
||||||
match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with
|
match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with
|
||||||
| Error expl ->
|
| Error expl ->
|
||||||
Error.errorf
|
Error.errorf
|
||||||
"when merging@ @[for node %a@],@ \
|
"when merging@ @[for node %a@],@ \
|
||||||
values %a and %a:@ conflict %a"
|
values %a and %a:@ conflict %a"
|
||||||
N.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl
|
(N.pp nstore) n_u M.pp m_u M.pp m_u' CC.Expl.pp expl
|
||||||
| Ok m_u_merged ->
|
| Ok m_u_merged ->
|
||||||
Log.debugf 20
|
Log.debugf 20
|
||||||
(fun k->k "(@[monoid[%s].on-new-term.sub.merged@ \
|
(fun k->k "(@[monoid[%s].on-new-term.sub.merged@ \
|
||||||
:n %a@ :sub-t %a@ :value %a@])"
|
:n %a@ :sub-t %a@ :value %a@])"
|
||||||
M.name N.pp n N.pp n_u M.pp m_u_merged);
|
M.name (N.pp nstore) n (N.pp nstore) n_u M.pp m_u_merged);
|
||||||
N_tbl.add self.values n_u m_u_merged;
|
N_tbl.add self.values n_u m_u_merged;
|
||||||
) else (
|
) else (
|
||||||
(* just add to [n_u] *)
|
(* just add to [n_u] *)
|
||||||
|
|
@ -1261,12 +1267,13 @@ end = struct
|
||||||
N_tbl.to_iter self.values
|
N_tbl.to_iter self.values
|
||||||
|
|
||||||
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
|
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
|
||||||
|
let nstore = CC.n_store cc in
|
||||||
begin match get self n1, get self n2 with
|
begin match get self n1, get self n2 with
|
||||||
| Some v1, Some v2 ->
|
| Some v1, Some v2 ->
|
||||||
Log.debugf 5
|
Log.debugf 5
|
||||||
(fun k->k
|
(fun k->k
|
||||||
"(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 %a@ :val2 %a@])@])"
|
"(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 %a@ :val2 %a@])@])"
|
||||||
M.name N.pp n1 M.pp v1 N.pp n2 M.pp v2);
|
M.name (N.pp nstore) n1 M.pp v1 (N.pp nstore) n2 M.pp v2);
|
||||||
begin match M.merge cc n1 v1 n2 v2 e_n1_n2 with
|
begin match M.merge cc n1 v1 n2 v2 e_n1_n2 with
|
||||||
| Ok v' ->
|
| Ok v' ->
|
||||||
N_tbl.remove self.values n2; (* only keep repr *)
|
N_tbl.remove self.values n2; (* only keep repr *)
|
||||||
|
|
@ -1282,7 +1289,8 @@ end = struct
|
||||||
end
|
end
|
||||||
|
|
||||||
let pp out (self:t) : unit =
|
let pp out (self:t) : unit =
|
||||||
let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" N.pp t M.pp v in
|
let nstore = CC.n_store self.cc in
|
||||||
|
let pp_e out (t,v) = Fmt.fprintf out "(@[%a@ :has %a@])" (N.pp nstore) t M.pp v in
|
||||||
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) (iter_all self)
|
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) (iter_all self)
|
||||||
|
|
||||||
let create_and_setup ?size (solver:SI.t) : t =
|
let create_and_setup ?size (solver:SI.t) : t =
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue