diff --git a/src/core-logic/term.ml b/src/core-logic/term.ml index 689f06c7..fd2cff21 100644 --- a/src/core-logic/term.ml +++ b/src/core-logic/term.ml @@ -55,6 +55,16 @@ let unfold_app (e : term) : term * term list = in aux [] e +let[@inline] is_const e = + match e.view with + | E_const _ -> true + | _ -> false + +let[@inline] is_app e = + match e.view with + | E_app _ -> true + | _ -> false + (* debug printer *) let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit = let rec loop k ~depth names out e = diff --git a/src/core-logic/term.mli b/src/core-logic/term.mli index b9d2ac67..adcdff02 100644 --- a/src/core-logic/term.mli +++ b/src/core-logic/term.mli @@ -53,6 +53,8 @@ include WITH_SET_MAP_TBL with type t := t val view : t -> view val unfold_app : t -> t * t list +val is_app : t -> bool +val is_const : t -> bool val iter_dag : ?seen:unit Tbl.t -> iter_ty:bool -> f:(t -> unit) -> t -> unit (** [iter_dag t ~f] calls [f] once on each subterm of [t], [t] included. diff --git a/src/smt/solver_internal.ml b/src/smt/solver_internal.ml index 77ce6b6e..d7e20f3f 100644 --- a/src/smt/solver_internal.ml +++ b/src/smt/solver_internal.ml @@ -21,6 +21,7 @@ module type PREPROCESS_ACTS = sig val mk_lit : ?sign:bool -> term -> lit val add_clause : lit list -> step_id -> unit val add_lit : ?default_pol:bool -> lit -> unit + val add_term_needing_combination : term -> unit end type preprocess_actions = (module PREPROCESS_ACTS) @@ -39,10 +40,9 @@ type t = { proof: proof_trace; (** proof logger *) registry: Registry.t; on_progress: (unit, unit) Event.Emitter.t; + th_comb: Th_combination.t; mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list; - mutable on_th_combination: - (t -> theory_actions -> (term * value) Iter.t) list; mutable preprocess: preprocess_hook list; mutable model_ask: model_ask_hook list; mutable model_complete: model_completion_hook list; @@ -82,11 +82,11 @@ let add_simplifier (self : t) f : unit = Simplify.add_hook self.simp f let[@inline] has_delayed_actions self = not (Queue.is_empty self.delayed_actions) -let on_th_combination self f = - self.on_th_combination <- f :: self.on_th_combination - let on_preprocess self f = self.preprocess <- f :: self.preprocess +let add_term_needing_combination self t = + Th_combination.add_term_needing_combination self.th_comb t + let on_model ?ask ?complete self = Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask; Option.iter @@ -130,6 +130,9 @@ let preprocess_term_ (self : t) (t0 : term) : unit = let mk_lit ?sign t : Lit.t = Lit.atom ?sign self.tst t let add_lit ?default_pol lit : unit = delayed_add_lit self ?default_pol lit let add_clause c pr : unit = delayed_add_clause self ~keep:true c pr + + let add_term_needing_combination t = + Th_combination.add_term_needing_combination self.th_comb t end in let acts = (module A : PREPROCESS_ACTS) in @@ -397,33 +400,12 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = (* do theory combination using the congruence closure. Each theory can merge classes, *) -let check_th_combination_ (self : t) (_acts : theory_actions) lits : - (Model.t, th_combination_conflict) result = - (* FIXME - - (* enter model mode, disabling most of congruence closure *) - CC.with_model_mode cc @@ fun () -> - let set_val (t, v) : unit = - Log.debugf 50 (fun k -> - k "(@[solver.th-comb.cc-set-term-value@ %a@ :val %a@])" Term.pp_debug t - Term.pp_debug v); - CC.set_model_value cc t v - in - - (* obtain assignments from the hook, and communicate them to the CC *) - let add_th_values f : unit = - let vals = f self acts in - Iter.iter set_val vals - in - try - List.iter add_th_values self.on_th_combination; - CC.check cc; - let m = mk_model_ self in - Ok m - with Semantic_conflict c -> Error c - *) - let m = mk_model_ self lits in - Ok m +let check_th_combination_ (self : t) (acts : theory_actions) _lits : unit = + let lits_to_decide = Th_combination.pop_new_lits self.th_comb in + if lits_to_decide <> [] then ( + let (module A) = acts in + List.iter (fun lit -> A.add_lit ~default_pol:false lit) lits_to_decide + ) (* call congruence closure, perform the actions it scheduled *) let check_cc_with_acts_ (self : t) (acts : theory_actions) = @@ -471,40 +453,13 @@ let assert_lits_ ~final (self : t) (acts : theory_actions) (lits : Lit.t Iter.t) (* do actual theory combination if nothing changed by pure "final check" *) if not new_work then ( - match check_th_combination_ self acts lits with - | Ok m -> self.last_model <- Some m - | Error { lits; semantic } -> - (* bad model, we add a clause to remove it *) - Log.debugf 5 (fun k -> - k - "(@[solver.th-comb.conflict@ :lits (@[%a@])@ :same-val \ - (@[%a@])@])" - (Util.pp_list Lit.pp) lits - (Util.pp_list - @@ Fmt.Dump.(triple bool Term.pp_debug Term.pp_debug)) - semantic); + check_th_combination_ self acts lits; - let c1 = List.rev_map Lit.neg lits in - let c2 = - semantic - |> List.rev_map (fun (sign, t, u) -> - let eqn = Term.eq self.tst t u in - let lit = Lit.atom ~sign:(not sign) self.tst eqn in - (* make sure to consider the new lit *) - add_lit self acts lit; - lit) - in - - let c = List.rev_append c1 c2 in - let pr = - Proof_trace.add_step self.proof @@ fun () -> Proof_core.lemma_cc c - in - - Log.debugf 20 (fun k -> - k "(@[solver.th-comb.add-semantic-conflict-clause@ %a@])" - (Util.pp_list Lit.pp) c); - (* will add a delayed action *) - add_clause_temp self acts c pr + (* if theory combination didn't add new clauses, compute a model *) + if not (has_delayed_actions self) then ( + let m = mk_model_ self lits in + self.last_model <- Some m + ) ); Perform_delayed_th.top self acts @@ -585,6 +540,7 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t = stat; simp = Simplify.create tst ~proof; last_model = None; + th_comb = Th_combination.create ~stat tst; on_progress = Event.Emitter.create (); preprocess = []; model_ask = []; @@ -598,7 +554,6 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t = count_conflict = Stat.mk_int stat "smt.solver.th-conflicts"; on_partial_check = []; on_final_check = []; - on_th_combination = []; level = 0; complete = true; } diff --git a/src/smt/solver_internal.mli b/src/smt/solver_internal.mli index c9c03255..a28db9e5 100644 --- a/src/smt/solver_internal.mli +++ b/src/smt/solver_internal.mli @@ -73,6 +73,10 @@ module type PREPROCESS_ACTS = sig val add_lit : ?default_pol:bool -> lit -> unit (** Ensure the literal will be decided/handled by the SAT solver. *) + + val add_term_needing_combination : term -> unit + (** Declare this term as being a foreign variable in the theory, + which means it needs to go through theory combination. *) end type preprocess_actions = (module PREPROCESS_ACTS) @@ -98,6 +102,10 @@ val preprocess_clause_array : t -> lit array -> step_id -> lit array * step_id val simplify_and_preproc_lit : t -> lit -> lit * step_id option (** Simplify literal then preprocess it *) +val add_term_needing_combination : t -> term -> unit +(** Declare this term as being a foreign variable in the theory, + which means it needs to go through theory combination. *) + (** {3 hooks for the theory} *) val raise_conflict : t -> theory_actions -> lit list -> step_id -> 'a @@ -216,16 +224,6 @@ val on_final_check : t -> (t -> theory_actions -> lit Iter.t -> unit) -> unit is given the whole trail. *) -val on_th_combination : - t -> (t -> theory_actions -> (term * value) Iter.t) -> unit -(** Add a hook called during theory combination. - The hook must return an iterator of pairs [(t, v)] - which mean that term [t] has value [v] in the model. - - Terms with the same value (according to {!Term.equal}) will be - merged in the CC; if two terms with different values are merged, - we get a semantic conflict and must pick another model. *) - val declare_pb_is_incomplete : t -> unit (** Declare that, in some theory, the problem is outside the logic fragment that is decidable (e.g. if we meet proper NIA formulas). diff --git a/src/smt/th_combination.ml b/src/smt/th_combination.ml new file mode 100644 index 00000000..e051a0ea --- /dev/null +++ b/src/smt/th_combination.ml @@ -0,0 +1,62 @@ +open Sidekick_core +module T = Term + +type t = { + tst: Term.store; + processed: T.Set.t T.Tbl.t; (** type -> set of terms *) + unprocessed: T.t Vec.t; + new_lits: Lit.t Vec.t; + n_terms: int Stat.counter; + n_lits: int Stat.counter; +} + +let create ?(stat = Stat.global) tst : t = + { + tst; + processed = T.Tbl.create 8; + unprocessed = Vec.create (); + new_lits = Vec.create (); + n_terms = Stat.mk_int stat "smt.thcomb.terms"; + n_lits = Stat.mk_int stat "smt.thcomb.intf-lits"; + } + +let processed_ (self : t) t : bool = + let ty = T.ty t in + match T.Tbl.find_opt self.processed ty with + | None -> false + | Some set -> T.Set.mem t set + +let add_term_needing_combination (self : t) (t : T.t) : unit = + if not (processed_ self t) then ( + Log.debugf 50 (fun k -> k "(@[th.comb.add-term-needing-comb@ %a@])" T.pp t); + Vec.push self.unprocessed t + ) + +let pop_new_lits (self : t) : Lit.t list = + (* first, process new terms, if any *) + while not (Vec.is_empty self.unprocessed) do + let t = Vec.pop_exn self.unprocessed in + let ty = T.ty t in + let set_for_ty = + try T.Tbl.find self.processed ty with Not_found -> T.Set.empty + in + if not (T.Set.mem t set_for_ty) then ( + Stat.incr self.n_terms; + + (* now create [t=u] for each [u] in [set_for_ty] *) + T.Set.iter + (fun u -> + let lit = Lit.make_eq self.tst t u in + Stat.incr self.n_lits; + Vec.push self.new_lits lit) + set_for_ty; + + (* add [t] to the set of processed terms *) + let new_set_for_ty = T.Set.add t set_for_ty in + T.Tbl.replace self.processed ty new_set_for_ty + ) + done; + + let lits = Vec.to_list self.new_lits in + Vec.clear self.new_lits; + lits diff --git a/src/smt/th_combination.mli b/src/smt/th_combination.mli new file mode 100644 index 00000000..5a782e3e --- /dev/null +++ b/src/smt/th_combination.mli @@ -0,0 +1,17 @@ +(** Delayed Theory Combination *) + +open Sidekick_core + +type t + +val create : ?stat:Stat.t -> Term.store -> t + +val add_term_needing_combination : t -> Term.t -> unit +(** [add_term_needing_combination self t] means that [t] occurs as a foreign + variable in another term, so it is important that its theory, and the + theory in which it occurs, agree on it being equal to other + foreign terms. *) + +val pop_new_lits : t -> Lit.t list +(** Get the new literals that the solver needs to decide, so that the + SMT solver gives each theory the same partition of interface equalities. *) diff --git a/src/th-lra/sidekick_th_lra.ml b/src/th-lra/sidekick_th_lra.ml index 918466d0..f9b4b3d6 100644 --- a/src/th-lra/sidekick_th_lra.ml +++ b/src/th-lra/sidekick_th_lra.ml @@ -130,8 +130,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct in_model: unit Term.Tbl.t; (* terms to add to model *) encoded_eqs: unit Term.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *) - needs_th_combination: unit Term.Tbl.t; - (* terms that require theory combination *) simp_preds: (Term.t * S_op.t * A.Q.t) Term.Tbl.t; (* term -> its simplex meaning *) simp_defined: LE.t Term.Tbl.t; @@ -157,7 +155,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct simp_preds = Term.Tbl.create 32; simp_defined = Term.Tbl.create 16; encoded_eqs = Term.Tbl.create 8; - needs_th_combination = Term.Tbl.create 8; encoded_le = Comb_map.empty; simplex = SimpSolver.create ~stat (); last_res = None; @@ -275,6 +272,11 @@ module Make (A : ARG) = (* : S with module A = A *) struct | Geq -> S_op.Geq | Gt -> S_op.Gt + (* add [t] to the theory combination system if it's not just a constant + of type Real *) + let add_lra_var_to_th_combination (si : SI.t) (t : term) : unit = + if not (Term.is_const t) then SI.add_term_needing_combination si t + (* TODO: refactor that and {!var_encoding_comb} *) (* turn a linear expression into a single constant and a coeff. This might define a side variable in the simplex. *) @@ -300,17 +302,20 @@ module Make (A : ARG) = (* : S with module A = A *) struct proxy, A.Q.one) (* look for subterms of type Real, for they will need theory combination *) - let on_subterm (self : state) (t : Term.t) : unit = + let on_subterm (_self : state) (si : SI.t) (t : Term.t) : unit = Log.debugf 50 (fun k -> k "(@[lra.cc-on-subterm@ %a@])" Term.pp_debug t); match A.view_as_lra t with - | LRA_other _ when not (A.has_ty_real t) -> () + | LRA_other _ when not (A.has_ty_real t) -> + (* for a non-LRA term [f args], if any of [args] is in LRA, + it needs theory combination *) + let _, args = Term.unfold_app t in + List.iter + (fun arg -> + if A.has_ty_real arg then SI.add_term_needing_combination si arg) + args | LRA_pred _ | LRA_const _ -> () | LRA_op _ | LRA_other _ | LRA_mult _ -> - if not (Term.Tbl.mem self.needs_th_combination t) then ( - Log.debugf 5 (fun k -> - k "(@[lra.needs-th-combination@ %a@])" Term.pp_debug t); - Term.Tbl.add self.needs_th_combination t () - ) + SI.add_term_needing_combination si t (* preprocess linear expressions away *) let preproc_lra (self : state) si (module PA : SI.PREPROCESS_ACTS) @@ -323,7 +328,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct Log.debugf 50 (fun k -> k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t); ignore (CC.add_term (SI.cc si) t : E_node.t); - if sub then on_subterm self t + if sub then on_subterm self si t in match A.view_as_lra t with @@ -369,7 +374,11 @@ module Make (A : ARG) = (* : S with module A = A *) struct (* obtain a single variable for the linear combination *) let v, c_v = le_comb_to_singleton_ self le_comb in declare_term_to_cc ~sub:false v; - LE_.Comb.iter (fun v _ -> declare_term_to_cc ~sub:true v) le_comb; + LE_.Comb.iter + (fun v _ -> + declare_term_to_cc ~sub:true v; + add_lra_var_to_th_combination si v) + le_comb; (* turn into simplex constraint. For example, [c . v <= const] becomes a direct simplex constraint [v <= const/c] @@ -568,41 +577,42 @@ module Make (A : ARG) = (* : S with module A = A *) struct (* evaluate a linear expression *) let eval_le_in_subst_ subst (le : LE.t) = LE.eval (eval_in_subst_ subst) le - (* FIXME: rename, this is more "provide_model_to_cc" *) - let do_th_combination (self : state) _si _acts : _ Iter.t = - Log.debug 1 "(lra.do-th-combinations)"; - let model = - match self.last_res with - | Some (SimpSolver.Sat m) -> m - | _ -> assert false - in + (* FIXME: rework into model creation + let do_th_combination (self : state) _si _acts : _ Iter.t = + Log.debug 1 "(lra.do-th-combinations)"; + let model = + match self.last_res with + | Some (SimpSolver.Sat m) -> m + | _ -> assert false + in - let vals = Subst.to_iter model |> Term.Tbl.of_iter in + let vals = Subst.to_iter model |> Term.Tbl.of_iter in - (* also include terms that occur under function symbols, if they're - not in the model already *) - Term.Tbl.iter - (fun t () -> - if not (Term.Tbl.mem vals t) then ( - let v = eval_in_subst_ model t in - Term.Tbl.add vals t v - )) - self.needs_th_combination; + (* also include terms that occur under function symbols, if they're + not in the model already *) + Term.Tbl.iter + (fun t () -> + if not (Term.Tbl.mem vals t) then ( + let v = eval_in_subst_ model t in + Term.Tbl.add vals t v + )) + self.needs_th_combination; - (* also consider subterms that are linear expressions, - and evaluate them using the value of each variable - in that linear expression. For example a term [a + 2b] - is evaluated as [eval(a) + 2 × eval(b)]. *) - Term.Tbl.iter - (fun t le -> - if not (Term.Tbl.mem vals t) then ( - let v = eval_le_in_subst_ model le in - Term.Tbl.add vals t v - )) - self.simp_defined; + (* also consider subterms that are linear expressions, + and evaluate them using the value of each variable + in that linear expression. For example a term [a + 2b] + is evaluated as [eval(a) + 2 × eval(b)]. *) + Term.Tbl.iter + (fun t le -> + if not (Term.Tbl.mem vals t) then ( + let v = eval_le_in_subst_ model le in + Term.Tbl.add vals t v + )) + self.simp_defined; - (* return whole model *) - Term.Tbl.to_iter vals |> Iter.map (fun (t, v) -> t, t_const self v) + (* return whole model *) + Term.Tbl.to_iter vals |> Iter.map (fun (t, v) -> t, t_const self v) + *) (* partial checks is where we add literals from the trail to the simplex. *) @@ -714,7 +724,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct SI.on_partial_check si (partial_check_ st); SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st); SI.on_cc_is_subterm si (fun (_, _, t) -> - on_subterm st t; + on_subterm st si t; []); SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) -> match as_const_ (E_node.term n1), as_const_ (E_node.term n2) with @@ -725,7 +735,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct E_node.pp n2); Error (CC.Handler_action.Conflict expl) | _ -> Ok []); - SI.on_th_combination si (do_th_combination st); st let theory =