From a8e2764e9222a7e97208906e5e895f8758d64a19 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 22 Feb 2021 12:04:52 -0500 Subject: [PATCH] lra: refactor theory combination (have CC tell us what terms are subterms) --- src/arith/lra/sidekick_arith_lra.ml | 19 ++++++++++--------- src/cc/Sidekick_cc.ml | 8 +++++++- src/core/Sidekick_core.ml | 9 +++++++++ src/msat-solver/Sidekick_msat_solver.ml | 1 + 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/arith/lra/sidekick_arith_lra.ml b/src/arith/lra/sidekick_arith_lra.ml index c8772cf5..51bbdd1e 100644 --- a/src/arith/lra/sidekick_arith_lra.ml +++ b/src/arith/lra/sidekick_arith_lra.ml @@ -239,8 +239,6 @@ module Make(A : ARG) : S with module A = A = struct | None, _ -> (* non trivial linexp, give it a fresh name in the simplex *) let proxy = var_encoding_comb self ~pre:"_le" le_comb in - T.Tbl.replace self.needs_th_combination proxy (); - let new_t = match pred with | Eq | Neq -> assert false (* unreachable *) @@ -251,8 +249,6 @@ module Make(A : ARG) : S with module A = A = struct in Log.debugf 10 (fun k->k "lra.preprocess:@ %a@ :into %a" T.pp t T.pp new_t); - - T.Tbl.add self.needs_th_combination new_t (); Some new_t | Some (coeff, v), pred -> @@ -271,8 +267,6 @@ module Make(A : ARG) : S with module A = A = struct let new_t = A.mk_lra tst (LRA_simplex_pred (v, op, q)) in Log.debugf 10 (fun k->k "lra.preprocess@ :%a@ :into %a" T.pp t T.pp new_t); - - T.Tbl.add self.needs_th_combination new_t (); Some new_t end @@ -314,9 +308,7 @@ module Make(A : ARG) : S with module A = A = struct Some proxy ) - | LRA_other t when A.has_ty_real t -> - T.Tbl.replace self.needs_th_combination t (); - None + | LRA_other t when A.has_ty_real t -> None | LRA_const _ | LRA_simplex_pred _ | LRA_simplex_var _ | LRA_other _ -> None module Q_map = CCMap.Make(Q) @@ -478,6 +470,14 @@ module Make(A : ARG) : S with module A = A = struct do_th_combination self si acts model; () + (* look for subterms of type Real, for they will need theory combination *) + let on_subterm (self:state) _ (t:T.t) : unit = + if A.has_ty_real t && + not (T.Tbl.mem self.needs_th_combination t) then ( + Log.debugf 5 (fun k->k "(@[lra.needs-th-combination@ %a@])" T.pp t); + T.Tbl.add self.needs_th_combination t () + ) + let create_and_setup si = Log.debug 2 "(th-lra.setup)"; let stat = SI.stats si in @@ -486,6 +486,7 @@ module Make(A : ARG) : S with module A = A = struct SI.add_preprocess si (preproc_lra st); SI.on_final_check si (final_check_ st); SI.on_partial_check si (partial_check_ st); + SI.on_cc_is_subterm si (on_subterm st); SI.on_cc_post_merge si (fun _ _ n1 n2 -> if A.has_ty_real (N.term n1) then ( diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index a00e3871..94683b2a 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -260,6 +260,7 @@ module Make (A: CC_ARG) mutable on_new_term: ev_on_new_term list; mutable on_conflict: ev_on_conflict list; mutable on_propagate: ev_on_propagate list; + mutable on_is_subterm : ev_on_is_subterm list; mutable new_merges: bool; bitgen: Bits.bitfield_gen; field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *) @@ -281,6 +282,7 @@ module Make (A: CC_ARG) and ev_on_new_term = t -> N.t -> term -> unit and ev_on_conflict = t -> th:bool -> lit list -> unit and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit + and ev_on_is_subterm = N.t -> term -> unit let[@inline] size_ (r:repr) = r.n_size let[@inline] n_true cc = Lazy.force cc.true_ @@ -535,6 +537,7 @@ module Make (A: CC_ARG) on_backtrack self (fun () -> sub_r.n_parents <- old_parents); sub_r.n_parents <- Bag.cons n sub_r.n_parents; end; + List.iter (fun f -> f sub u) self.on_is_subterm; sub in let[@inline] return x = Some x in @@ -847,9 +850,11 @@ module Make (A: CC_ARG) let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate + let on_is_subterm cc f = cc.on_is_subterm <- f :: cc.on_is_subterm let create ?(stat=Stat.global) - ?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[]) + ?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[]) + ?(on_conflict=[]) ?(on_propagate=[]) ?(on_is_subterm=[]) ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in @@ -865,6 +870,7 @@ module Make (A: CC_ARG) on_new_term; on_conflict; on_propagate; + on_is_subterm; pending=Vec.create(); combine=Vec.create(); undo=Backtrack_stack.create(); diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 659ac42a..434ac895 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -224,6 +224,7 @@ module type CC_S = sig type ev_on_new_term = t -> N.t -> term -> unit type ev_on_conflict = t -> th:bool -> lit list -> unit type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit + type ev_on_is_subterm = N.t -> term -> unit val create : ?stat:Stat.t -> @@ -232,6 +233,7 @@ module type CC_S = sig ?on_new_term:ev_on_new_term list -> ?on_conflict:ev_on_conflict list -> ?on_propagate:ev_on_propagate list -> + ?on_is_subterm:ev_on_is_subterm list -> ?size:[`Small | `Big] -> term_state -> t @@ -261,6 +263,9 @@ module type CC_S = sig val on_propagate : t -> ev_on_propagate -> unit (** Called when the congruence closure propagates a literal *) + val on_is_subterm : t -> ev_on_is_subterm -> unit + (** Called on terms that are subterms of function symbols *) + val set_as_lit : t -> N.t -> lit -> unit (** map the given node to a literal. *) @@ -475,6 +480,10 @@ module type SOLVER_INTERNAL = sig (** Callback to add data on terms when they are added to the congruence closure *) + val on_cc_is_subterm : t -> (CC.N.t -> term -> unit) -> unit + (** Callback for when a term is a subterm of another term in the + congruence closure *) + val on_cc_conflict : t -> (CC.t -> th:bool -> lit list -> unit) -> unit (** Callback called on every CC conflict *) diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index b8179236..5baeb0a7 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -280,6 +280,7 @@ module Make(A : ARG) let on_cc_post_merge self f = CC.on_post_merge (cc self) f let on_cc_conflict self f = CC.on_conflict (cc self) f let on_cc_propagate self f = CC.on_propagate (cc self) f + let on_cc_is_subterm self f = CC.on_is_subterm (cc self) f let cc_add_term self t = CC.add_term (cc self) t let cc_mem_term self t = CC.mem_term (cc self) t