refactor: change signature of field access in CC

This commit is contained in:
Simon Cruanes 2021-07-04 00:25:59 -04:00
parent 51ac678ccd
commit 71360ad1f8
2 changed files with 61 additions and 14 deletions

View file

@ -298,6 +298,7 @@ module Make (A: CC_ARG)
let[@inline] on_backtrack cc f : unit =
Backtrack_stack.push_if_nonzero_level cc.undo f
let[@inline] get_bitfield _cc field n = N.get_field field n
let set_bitfield cc field b n =
let old = N.get_field field n in
if old <> b then (

View file

@ -308,8 +308,25 @@ module type CC_ARG = sig
(** View the term through the lens of the congruence closure *)
end
(** Signature of the congruence closure *)
(** Main congruence closure.
The congruence closure handles the theory QF_UF (uninterpreted
function symbols).
It is also responsible for {i theory combination}, and provides
a general framework for equality reasoning that other
theories piggyback on.
For example, the theory of datatypes relies on the congruence closure
to do most of the work, and "only" adds injectivity/disjointness/acyclicity
lemmas when needed.
Similarly, a theory of arrays would hook into the congruence closure and
assert (dis)equalities as needed.
*)
module type CC_S = sig
(** first, some aliases. *)
module T : TERM
module P : PROOF with type term = T.Term.t
module Lit : LIT with module T = T
@ -322,9 +339,13 @@ module type CC_S = sig
type actions = Actions.t
type t
(** State of the congruence closure *)
(** The congruence closure object.
It contains a fair amount of state and is mutable
and backtrackable. *)
(** An equivalence class is a set of terms that are currently equal
(** Equivalence classes.
An equivalence class is a set of terms that are currently equal
in the partial model built by the solver.
The class is represented by a collection of nodes, one of which is
distinguished and is called the "representative".
@ -349,9 +370,20 @@ module type CC_S = sig
{!find} to get the representative of the class. *)
val term : t -> term
(** Term contained in this equivalence class.
If [is_root n], then [term n] is the class' representative term. *)
val equal : t -> t -> bool
(** Are two classes {b physically} equal? To check for
logical equality, use [CC.N.equal (CC.find cc n1) (CC.find cc n2)]
which checks for equality of representatives. *)
val hash : t -> int
(** An opaque hash of this node. *)
val pp : t Fmt.printer
(** Unspecified printing of the node, for example its term,
a unique ID, etc. *)
val is_root : t -> bool
(** Is the node a root (ie the representative of its class)?
@ -369,11 +401,11 @@ module type CC_S = sig
(** A field in the bitfield of this node. This should only be
allocated when a theory is initialized.
Bitfields are accessed using preallocated keys.
See {!CC_S.allocate_bitfield}.
All fields are initially 0, are backtracked automatically,
and are merged automatically when classes are merged. *)
val get_field : bitfield -> t -> bool
(** Access the bit field *)
end
(** Explanations
@ -468,8 +500,21 @@ module type CC_S = sig
as well. *)
val allocate_bitfield : descr:string -> t -> N.bitfield
(** Allocate a new bitfield for the nodes.
See {!N.bitfield}. *)
(** Allocate a new node field (see {!N.bitfield}).
This field descriptor is henceforth reserved for all nodes
in this congruence closure, and can be set using {!set_bitfield}
for each node individually.
This can be used to efficiently store some metadata on nodes
(e.g. "is there a numeric value in the class"
or "is there a constructor term in the class").
There may be restrictions on how many distinct fields are allocated
for a given congruence closure (e.g. at most {!Sys.int_size} fields).
*)
val get_bitfield : t -> N.bitfield -> N.t -> bool
(** Access the bit field of the given node *)
val set_bitfield : t -> N.bitfield -> bool -> N.t -> unit
(** Set the bitfield for the node. This will be backtracked.
@ -1152,6 +1197,7 @@ end = struct
module Expl = SI.CC.Expl
type t = {
cc: CC.t;
values: M.t N_tbl.t; (* repr -> value for the class *)
field_has_value: N.bitfield; (* bit in CC to filter out quickly classes without value *)
}
@ -1160,12 +1206,12 @@ end = struct
let pop_levels self n = N_tbl.pop_levels self.values n
let mem self n =
let res = N.get_field self.field_has_value n in
let res = CC.get_bitfield self.cc self.field_has_value n in
assert (if res then N_tbl.mem self.values n else true);
res
let get self n =
if N.get_field self.field_has_value n
if CC.get_bitfield self.cc self.field_has_value n
then N_tbl.get self.values n
else None
@ -1187,7 +1233,7 @@ end = struct
(fun k->k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])"
M.name N.pp n N.pp n_u M.pp m_u);
let n_u = CC.find cc n_u in
if N.get_field self.field_has_value n_u then (
if CC.get_bitfield self.cc self.field_has_value n_u then (
let m_u' =
try N_tbl.find self.values n_u
with Not_found ->
@ -1243,10 +1289,10 @@ end = struct
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) (iter_all self)
let create_and_setup ?size (solver:SI.t) : t =
let cc = SI.cc solver in
let field_has_value =
SI.CC.allocate_bitfield ~descr:("monoid."^M.name^".has-value")
(SI.cc solver) in
let self = { values=N_tbl.create ?size (); field_has_value; } in
SI.CC.allocate_bitfield ~descr:("monoid."^M.name^".has-value") cc in
let self = { cc; values=N_tbl.create ?size (); field_has_value; } in
SI.on_cc_new_term solver (on_new_term self);
SI.on_cc_pre_merge solver (on_pre_merge self);
self