refactor(th-comb): provide full model to the CC

this way it can fail on merges of classes assigned conflicting value.
This commit is contained in:
Simon Cruanes 2022-02-17 16:36:07 -05:00
parent fd66039c8d
commit 95f84b4854
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
6 changed files with 217 additions and 155 deletions

View file

@ -24,6 +24,7 @@ module Make (A: CC_ARG)
module Actions = A.Actions module Actions = A.Actions
module P = Actions.P module P = Actions.P
type term = T.Term.t type term = T.Term.t
type value = term
type term_store = T.Term.store type term_store = T.Term.store
type lit = Lit.t type lit = Lit.t
type fun_ = T.Fun.t type fun_ = T.Fun.t
@ -267,12 +268,15 @@ module Make (A: CC_ARG)
type combine_task = type combine_task =
| CT_merge of node * node * explanation | CT_merge of node * node * explanation
| CT_set_val of node * value
type t = { type t = {
tst: term_store; tst: term_store;
tbl: node T_tbl.t;
proof: proof; proof: proof;
tbl: node T_tbl.t;
(* internalization [term -> node] *) (* internalization [term -> node] *)
signatures_tbl : node Sig_tbl.t; signatures_tbl : node Sig_tbl.t;
(* map a signature to the corresponding node in some equivalence class. (* map a signature to the corresponding node in some equivalence class.
A signature is a [term_cell] in which every immediate subterm A signature is a [term_cell] in which every immediate subterm
@ -281,9 +285,21 @@ module Make (A: CC_ARG)
The critical property is that all members of an equivalence class The critical property is that all members of an equivalence class
that have the same "shape" (including head symbol) that have the same "shape" (including head symbol)
have the same signature *) have the same signature *)
pending: node Vec.t; pending: node Vec.t;
combine: combine_task Vec.t; combine: combine_task Vec.t;
t_to_val: (node*value) T_tbl.t;
(* [repr -> (t,val)] where [repr = t]
and [t := val] in the model *)
val_to_t: node T_tbl.t; (* [val -> t] where [t := val] in the model *)
undo: (unit -> unit) Backtrack_stack.t; undo: (unit -> unit) Backtrack_stack.t;
bitgen: Bits.bitfield_gen;
field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *)
true_ : node lazy_t;
false_ : node lazy_t;
mutable on_pre_merge: ev_on_pre_merge list; mutable on_pre_merge: ev_on_pre_merge list;
mutable on_post_merge: ev_on_post_merge list; mutable on_post_merge: ev_on_post_merge list;
mutable on_new_term: ev_on_new_term list; mutable on_new_term: ev_on_new_term list;
@ -291,10 +307,6 @@ module Make (A: CC_ARG)
mutable on_propagate: ev_on_propagate list; mutable on_propagate: ev_on_propagate list;
mutable on_is_subterm : ev_on_is_subterm list; mutable on_is_subterm : ev_on_is_subterm list;
mutable new_merges: bool; mutable new_merges: bool;
bitgen: Bits.bitfield_gen;
field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *)
true_ : node lazy_t;
false_ : node lazy_t;
stat: Stat.t; stat: Stat.t;
count_conflict: int Stat.counter; count_conflict: int Stat.counter;
count_props: int Stat.counter; count_props: int Stat.counter;
@ -622,7 +634,7 @@ module Make (A: CC_ARG)
(* remove term when we backtrack *) (* remove term when we backtrack *)
on_backtrack cc on_backtrack cc
(fun () -> (fun () ->
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t); Log.debugf 30 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t);
T_tbl.remove cc.tbl t); T_tbl.remove cc.tbl t);
(* add term to the table *) (* add term to the table *)
T_tbl.add cc.tbl t n; T_tbl.add cc.tbl t n;
@ -755,6 +767,51 @@ module Make (A: CC_ARG)
cc.new_merges <- true; cc.new_merges <- true;
task_merge_ cc acts a b e_ab task_merge_ cc acts a b e_ab
| CT_set_val (n, v) ->
task_set_val_ cc acts n v
and task_set_val_ cc acts n v =
let repr_n = find_ n in
(* - if repr(n) has value [v], do nothing
- else if repr(n) has value [v'], semantic conflict
- else add [repr(n) -> (n,v)] to cc.t_to_val *)
begin match T_tbl.find_opt cc.t_to_val repr_n.n_term with
| Some (n', v') when not (Term.equal v v') ->
(* semantic conflict *)
let expl = [Expl.mk_merge n n'] in
let expl_st = explain_expls cc expl in
let lits = expl_st.lits in
let tuples =
List.rev_map (fun (t,u) -> true, t.n_term, u.n_term) expl_st.same_val
in
let tuples = (false, n.n_term, n'.n_term) :: tuples in
Log.debugf 20
(fun k->k "(@[cc.semantic-conflict.set-val@ (@[set-val %a@ := %a@])@ \
(@[existing-val %a@ := %a@])@])"
N.pp n Term.pp v N.pp n' Term.pp v');
Stat.incr cc.count_semantic_conflict;
Actions.raise_semantic_conflict acts lits tuples
| Some _ -> ()
| None ->
T_tbl.add cc.t_to_val repr_n.n_term (n, v);
on_backtrack cc (fun () -> T_tbl.remove cc.t_to_val repr_n.n_term);
end;
(* now for the reverse map, look in self.val_to_t for [v].
- if present, push a merge command with Expl.mk_same_value
- if not, add [v -> n] *)
begin match T_tbl.find_opt cc.val_to_t v with
| None ->
T_tbl.add cc.val_to_t v n;
on_backtrack cc (fun () -> T_tbl.remove cc.val_to_t v);
| Some n' when not (same_class n n') ->
merge_classes cc n n' (Expl.mk_same_value n n')
| Some _ -> ()
end
(* main CC algo: merge equivalence classes in [st.combine]. (* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *) @raise Exn_unsat if merge fails *)
and task_merge_ cc acts a b e_ab : unit = and task_merge_ cc acts a b e_ab : unit =
@ -787,7 +844,7 @@ module Make (A: CC_ARG)
let lits = expl_st.lits in let lits = expl_st.lits in
let same_val = let same_val =
expl_st.same_val expl_st.same_val
|> List.rev_map (fun (t,u) -> N.term t, N.term u) in |> List.rev_map (fun (t,u) -> true, N.term t, N.term u) in
assert (same_val <> []); assert (same_val <> []);
Stat.incr cc.count_semantic_conflict; Stat.incr cc.count_semantic_conflict;
Actions.raise_semantic_conflict acts lits same_val Actions.raise_semantic_conflict acts lits same_val
@ -817,14 +874,17 @@ module Make (A: CC_ARG)
in in
merge_bool ra a rb b; merge_bool ra a rb b;
merge_bool rb b ra a; merge_bool rb b ra a;
(* perform [union r_from r_into] *) (* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* call [on_pre_merge] functions, and merge theory data items *) (* call [on_pre_merge] functions, and merge theory data items *)
begin begin
(* explanation is [a=ra & e_ab & b=rb] *) (* explanation is [a=ra & e_ab & b=rb] *)
let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in
List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge; List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge;
end; end;
begin begin
(* parents might have a different signature, check for collisions *) (* parents might have a different signature, check for collisions *)
N.iter_parents r_from N.iter_parents r_from
@ -848,8 +908,8 @@ module Make (A: CC_ARG)
(* on backtrack, unmerge classes and restore the pointers to [r_from] *) (* on backtrack, unmerge classes and restore the pointers to [r_from] *)
on_backtrack cc on_backtrack cc
(fun () -> (fun () ->
Log.debugf 15 Log.debugf 30
(fun k->k "(@[cc.undo_merge@ :from %a :into %a@])" (fun k->k "(@[cc.undo_merge@ :from %a@ :into %a@])"
N.pp r_from N.pp r_into); N.pp r_from N.pp r_into);
r_into.n_bits <- r_into_old_bits; r_into.n_bits <- r_into_old_bits;
r_into.n_next <- r_into_old_next; r_into.n_next <- r_into_old_next;
@ -861,6 +921,42 @@ module Make (A: CC_ARG)
r_into.n_size <- r_into.n_size - r_from.n_size; r_into.n_size <- r_into.n_size - r_from.n_size;
); );
end; end;
(* check for semantic values, update the one of [r_into]
if [r_from] has a value *)
begin match T_tbl.find_opt cc.t_to_val r_from.n_term with
| None -> ()
| Some (n_from, v_from) ->
begin match T_tbl.find_opt cc.t_to_val r_into.n_term with
| None ->
T_tbl.add cc.t_to_val r_into.n_term (n_from,v_from);
on_backtrack cc (fun () -> T_tbl.remove cc.t_to_val r_into.n_term);
| Some (n_into,v_into) when not (Term.equal v_from v_into) ->
(* semantic conflict, including [n_from != n_into] in model *)
let expl = [
e_ab; Expl.mk_merge r_from n_from;
Expl.mk_merge r_into n_into] in
let expl_st = explain_expls cc expl in
let lits = expl_st.lits in
let tuples =
List.rev_map (fun (t,u) -> true, t.n_term, u.n_term) expl_st.same_val
in
let tuples = (false, n_from.n_term, n_into.n_term) :: tuples in
Log.debugf 20
(fun k->k "(@[cc.semantic-conflict.post-merge@ \
(@[n-from %a@ := %a@])@ (@[n-into %a@ := %a@])@])"
N.pp n_from Term.pp v_from N.pp n_into Term.pp v_into);
Stat.incr cc.count_semantic_conflict;
Actions.raise_semantic_conflict acts
lits tuples
| Some _ -> ()
end
end;
(* update explanations (a -> b), arbitrarily. (* update explanations (a -> b), arbitrarily.
Note that here we merge the classes by adding a bridge between [a] Note that here we merge the classes by adding a bridge between [a]
and [b], not their roots. *) and [b], not their roots. *)
@ -908,23 +1004,24 @@ module Make (A: CC_ARG)
let lit = if sign then lit else Lit.neg lit in (* apply sign *) let lit = if sign then lit else Lit.neg lit in (* apply sign *)
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *) (* complete explanation with the [u1=t1] chunk *)
let reason = let lazy st = half_expl_and_pr in
let e = lazy ( let st = Expl_state.copy st in (* do not modify shared st *)
let lazy st = half_expl_and_pr in explain_equal_rec_ cc st u1 t1;
explain_equal_rec_ cc st u1 t1;
(* propagate only if this doesn't depend on some semantic values *)
if not (Expl_state.is_semantic st) then (
let reason () =
(* true literals explaining why t1=t2 *) (* true literals explaining why t1=t2 *)
let guard = st.lits in let guard = st.lits in
(* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *) (* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *)
let st = Expl_state.copy st in (* do not modify shared st *)
Expl_state.add_lit st (Lit.neg lit); Expl_state.add_lit st (Lit.neg lit);
let _, pr = lits_and_proof_of_expl cc st in let _, pr = lits_and_proof_of_expl cc st in
guard, pr guard, pr
) in in
fun () -> Lazy.force e List.iter (fun f -> f cc lit reason) cc.on_propagate;
in Stat.incr cc.count_props;
List.iter (fun f -> f cc lit reason) cc.on_propagate; Actions.propagate acts lit ~reason
Stat.incr cc.count_props; )
Actions.propagate acts lit ~reason
| _ -> ()) | _ -> ())
module Debug_ = struct module Debug_ = struct
@ -998,8 +1095,9 @@ module Make (A: CC_ARG)
let[@inline] merge_t cc t1 t2 expl = let[@inline] merge_t cc t1 t2 expl =
merge cc (add_term cc t1) (add_term cc t2) expl merge cc (add_term cc t1) (add_term cc t2) expl
let merge_same_value cc n1 n2 = merge cc n1 n2 (Expl.mk_same_value n1 n2) let set_model_value (self:t) (t:term) (v:value) : unit =
let merge_same_value_t cc t1 t2 = merge_same_value cc (add_term cc t1) (add_term cc t2) let n = add_term self t in
Vec.push self.combine (CT_set_val (n,v))
let explain_eq cc n1 n2 : Resolved_expl.t = let explain_eq cc n1 n2 : Resolved_expl.t =
let st = Expl_state.create() in let st = Expl_state.create() in
@ -1027,6 +1125,8 @@ module Make (A: CC_ARG)
tbl = T_tbl.create size; tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size; signatures_tbl = Sig_tbl.create size;
bitgen; bitgen;
t_to_val=T_tbl.create 32;
val_to_t=T_tbl.create 32;
on_pre_merge; on_pre_merge;
on_post_merge; on_post_merge;
on_new_term; on_new_term;

