lra: refactor theory combination (have CC tell us what terms are subterms)

This commit is contained in:
Simon Cruanes 2021-02-22 12:04:52 -05:00
parent a7afce3af4
commit a8e2764e92
4 changed files with 27 additions and 10 deletions

View file

@ -239,8 +239,6 @@ module Make(A : ARG) : S with module A = A = struct
| None, _ ->
(* non trivial linexp, give it a fresh name in the simplex *)
let proxy = var_encoding_comb self ~pre:"_le" le_comb in
T.Tbl.replace self.needs_th_combination proxy ();
let new_t =
match pred with
| Eq | Neq -> assert false (* unreachable *)
@ -251,8 +249,6 @@ module Make(A : ARG) : S with module A = A = struct
in
Log.debugf 10 (fun k->k "lra.preprocess:@ %a@ :into %a" T.pp t T.pp new_t);
T.Tbl.add self.needs_th_combination new_t ();
Some new_t
| Some (coeff, v), pred ->
@ -271,8 +267,6 @@ module Make(A : ARG) : S with module A = A = struct
let new_t = A.mk_lra tst (LRA_simplex_pred (v, op, q)) in
Log.debugf 10 (fun k->k "lra.preprocess@ :%a@ :into %a" T.pp t T.pp new_t);
T.Tbl.add self.needs_th_combination new_t ();
Some new_t
end
@ -314,9 +308,7 @@ module Make(A : ARG) : S with module A = A = struct
Some proxy
)
| LRA_other t when A.has_ty_real t ->
T.Tbl.replace self.needs_th_combination t ();
None
| LRA_other t when A.has_ty_real t -> None
| LRA_const _ | LRA_simplex_pred _ | LRA_simplex_var _ | LRA_other _ -> None
module Q_map = CCMap.Make(Q)
@ -478,6 +470,14 @@ module Make(A : ARG) : S with module A = A = struct
do_th_combination self si acts model;
()
(* look for subterms of type Real, for they will need theory combination *)
let on_subterm (self:state) _ (t:T.t) : unit =
if A.has_ty_real t &&
not (T.Tbl.mem self.needs_th_combination t) then (
Log.debugf 5 (fun k->k "(@[lra.needs-th-combination@ %a@])" T.pp t);
T.Tbl.add self.needs_th_combination t ()
)
let create_and_setup si =
Log.debug 2 "(th-lra.setup)";
let stat = SI.stats si in
@ -486,6 +486,7 @@ module Make(A : ARG) : S with module A = A = struct
SI.add_preprocess si (preproc_lra st);
SI.on_final_check si (final_check_ st);
SI.on_partial_check si (partial_check_ st);
SI.on_cc_is_subterm si (on_subterm st);
SI.on_cc_post_merge si
(fun _ _ n1 n2 ->
if A.has_ty_real (N.term n1) then (

View file

@ -260,6 +260,7 @@ module Make (A: CC_ARG)
mutable on_new_term: ev_on_new_term list;
mutable on_conflict: ev_on_conflict list;
mutable on_propagate: ev_on_propagate list;
mutable on_is_subterm : ev_on_is_subterm list;
mutable new_merges: bool;
bitgen: Bits.bitfield_gen;
field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *)
@ -281,6 +282,7 @@ module Make (A: CC_ARG)
and ev_on_new_term = t -> N.t -> term -> unit
and ev_on_conflict = t -> th:bool -> lit list -> unit
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
and ev_on_is_subterm = N.t -> term -> unit
let[@inline] size_ (r:repr) = r.n_size
let[@inline] n_true cc = Lazy.force cc.true_
@ -535,6 +537,7 @@ module Make (A: CC_ARG)
on_backtrack self (fun () -> sub_r.n_parents <- old_parents);
sub_r.n_parents <- Bag.cons n sub_r.n_parents;
end;
List.iter (fun f -> f sub u) self.on_is_subterm;
sub
in
let[@inline] return x = Some x in
@ -847,9 +850,11 @@ module Make (A: CC_ARG)
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate
let on_is_subterm cc f = cc.on_is_subterm <- f :: cc.on_is_subterm
let create ?(stat=Stat.global)
?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[])
?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[])
?(on_conflict=[]) ?(on_propagate=[]) ?(on_is_subterm=[])
?(size=`Big)
(tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
@ -865,6 +870,7 @@ module Make (A: CC_ARG)
on_new_term;
on_conflict;
on_propagate;
on_is_subterm;
pending=Vec.create();
combine=Vec.create();
undo=Backtrack_stack.create();

View file

@ -224,6 +224,7 @@ module type CC_S = sig
type ev_on_new_term = t -> N.t -> term -> unit
type ev_on_conflict = t -> th:bool -> lit list -> unit
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
type ev_on_is_subterm = N.t -> term -> unit
val create :
?stat:Stat.t ->
@ -232,6 +233,7 @@ module type CC_S = sig
?on_new_term:ev_on_new_term list ->
?on_conflict:ev_on_conflict list ->
?on_propagate:ev_on_propagate list ->
?on_is_subterm:ev_on_is_subterm list ->
?size:[`Small | `Big] ->
term_state ->
t
@ -261,6 +263,9 @@ module type CC_S = sig
val on_propagate : t -> ev_on_propagate -> unit
(** Called when the congruence closure propagates a literal *)
val on_is_subterm : t -> ev_on_is_subterm -> unit
(** Called on terms that are subterms of function symbols *)
val set_as_lit : t -> N.t -> lit -> unit
(** map the given node to a literal. *)
@ -475,6 +480,10 @@ module type SOLVER_INTERNAL = sig
(** Callback to add data on terms when they are added to the congruence
closure *)
val on_cc_is_subterm : t -> (CC.N.t -> term -> unit) -> unit
(** Callback for when a term is a subterm of another term in the
congruence closure *)
val on_cc_conflict : t -> (CC.t -> th:bool -> lit list -> unit) -> unit
(** Callback called on every CC conflict *)

View file

@ -280,6 +280,7 @@ module Make(A : ARG)
let on_cc_post_merge self f = CC.on_post_merge (cc self) f
let on_cc_conflict self f = CC.on_conflict (cc self) f
let on_cc_propagate self f = CC.on_propagate (cc self) f
let on_cc_is_subterm self f = CC.on_is_subterm (cc self) f
let cc_add_term self t = CC.add_term (cc self) t
let cc_mem_term self t = CC.mem_term (cc self) t