From 95f84b4854001a998b5c23947408270013e1ed63 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 17 Feb 2022 16:36:07 -0500 Subject: [PATCH] refactor(th-comb): provide full model to the CC this way it can fail on merges of classes assigned conflicting value. --- src/cc/Sidekick_cc.ml | 144 ++++++++++++++++++++++---- src/core/Sidekick_core.ml | 30 +++--- src/lra/sidekick_arith_lra.ml | 142 +++++++++---------------- src/simplex/linear_expr.ml | 4 +- src/simplex/linear_expr_intf.ml | 2 +- src/smt-solver/Sidekick_smt_solver.ml | 50 ++++----- 6 files changed, 217 insertions(+), 155 deletions(-) diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 65b030a6..a1942a1b 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -24,6 +24,7 @@ module Make (A: CC_ARG) module Actions = A.Actions module P = Actions.P type term = T.Term.t + type value = term type term_store = T.Term.store type lit = Lit.t type fun_ = T.Fun.t @@ -267,12 +268,15 @@ module Make (A: CC_ARG) type combine_task = | CT_merge of node * node * explanation + | CT_set_val of node * value type t = { tst: term_store; - tbl: node T_tbl.t; proof: proof; + + tbl: node T_tbl.t; (* internalization [term -> node] *) + signatures_tbl : node Sig_tbl.t; (* map a signature to the corresponding node in some equivalence class. A signature is a [term_cell] in which every immediate subterm @@ -281,9 +285,21 @@ module Make (A: CC_ARG) The critical property is that all members of an equivalence class that have the same "shape" (including head symbol) have the same signature *) + pending: node Vec.t; combine: combine_task Vec.t; + + t_to_val: (node*value) T_tbl.t; + (* [repr -> (t,val)] where [repr = t] + and [t := val] in the model *) + val_to_t: node T_tbl.t; (* [val -> t] where [t := val] in the model *) + undo: (unit -> unit) Backtrack_stack.t; + bitgen: Bits.bitfield_gen; + field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *) + true_ : node lazy_t; + false_ : node lazy_t; + mutable on_pre_merge: ev_on_pre_merge list; mutable on_post_merge: ev_on_post_merge list; mutable on_new_term: ev_on_new_term list; @@ -291,10 +307,6 @@ module Make (A: CC_ARG) mutable on_propagate: ev_on_propagate list; mutable on_is_subterm : ev_on_is_subterm list; mutable new_merges: bool; - bitgen: Bits.bitfield_gen; - field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *) - true_ : node lazy_t; - false_ : node lazy_t; stat: Stat.t; count_conflict: int Stat.counter; count_props: int Stat.counter; @@ -622,7 +634,7 @@ module Make (A: CC_ARG) (* remove term when we backtrack *) on_backtrack cc (fun () -> - Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t); + Log.debugf 30 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t); T_tbl.remove cc.tbl t); (* add term to the table *) T_tbl.add cc.tbl t n; @@ -755,6 +767,51 @@ module Make (A: CC_ARG) cc.new_merges <- true; task_merge_ cc acts a b e_ab + | CT_set_val (n, v) -> + task_set_val_ cc acts n v + + and task_set_val_ cc acts n v = + let repr_n = find_ n in + (* - if repr(n) has value [v], do nothing + - else if repr(n) has value [v'], semantic conflict + - else add [repr(n) -> (n,v)] to cc.t_to_val *) + begin match T_tbl.find_opt cc.t_to_val repr_n.n_term with + | Some (n', v') when not (Term.equal v v') -> + (* semantic conflict *) + let expl = [Expl.mk_merge n n'] in + let expl_st = explain_expls cc expl in + let lits = expl_st.lits in + let tuples = + List.rev_map (fun (t,u) -> true, t.n_term, u.n_term) expl_st.same_val + in + let tuples = (false, n.n_term, n'.n_term) :: tuples in + Log.debugf 20 + (fun k->k "(@[cc.semantic-conflict.set-val@ (@[set-val %a@ := %a@])@ \ + (@[existing-val %a@ := %a@])@])" + N.pp n Term.pp v N.pp n' Term.pp v'); + + Stat.incr cc.count_semantic_conflict; + Actions.raise_semantic_conflict acts lits tuples + + | Some _ -> () + | None -> + T_tbl.add cc.t_to_val repr_n.n_term (n, v); + on_backtrack cc (fun () -> T_tbl.remove cc.t_to_val repr_n.n_term); + end; + (* now for the reverse map, look in self.val_to_t for [v]. + - if present, push a merge command with Expl.mk_same_value + - if not, add [v -> n] *) + begin match T_tbl.find_opt cc.val_to_t v with + | None -> + T_tbl.add cc.val_to_t v n; + on_backtrack cc (fun () -> T_tbl.remove cc.val_to_t v); + + | Some n' when not (same_class n n') -> + merge_classes cc n n' (Expl.mk_same_value n n') + + | Some _ -> () + end + (* main CC algo: merge equivalence classes in [st.combine]. @raise Exn_unsat if merge fails *) and task_merge_ cc acts a b e_ab : unit = @@ -787,7 +844,7 @@ module Make (A: CC_ARG) let lits = expl_st.lits in let same_val = expl_st.same_val - |> List.rev_map (fun (t,u) -> N.term t, N.term u) in + |> List.rev_map (fun (t,u) -> true, N.term t, N.term u) in assert (same_val <> []); Stat.incr cc.count_semantic_conflict; Actions.raise_semantic_conflict acts lits same_val @@ -817,14 +874,17 @@ module Make (A: CC_ARG) in merge_bool ra a rb b; merge_bool rb b ra a; + (* perform [union r_from r_into] *) Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); + (* call [on_pre_merge] functions, and merge theory data items *) begin (* explanation is [a=ra & e_ab & b=rb] *) let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge; end; + begin (* parents might have a different signature, check for collisions *) N.iter_parents r_from @@ -848,8 +908,8 @@ module Make (A: CC_ARG) (* on backtrack, unmerge classes and restore the pointers to [r_from] *) on_backtrack cc (fun () -> - Log.debugf 15 - (fun k->k "(@[cc.undo_merge@ :from %a :into %a@])" + Log.debugf 30 + (fun k->k "(@[cc.undo_merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); r_into.n_bits <- r_into_old_bits; r_into.n_next <- r_into_old_next; @@ -861,6 +921,42 @@ module Make (A: CC_ARG) r_into.n_size <- r_into.n_size - r_from.n_size; ); end; + + (* check for semantic values, update the one of [r_into] + if [r_from] has a value *) + begin match T_tbl.find_opt cc.t_to_val r_from.n_term with + | None -> () + | Some (n_from, v_from) -> + begin match T_tbl.find_opt cc.t_to_val r_into.n_term with + | None -> + T_tbl.add cc.t_to_val r_into.n_term (n_from,v_from); + on_backtrack cc (fun () -> T_tbl.remove cc.t_to_val r_into.n_term); + + | Some (n_into,v_into) when not (Term.equal v_from v_into) -> + (* semantic conflict, including [n_from != n_into] in model *) + let expl = [ + e_ab; Expl.mk_merge r_from n_from; + Expl.mk_merge r_into n_into] in + let expl_st = explain_expls cc expl in + let lits = expl_st.lits in + let tuples = + List.rev_map (fun (t,u) -> true, t.n_term, u.n_term) expl_st.same_val + in + let tuples = (false, n_from.n_term, n_into.n_term) :: tuples in + + Log.debugf 20 + (fun k->k "(@[cc.semantic-conflict.post-merge@ \ + (@[n-from %a@ := %a@])@ (@[n-into %a@ := %a@])@])" + N.pp n_from Term.pp v_from N.pp n_into Term.pp v_into); + + Stat.incr cc.count_semantic_conflict; + Actions.raise_semantic_conflict acts + lits tuples + + | Some _ -> () + end + end; + (* update explanations (a -> b), arbitrarily. Note that here we merge the classes by adding a bridge between [a] and [b], not their roots. *) @@ -908,23 +1004,24 @@ module Make (A: CC_ARG) let lit = if sign then lit else Lit.neg lit in (* apply sign *) Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); (* complete explanation with the [u1=t1] chunk *) - let reason = - let e = lazy ( - let lazy st = half_expl_and_pr in - explain_equal_rec_ cc st u1 t1; + let lazy st = half_expl_and_pr in + let st = Expl_state.copy st in (* do not modify shared st *) + explain_equal_rec_ cc st u1 t1; + + (* propagate only if this doesn't depend on some semantic values *) + if not (Expl_state.is_semantic st) then ( + let reason () = (* true literals explaining why t1=t2 *) let guard = st.lits in (* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *) - let st = Expl_state.copy st in (* do not modify shared st *) Expl_state.add_lit st (Lit.neg lit); let _, pr = lits_and_proof_of_expl cc st in guard, pr - ) in - fun () -> Lazy.force e - in - List.iter (fun f -> f cc lit reason) cc.on_propagate; - Stat.incr cc.count_props; - Actions.propagate acts lit ~reason + in + List.iter (fun f -> f cc lit reason) cc.on_propagate; + Stat.incr cc.count_props; + Actions.propagate acts lit ~reason + ) | _ -> ()) module Debug_ = struct @@ -998,8 +1095,9 @@ module Make (A: CC_ARG) let[@inline] merge_t cc t1 t2 expl = merge cc (add_term cc t1) (add_term cc t2) expl - let merge_same_value cc n1 n2 = merge cc n1 n2 (Expl.mk_same_value n1 n2) - let merge_same_value_t cc t1 t2 = merge_same_value cc (add_term cc t1) (add_term cc t2) + let set_model_value (self:t) (t:term) (v:value) : unit = + let n = add_term self t in + Vec.push self.combine (CT_set_val (n,v)) let explain_eq cc n1 n2 : Resolved_expl.t = let st = Expl_state.create() in @@ -1027,6 +1125,8 @@ module Make (A: CC_ARG) tbl = T_tbl.create size; signatures_tbl = Sig_tbl.create size; bitgen; + t_to_val=T_tbl.create 32; + val_to_t=T_tbl.create 32; on_pre_merge; on_post_merge; on_new_term; diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 196f20ff..9530d366 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -332,11 +332,13 @@ module type CC_ACTIONS = sig exception). @param pr the proof of [c] being a tautology *) - val raise_semantic_conflict : t -> Lit.t list -> (T.Term.t * T.Term.t) list -> 'a + val raise_semantic_conflict : t -> Lit.t list -> (bool * T.Term.t * T.Term.t) list -> 'a (** [raise_semantic_conflict acts lits same_val] declares that - the conjunction of all [lits] (literals true in current trail) - and pairs [t_i = u_i] (which are pairs of terms with the same value - in the current model), implies false. + the conjunction of all [lits] (literals true in current trail) and tuples + [{=,≠}, t_i, u_i] implies false. + + The [{=,≠}, t_i, u_i] are pairs of terms with the same value (if [=] / true) + or distinct value (if [≠] / false)) in the current model. This does not return. It should raise an exception. *) @@ -410,6 +412,7 @@ module type CC_S = sig and type proof_step = proof_step type term_store = T.Term.store type term = T.Term.t + type value = term type fun_ = T.Fun.t type lit = Lit.t type actions = Actions.t @@ -726,11 +729,8 @@ module type CC_S = sig val merge_t : t -> term -> term -> Expl.t -> unit (** Shortcut for adding + merging *) - val merge_same_value : t -> N.t -> N.t -> unit - (** Merge these two nodes because they have the same value - in the model. The explanation will be {!Expl.mk_same_value}. *) - - val merge_same_value_t : t -> term -> term -> unit + val set_model_value : t -> term -> value -> unit + (** Set the value of a term in the model. *) val check : t -> actions -> unit (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. @@ -783,6 +783,7 @@ module type SOLVER_INTERNAL = sig type ty = T.Ty.t type term = T.Term.t + type value = T.Term.t type term_store = T.Term.store type ty_store = T.Ty.store type clause_pool @@ -1029,11 +1030,14 @@ module type SOLVER_INTERNAL = sig is given the whole trail. *) - val on_th_combination : t -> (t -> theory_actions -> term list Iter.t) -> unit + val on_th_combination : t -> (t -> theory_actions -> (term * value) Iter.t) -> unit (** Add a hook called during theory combination. - The hook must return an iterator of lists, each list [t1…tn] - is a set of terms that have the same value in the model - (and therefore must be merged). *) + The hook must return an iterator of pairs [(t, v)] + which mean that term [t] has value [v] in the model. + + Terms with the same value (according to {!Term.equal}) will be + merged in the CC; if two terms with different values are merged, + we get a semantic conflict and must pick another model. *) val declare_pb_is_incomplete : t -> unit (** Declare that, in some theory, the problem is outside the logic fragment diff --git a/src/lra/sidekick_arith_lra.ml b/src/lra/sidekick_arith_lra.ml index ec64af15..4ec01a55 100644 --- a/src/lra/sidekick_arith_lra.ml +++ b/src/lra/sidekick_arith_lra.ml @@ -238,6 +238,7 @@ module Make(A : ARG) : S with module A = A = struct encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *) needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *) simp_preds: (T.t * S_op.t * A.Q.t) T.Tbl.t; (* term -> its simplex meaning *) + simp_defined: LE.t T.Tbl.t; (* (rational) terms that are equal to a linexp *) st_exprs : ST_exprs.t; mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *) simplex: SimpSolver.t; @@ -255,6 +256,7 @@ module Make(A : ARG) : S with module A = A = struct st_exprs=ST_exprs.create_and_setup si; gensym=A.Gensym.create tst; simp_preds=T.Tbl.create 32; + simp_defined=T.Tbl.create 16; encoded_eqs=T.Tbl.create 8; needs_th_combination=T.Tbl.create 8; encoded_le=Comb_map.empty; @@ -293,7 +295,6 @@ module Make(A : ARG) : S with module A = A = struct let[@inline] as_const_ t = match A.view_as_lra t with LRA_const n -> Some n | _ -> None let[@inline] is_zero t = match A.view_as_lra t with LRA_const n -> A.Q.(n = zero) | _ -> false - let t_of_comb (self:state) (comb:LE_.Comb.t) ~(init:T.t) : T.t = let[@inline] (+) a b = A.mk_lra self.tst (LRA_op (Plus, a, b)) in let[@inline] ( * ) a b = A.mk_lra self.tst (LRA_mult (a, b)) in @@ -379,11 +380,11 @@ module Make(A : ARG) : S with module A = A = struct Log.debugf 50 (fun k->k "(@[lra.cc-on-subterm@ %a@])" T.pp t); match A.view_as_lra t with | LRA_other _ when not (A.has_ty_real t) -> () - | LRA_pred _ -> () - | LRA_op _ | LRA_const _ | LRA_other _ | LRA_mult _ -> + | LRA_pred _ | LRA_const _ -> () + | LRA_op _ | LRA_other _ | LRA_mult _ -> if not (T.Tbl.mem self.needs_th_combination t) then ( Log.debugf 5 (fun k->k "(@[lra.needs-th-combination@ %a@])" T.pp t); - T.Tbl.add self.needs_th_combination t () + T.Tbl.add self.needs_th_combination t (); ) (* preprocess linear expressions away *) @@ -464,12 +465,11 @@ module Make(A : ARG) : S with module A = A = struct T.pp t SimpSolver.Constraint.pp constr); | LRA_op _ | LRA_mult _ -> - (* NOTE: we don't need to do anything for rational subterms, at least - not at first. Only when theory combination mandates we compare - two terms (by deciding [t1 = t2]) do they impact the simplex; and - then they're moved into an equation, which means they are - preprocessed in the LRA_pred case above. *) - () + if not (T.Tbl.mem self.simp_defined t) then ( + (* we define these terms so their value in the model make sense *) + let le = as_linexp t in + T.Tbl.add self.simp_defined t le; + ); | LRA_const _n -> () @@ -619,98 +619,54 @@ module Make(A : ARG) : S with module A = A = struct let t2 = N.term n2 in add_local_eq_t self si acts t1 t2 ~tag:(Tag.CC_eq (n1, n2)) - (* - (* theory combination: add decisions [t=u] whenever [t] and [u] - have the same value in [subst] and both occur under function symbols *) - let do_th_combination (self:state) si acts (subst:Subst.t) : unit = - Log.debug 1 "(lra.do-th-combinations)"; - let n_th_comb = T.Tbl.keys self.needs_th_combination |> Iter.length in - if n_th_comb > 0 then ( - Log.debugf 5 - (fun k->k "(@[lra.needs-th-combination@ :n-lits %d@])" n_th_comb); - Log.debugf 50 - (fun k->k "(@[lra.needs-th-combination@ :terms [@[%a@]]@])" - (Util.pp_iter @@ Fmt.within "`" "`" T.pp) (T.Tbl.keys self.needs_th_combination)); - ); + (* evaluate a term directly, as a variable *) + let eval_in_subst_ subst t = match A.view_as_lra t with + | LRA_const n -> n + | _ -> Subst.eval subst t |> CCOpt.get_or ~default:A.Q.zero - let eval_in_subst_ subst t = match A.view_as_lra t with - | LRA_const n -> n - | _ -> Subst.eval subst t |> CCOpt.get_or ~default:A.Q.zero - in + (* evaluate a linear expression *) + let eval_le_in_subst_ subst (le:LE.t) = + LE.eval (eval_in_subst_ subst) le - let n = ref 0 in - (* theory combination: for [t1,t2] terms in [self.needs_th_combination] - that have same value, but are not provably equal, push - decision [t1=t2] into the SAT solver. *) - begin - let by_val: T.t list Q_map.t = - T.Tbl.keys self.needs_th_combination - |> Iter.map (fun t -> eval_in_subst_ subst t, t) - |> Iter.fold - (fun m (q,t) -> - let l = Q_map.get_or ~default:[] q m in - Q_map.add q (t::l) m) - Q_map.empty - in - Q_map.iter - (fun _q ts -> - begin match ts with - | [] | [_] -> () - | ts -> - (* several terms! see if they are already equal *) - CCList.diagonal ts - |> List.iter - (fun (t1,t2) -> - Log.debugf 50 - (fun k->k "(@[LRA.th-comb.check-pair[val=%a]@ %a@ %a@])" - A.Q.pp _q T.pp t1 T.pp t2); - assert(SI.cc_mem_term si t1); - assert(SI.cc_mem_term si t2); - (* if both [t1] and [t2] are relevant to the congruence - closure, and are not equal in it yet, add [t1=t2] as - the next decision to do *) - if not (SI.cc_are_equal si t1 t2) then ( - Log.debugf 50 - (fun k->k - "(@[lra.th-comb.must-decide-equal@ :t1 %a@ :t2 %a@])" T.pp t1 T.pp t2); - Stat.incr self.stat_th_comb; - Profile.instant "lra.th-comb-assert-eq"; - - let t = A.mk_eq (SI.tst si) t1 t2 in - let lit = SI.mk_lit si acts t in - incr n; - SI.push_decision si acts lit - ) - ) - end) - by_val; - () - end; - Log.debugf 1 (fun k->k "(@[lra.do-th-combinations.done@ :new-lits %d@])" !n); - () - *) - - let do_th_combination (self:state) _si _acts : A.term list Iter.t = + (* FIXME: rename, this is more "provide_model_to_cc" *) + let do_th_combination (self:state) _si _acts : _ Iter.t = Log.debug 1 "(lra.do-th-combinations)"; let model = match self.last_res with | Some (SimpSolver.Sat m) -> m | _ -> assert false in - (* gather terms by their model value *) - let tbl = Q_tbl.create 32 in - Subst.to_iter model - (fun (t,q) -> - let l = Q_tbl.get_or ~default:[] tbl q in - Q_tbl.replace tbl q (t :: l)); + let vals = + Subst.to_iter model |> T.Tbl.of_iter + in - (* now return classes of terms *) - Q_tbl.to_iter tbl - |> Iter.filter_map - (fun (_q, l) -> - match l with - | [] | [_] -> None - | l -> Some l) + (* also include terms that occur under function symbols, if they're + not in the model already *) + T.Tbl.iter + (fun t () -> + if not (T.Tbl.mem vals t) then ( + let v = eval_in_subst_ model t in + T.Tbl.add vals t v; + )) + self.needs_th_combination; + + (* also consider subterms that are linear expressions, + and evaluate them using the value of each variable + in that linear expression. For example a term [a + 2b] + is evaluated as [eval(a) + 2 × eval(b)]. *) + T.Tbl.iter + (fun t le -> + if not (T.Tbl.mem vals t) then ( + let v = eval_le_in_subst_ model le in + T.Tbl.add vals t v + )) + self.simp_defined; + + (* return whole model *) + begin + T.Tbl.to_iter vals + |> Iter.map (fun (t,v) -> t, t_const self v) + end (* partial checks is where we add literals from the trail to the simplex. *) diff --git a/src/simplex/linear_expr.ml b/src/simplex/linear_expr.ml index 1eb3b091..9873e93c 100644 --- a/src/simplex/linear_expr.ml +++ b/src/simplex/linear_expr.ml @@ -15,7 +15,7 @@ module Make(C : COEFF)(Var : VAR) = struct module Var = Var type var = Var.t - type subst = C.t Var_map.t + type subst = (Var.t -> C.t) (** Linear combination of variables. *) module Comb = struct @@ -87,7 +87,7 @@ module Make(C : COEFF)(Var : VAR) = struct let eval (subst : subst) (e:t) : C.t = Var_map.fold - (fun x c acc -> C.(acc + c * (Var_map.find x subst))) + (fun x c acc -> C.(acc + c * subst x)) e C.zero end diff --git a/src/simplex/linear_expr_intf.ml b/src/simplex/linear_expr_intf.ml index 0dcd9ab4..253d28ba 100644 --- a/src/simplex/linear_expr_intf.ml +++ b/src/simplex/linear_expr_intf.ml @@ -84,7 +84,7 @@ module type S = sig module Var_map : CCMap.S with type key = var (** Maps from variables, used for expressions as well as substitutions. *) - type subst = C.t Var_map.t + type subst = Var.t -> C.t (** Type for substitutions. *) (** Combinations. diff --git a/src/smt-solver/Sidekick_smt_solver.ml b/src/smt-solver/Sidekick_smt_solver.ml index fb964c6f..95ee8283 100644 --- a/src/smt-solver/Sidekick_smt_solver.ml +++ b/src/smt-solver/Sidekick_smt_solver.ml @@ -88,6 +88,7 @@ module Make(A : ARG) module Term = T.Term module Lit = A.Lit type term = Term.t + type value = term type ty = Ty.t type proof = A.proof type proof_step = A.proof_step @@ -101,7 +102,8 @@ module Make(A : ARG) and doesn't need to kill the current trail. *) type th_combination_conflict = { lits: lit list; - same_val: (term*term) list; + semantic: (bool*term*term) list; + (* set of semantic eqns/diseqns (ie true only in current model) *) } exception Semantic_conflict of th_combination_conflict @@ -128,8 +130,8 @@ module Make(A : ARG) let[@inline] raise_conflict (a:t) lits (pr:proof_step) = let (module A) = a in A.raise_conflict lits pr - let[@inline] raise_semantic_conflict (_:t) lits same_val = - raise (Semantic_conflict {lits; same_val}) + let[@inline] raise_semantic_conflict (_:t) lits semantic = + raise (Semantic_conflict {lits; semantic}) let[@inline] propagate (a:t) lit ~reason = let (module A) = a in let reason = Sidekick_sat.Consequence reason in @@ -163,6 +165,7 @@ module Make(A : ARG) type nonrec proof = proof type nonrec proof_step = proof_step type term = Term.t + type value = term type ty = Ty.t type lit = Lit.t type term_store = Term.store @@ -274,7 +277,7 @@ module Make(A : ARG) mutable on_progress: unit -> unit; mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list; - mutable on_th_combination: (t -> theory_actions -> term list Iter.t) list; + mutable on_th_combination: (t -> theory_actions -> (term*value) Iter.t) list; mutable preprocess: preprocess_hook list; mutable model_ask: model_ask_hook list; mutable model_complete: model_completion_hook list; @@ -573,23 +576,17 @@ module Make(A : ARG) let cc = cc self in with_cc_level_ cc @@ fun () -> - (* merge all terms in the class *) - let merge_cls (cls:term list) : unit = - match cls with - | [] -> assert false - | [_] -> () - | t :: ts -> - Log.debugf 50 - (fun k->k "(@[solver.th-comb.merge-cls@ %a@])" - (Util.pp_list Term.pp) cls); - - List.iter (fun u -> CC.merge_same_value_t cc t u) ts + let set_val (t,v) : unit = + Log.debugf 50 + (fun k->k "(@[solver.th-comb.cc-set-term-value@ %a@ :val %a@])" + Term.pp t Term.pp v); + CC.set_model_value cc t v in (* obtain classes of equal terms from the hook, and merge them *) let add_th_equalities f : unit = - let cls = f self acts in - Iter.iter merge_cls cls + let vals = f self acts in + Iter.iter set_val vals in try @@ -624,33 +621,38 @@ module Make(A : ARG) ) done; + CC.check cc acts; + let new_merges_in_cc = CC.new_merges cc in + begin match check_th_combination_ self acts with | Ok () -> () - | Error {lits; same_val} -> + | Error {lits; semantic} -> (* bad model, we add a clause to remove it *) Log.debugf 10 (fun k->k "(@[solver.th-comb.conflict@ :lits (@[%a@])@ \ :same-val (@[%a@])@])" (Util.pp_list Lit.pp) lits - (Util.pp_list @@ Fmt.Dump.pair Term.pp Term.pp) same_val); + (Util.pp_list @@ Fmt.Dump.(triple bool Term.pp Term.pp)) semantic); let c1 = List.rev_map Lit.neg lits in let c2 = - List.rev_map (fun (t,u) -> - Lit.atom ~sign:false self.tst @@ A.mk_eq self.tst t u) same_val + semantic + |> List.rev_map + (fun (sign,t,u) -> + Lit.atom ~sign:(not sign) self.tst @@ A.mk_eq self.tst t u) in let c = List.rev_append c1 c2 in let pr = P.lemma_cc (Iter.of_list c) self.proof in Log.debugf 20 - (fun k->k "(@[solver.th-comb.add-clause@ %a@])" + (fun k->k "(@[solver.th-comb.add-semantic-conflict-clause@ %a@])" (Util.pp_list Lit.pp) c); + (* will add a delayed action *) add_clause_temp self acts c pr; end; - CC.check cc acts; - if not (CC.new_merges cc) && not (has_delayed_actions self) then ( + if not new_merges_in_cc && not (has_delayed_actions self) then ( continue := false; ); done;