View file

@ -332,11 +332,13 @@ module type CC_ACTIONS = sig
exception). exception).
@param pr the proof of [c] being a tautology *) @param pr the proof of [c] being a tautology *)
val raise_semantic_conflict : t -> Lit.t list -> (T.Term.t * T.Term.t) list -> 'a val raise_semantic_conflict : t -> Lit.t list -> (bool * T.Term.t * T.Term.t) list -> 'a
(** [raise_semantic_conflict acts lits same_val] declares that (** [raise_semantic_conflict acts lits same_val] declares that
the conjunction of all [lits] (literals true in current trail) the conjunction of all [lits] (literals true in current trail) and tuples
and pairs [t_i = u_i] (which are pairs of terms with the same value [{=,}, t_i, u_i] implies false.
in the current model), implies false.
The [{=,}, t_i, u_i] are pairs of terms with the same value (if [=] / true)
or distinct value (if [] / false)) in the current model.
This does not return. It should raise an exception. This does not return. It should raise an exception.
*) *)
@ -410,6 +412,7 @@ module type CC_S = sig
and type proof_step = proof_step and type proof_step = proof_step
type term_store = T.Term.store type term_store = T.Term.store
type term = T.Term.t type term = T.Term.t
type value = term
type fun_ = T.Fun.t type fun_ = T.Fun.t
type lit = Lit.t type lit = Lit.t
type actions = Actions.t type actions = Actions.t
@ -726,11 +729,8 @@ module type CC_S = sig
val merge_t : t -> term -> term -> Expl.t -> unit val merge_t : t -> term -> term -> Expl.t -> unit
(** Shortcut for adding + merging *) (** Shortcut for adding + merging *)
val merge_same_value : t -> N.t -> N.t -> unit val set_model_value : t -> term -> value -> unit
(** Merge these two nodes because they have the same value (** Set the value of a term in the model. *)
in the model. The explanation will be {!Expl.mk_same_value}. *)
val merge_same_value_t : t -> term -> term -> unit
val check : t -> actions -> unit val check : t -> actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
@ -783,6 +783,7 @@ module type SOLVER_INTERNAL = sig
type ty = T.Ty.t type ty = T.Ty.t
type term = T.Term.t type term = T.Term.t
type value = T.Term.t
type term_store = T.Term.store type term_store = T.Term.store
type ty_store = T.Ty.store type ty_store = T.Ty.store
type clause_pool type clause_pool
@ -1029,11 +1030,14 @@ module type SOLVER_INTERNAL = sig
is given the whole trail. is given the whole trail.
*) *)
val on_th_combination : t -> (t -> theory_actions -> term list Iter.t) -> unit val on_th_combination : t -> (t -> theory_actions -> (term * value) Iter.t) -> unit
(** Add a hook called during theory combination. (** Add a hook called during theory combination.
The hook must return an iterator of lists, each list [t1tn] The hook must return an iterator of pairs [(t, v)]
is a set of terms that have the same value in the model which mean that term [t] has value [v] in the model.
(and therefore must be merged). *)
Terms with the same value (according to {!Term.equal}) will be
merged in the CC; if two terms with different values are merged,
we get a semantic conflict and must pick another model. *)
val declare_pb_is_incomplete : t -> unit val declare_pb_is_incomplete : t -> unit
(** Declare that, in some theory, the problem is outside the logic fragment (** Declare that, in some theory, the problem is outside the logic fragment

View file

@ -238,6 +238,7 @@ module Make(A : ARG) : S with module A = A = struct
encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *) encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *)
needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *) needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *)
simp_preds: (T.t * S_op.t * A.Q.t) T.Tbl.t; (* term -> its simplex meaning *) simp_preds: (T.t * S_op.t * A.Q.t) T.Tbl.t; (* term -> its simplex meaning *)
simp_defined: LE.t T.Tbl.t; (* (rational) terms that are equal to a linexp *)
st_exprs : ST_exprs.t; st_exprs : ST_exprs.t;
mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *) mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *)
simplex: SimpSolver.t; simplex: SimpSolver.t;
@ -255,6 +256,7 @@ module Make(A : ARG) : S with module A = A = struct
st_exprs=ST_exprs.create_and_setup si; st_exprs=ST_exprs.create_and_setup si;
gensym=A.Gensym.create tst; gensym=A.Gensym.create tst;
simp_preds=T.Tbl.create 32; simp_preds=T.Tbl.create 32;
simp_defined=T.Tbl.create 16;
encoded_eqs=T.Tbl.create 8; encoded_eqs=T.Tbl.create 8;
needs_th_combination=T.Tbl.create 8; needs_th_combination=T.Tbl.create 8;
encoded_le=Comb_map.empty; encoded_le=Comb_map.empty;
@ -293,7 +295,6 @@ module Make(A : ARG) : S with module A = A = struct
let[@inline] as_const_ t = match A.view_as_lra t with LRA_const n -> Some n | _ -> None let[@inline] as_const_ t = match A.view_as_lra t with LRA_const n -> Some n | _ -> None
let[@inline] is_zero t = match A.view_as_lra t with LRA_const n -> A.Q.(n = zero) | _ -> false let[@inline] is_zero t = match A.view_as_lra t with LRA_const n -> A.Q.(n = zero) | _ -> false
let t_of_comb (self:state) (comb:LE_.Comb.t) ~(init:T.t) : T.t = let t_of_comb (self:state) (comb:LE_.Comb.t) ~(init:T.t) : T.t =
let[@inline] (+) a b = A.mk_lra self.tst (LRA_op (Plus, a, b)) in let[@inline] (+) a b = A.mk_lra self.tst (LRA_op (Plus, a, b)) in
let[@inline] ( * ) a b = A.mk_lra self.tst (LRA_mult (a, b)) in let[@inline] ( * ) a b = A.mk_lra self.tst (LRA_mult (a, b)) in
@ -379,11 +380,11 @@ module Make(A : ARG) : S with module A = A = struct
Log.debugf 50 (fun k->k "(@[lra.cc-on-subterm@ %a@])" T.pp t); Log.debugf 50 (fun k->k "(@[lra.cc-on-subterm@ %a@])" T.pp t);
match A.view_as_lra t with match A.view_as_lra t with
| LRA_other _ when not (A.has_ty_real t) -> () | LRA_other _ when not (A.has_ty_real t) -> ()
| LRA_pred _ -> () | LRA_pred _ | LRA_const _ -> ()
| LRA_op _ | LRA_const _ | LRA_other _ | LRA_mult _ -> | LRA_op _ | LRA_other _ | LRA_mult _ ->
if not (T.Tbl.mem self.needs_th_combination t) then ( if not (T.Tbl.mem self.needs_th_combination t) then (
Log.debugf 5 (fun k->k "(@[lra.needs-th-combination@ %a@])" T.pp t); Log.debugf 5 (fun k->k "(@[lra.needs-th-combination@ %a@])" T.pp t);
T.Tbl.add self.needs_th_combination t () T.Tbl.add self.needs_th_combination t ();
) )
(* preprocess linear expressions away *) (* preprocess linear expressions away *)
@ -464,12 +465,11 @@ module Make(A : ARG) : S with module A = A = struct
T.pp t SimpSolver.Constraint.pp constr); T.pp t SimpSolver.Constraint.pp constr);
| LRA_op _ | LRA_mult _ -> | LRA_op _ | LRA_mult _ ->
(* NOTE: we don't need to do anything for rational subterms, at least if not (T.Tbl.mem self.simp_defined t) then (
not at first. Only when theory combination mandates we compare (* we define these terms so their value in the model make sense *)
two terms (by deciding [t1 = t2]) do they impact the simplex; and let le = as_linexp t in
then they're moved into an equation, which means they are T.Tbl.add self.simp_defined t le;
preprocessed in the LRA_pred case above. *) );
()
| LRA_const _n -> () | LRA_const _n -> ()
@ -619,98 +619,54 @@ module Make(A : ARG) : S with module A = A = struct
let t2 = N.term n2 in let t2 = N.term n2 in
add_local_eq_t self si acts t1 t2 ~tag:(Tag.CC_eq (n1, n2)) add_local_eq_t self si acts t1 t2 ~tag:(Tag.CC_eq (n1, n2))
(* (* evaluate a term directly, as a variable *)
(* theory combination: add decisions [t=u] whenever [t] and [u] let eval_in_subst_ subst t = match A.view_as_lra t with
have the same value in [subst] and both occur under function symbols *) | LRA_const n -> n
let do_th_combination (self:state) si acts (subst:Subst.t) : unit = | _ -> Subst.eval subst t |> CCOpt.get_or ~default:A.Q.zero
Log.debug 1 "(lra.do-th-combinations)";
let n_th_comb = T.Tbl.keys self.needs_th_combination |> Iter.length in
if n_th_comb > 0 then (
Log.debugf 5
(fun k->k "(@[lra.needs-th-combination@ :n-lits %d@])" n_th_comb);
Log.debugf 50
(fun k->k "(@[lra.needs-th-combination@ :terms [@[%a@]]@])"
(Util.pp_iter @@ Fmt.within "`" "`" T.pp) (T.Tbl.keys self.needs_th_combination));
);
let eval_in_subst_ subst t = match A.view_as_lra t with (* evaluate a linear expression *)
| LRA_const n -> n let eval_le_in_subst_ subst (le:LE.t) =
| _ -> Subst.eval subst t |> CCOpt.get_or ~default:A.Q.zero LE.eval (eval_in_subst_ subst) le
in
let n = ref 0 in (* FIXME: rename, this is more "provide_model_to_cc" *)
(* theory combination: for [t1,t2] terms in [self.needs_th_combination] let do_th_combination (self:state) _si _acts : _ Iter.t =
that have same value, but are not provably equal, push
decision [t1=t2] into the SAT solver. *)
begin
let by_val: T.t list Q_map.t =
T.Tbl.keys self.needs_th_combination
|> Iter.map (fun t -> eval_in_subst_ subst t, t)
|> Iter.fold
(fun m (q,t) ->
let l = Q_map.get_or ~default:[] q m in
Q_map.add q (t::l) m)
Q_map.empty
in
Q_map.iter
(fun _q ts ->
begin match ts with
| [] | [_] -> ()
| ts ->
(* several terms! see if they are already equal *)
CCList.diagonal ts
|> List.iter
(fun (t1,t2) ->
Log.debugf 50
(fun k->k "(@[LRA.th-comb.check-pair[val=%a]@ %a@ %a@])"
A.Q.pp _q T.pp t1 T.pp t2);
assert(SI.cc_mem_term si t1);
assert(SI.cc_mem_term si t2);
(* if both [t1] and [t2] are relevant to the congruence
closure, and are not equal in it yet, add [t1=t2] as
the next decision to do *)
if not (SI.cc_are_equal si t1 t2) then (
Log.debugf 50
(fun k->k
"(@[lra.th-comb.must-decide-equal@ :t1 %a@ :t2 %a@])" T.pp t1 T.pp t2);
Stat.incr self.stat_th_comb;
Profile.instant "lra.th-comb-assert-eq";
let t = A.mk_eq (SI.tst si) t1 t2 in
let lit = SI.mk_lit si acts t in
incr n;
SI.push_decision si acts lit
)
)
end)
by_val;
()
end;
Log.debugf 1 (fun k->k "(@[lra.do-th-combinations.done@ :new-lits %d@])" !n);
()
*)
let do_th_combination (self:state) _si _acts : A.term list Iter.t =
Log.debug 1 "(lra.do-th-combinations)"; Log.debug 1 "(lra.do-th-combinations)";
let model = match self.last_res with let model = match self.last_res with
| Some (SimpSolver.Sat m) -> m | Some (SimpSolver.Sat m) -> m
| _ -> assert false | _ -> assert false
in in
(* gather terms by their model value *) let vals =
let tbl = Q_tbl.create 32 in Subst.to_iter model |> T.Tbl.of_iter
Subst.to_iter model in
(fun (t,q) ->
let l = Q_tbl.get_or ~default:[] tbl q in
Q_tbl.replace tbl q (t :: l));
(* now return classes of terms *) (* also include terms that occur under function symbols, if they're
Q_tbl.to_iter tbl not in the model already *)
|> Iter.filter_map T.Tbl.iter
(fun (_q, l) -> (fun t () ->
match l with if not (T.Tbl.mem vals t) then (
| [] | [_] -> None let v = eval_in_subst_ model t in
| l -> Some l) T.Tbl.add vals t v;
))
self.needs_th_combination;
(* also consider subterms that are linear expressions,
and evaluate them using the value of each variable
in that linear expression. For example a term [a + 2b]
is evaluated as [eval(a) + 2 × eval(b)]. *)
T.Tbl.iter
(fun t le ->
if not (T.Tbl.mem vals t) then (
let v = eval_le_in_subst_ model le in
T.Tbl.add vals t v
))
self.simp_defined;
(* return whole model *)
begin
T.Tbl.to_iter vals
|> Iter.map (fun (t,v) -> t, t_const self v)
end
(* partial checks is where we add literals from the trail to the (* partial checks is where we add literals from the trail to the
simplex. *) simplex. *)

View file

@ -15,7 +15,7 @@ module Make(C : COEFF)(Var : VAR) = struct
module Var = Var module Var = Var
type var = Var.t type var = Var.t
type subst = C.t Var_map.t type subst = (Var.t -> C.t)
(** Linear combination of variables. *) (** Linear combination of variables. *)
module Comb = struct module Comb = struct
@ -87,7 +87,7 @@ module Make(C : COEFF)(Var : VAR) = struct
let eval (subst : subst) (e:t) : C.t = let eval (subst : subst) (e:t) : C.t =
Var_map.fold Var_map.fold
(fun x c acc -> C.(acc + c * (Var_map.find x subst))) (fun x c acc -> C.(acc + c * subst x))
e C.zero e C.zero
end end

View file

@ -84,7 +84,7 @@ module type S = sig
module Var_map : CCMap.S with type key = var module Var_map : CCMap.S with type key = var
(** Maps from variables, used for expressions as well as substitutions. *) (** Maps from variables, used for expressions as well as substitutions. *)
type subst = C.t Var_map.t type subst = Var.t -> C.t
(** Type for substitutions. *) (** Type for substitutions. *)
(** Combinations. (** Combinations.

View file

@ -88,6 +88,7 @@ module Make(A : ARG)
module Term = T.Term module Term = T.Term
module Lit = A.Lit module Lit = A.Lit
type term = Term.t type term = Term.t
type value = term
type ty = Ty.t type ty = Ty.t
type proof = A.proof type proof = A.proof
type proof_step = A.proof_step type proof_step = A.proof_step
@ -101,7 +102,8 @@ module Make(A : ARG)
and doesn't need to kill the current trail. *) and doesn't need to kill the current trail. *)
type th_combination_conflict = { type th_combination_conflict = {
lits: lit list; lits: lit list;
same_val: (term*term) list; semantic: (bool*term*term) list;
(* set of semantic eqns/diseqns (ie true only in current model) *)
} }
exception Semantic_conflict of th_combination_conflict exception Semantic_conflict of th_combination_conflict
@ -128,8 +130,8 @@ module Make(A : ARG)
let[@inline] raise_conflict (a:t) lits (pr:proof_step) = let[@inline] raise_conflict (a:t) lits (pr:proof_step) =
let (module A) = a in let (module A) = a in
A.raise_conflict lits pr A.raise_conflict lits pr
let[@inline] raise_semantic_conflict (_:t) lits same_val = let[@inline] raise_semantic_conflict (_:t) lits semantic =
raise (Semantic_conflict {lits; same_val}) raise (Semantic_conflict {lits; semantic})
let[@inline] propagate (a:t) lit ~reason = let[@inline] propagate (a:t) lit ~reason =
let (module A) = a in let (module A) = a in
let reason = Sidekick_sat.Consequence reason in let reason = Sidekick_sat.Consequence reason in
@ -163,6 +165,7 @@ module Make(A : ARG)
type nonrec proof = proof type nonrec proof = proof
type nonrec proof_step = proof_step type nonrec proof_step = proof_step
type term = Term.t type term = Term.t
type value = term
type ty = Ty.t type ty = Ty.t
type lit = Lit.t type lit = Lit.t
type term_store = Term.store type term_store = Term.store
@ -274,7 +277,7 @@ module Make(A : ARG)
mutable on_progress: unit -> unit; mutable on_progress: unit -> unit;
mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_th_combination: (t -> theory_actions -> term list Iter.t) list; mutable on_th_combination: (t -> theory_actions -> (term*value) Iter.t) list;
mutable preprocess: preprocess_hook list; mutable preprocess: preprocess_hook list;
mutable model_ask: model_ask_hook list; mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list; mutable model_complete: model_completion_hook list;
@ -573,23 +576,17 @@ module Make(A : ARG)
let cc = cc self in let cc = cc self in
with_cc_level_ cc @@ fun () -> with_cc_level_ cc @@ fun () ->
(* merge all terms in the class *) let set_val (t,v) : unit =
let merge_cls (cls:term list) : unit = Log.debugf 50
match cls with (fun k->k "(@[solver.th-comb.cc-set-term-value@ %a@ :val %a@])"
| [] -> assert false Term.pp t Term.pp v);
| [_] -> () CC.set_model_value cc t v
| t :: ts ->
Log.debugf 50
(fun k->k "(@[solver.th-comb.merge-cls@ %a@])"
(Util.pp_list Term.pp) cls);
List.iter (fun u -> CC.merge_same_value_t cc t u) ts
in in
(* obtain classes of equal terms from the hook, and merge them *) (* obtain classes of equal terms from the hook, and merge them *)
let add_th_equalities f : unit = let add_th_equalities f : unit =
let cls = f self acts in let vals = f self acts in
Iter.iter merge_cls cls Iter.iter set_val vals
in in
try try
@ -624,33 +621,38 @@ module Make(A : ARG)
) )
done; done;
CC.check cc acts;
let new_merges_in_cc = CC.new_merges cc in
begin match check_th_combination_ self acts with begin match check_th_combination_ self acts with
| Ok () -> () | Ok () -> ()
| Error {lits; same_val} -> | Error {lits; semantic} ->
(* bad model, we add a clause to remove it *) (* bad model, we add a clause to remove it *)
Log.debugf 10 Log.debugf 10
(fun k->k "(@[solver.th-comb.conflict@ :lits (@[%a@])@ \ (fun k->k "(@[solver.th-comb.conflict@ :lits (@[%a@])@ \
:same-val (@[%a@])@])" :same-val (@[%a@])@])"
(Util.pp_list Lit.pp) lits (Util.pp_list Lit.pp) lits
(Util.pp_list @@ Fmt.Dump.pair Term.pp Term.pp) same_val); (Util.pp_list @@ Fmt.Dump.(triple bool Term.pp Term.pp)) semantic);
let c1 = List.rev_map Lit.neg lits in let c1 = List.rev_map Lit.neg lits in
let c2 = let c2 =
List.rev_map (fun (t,u) -> semantic
Lit.atom ~sign:false self.tst @@ A.mk_eq self.tst t u) same_val |> List.rev_map
(fun (sign,t,u) ->
Lit.atom ~sign:(not sign) self.tst @@ A.mk_eq self.tst t u)
in in
let c = List.rev_append c1 c2 in let c = List.rev_append c1 c2 in
let pr = P.lemma_cc (Iter.of_list c) self.proof in let pr = P.lemma_cc (Iter.of_list c) self.proof in
Log.debugf 20 Log.debugf 20
(fun k->k "(@[solver.th-comb.add-clause@ %a@])" (fun k->k "(@[solver.th-comb.add-semantic-conflict-clause@ %a@])"
(Util.pp_list Lit.pp) c); (Util.pp_list Lit.pp) c);
(* will add a delayed action *)
add_clause_temp self acts c pr; add_clause_temp self acts c pr;
end; end;
CC.check cc acts; if not new_merges_in_cc && not (has_delayed_actions self) then (
if not (CC.new_merges cc) && not (has_delayed_actions self) then (
continue := false; continue := false;
); );
done; done;