wip: theory combination by exposing model (classes) directly to CC

This commit is contained in:
Simon Cruanes 2022-02-17 13:49:47 -05:00
parent 65d4a90df1
commit fd66039c8d
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
5 changed files with 273 additions and 57 deletions

View file

@ -90,13 +90,14 @@ module Make (A: CC_ARG)
(* atomic explanation in the congruence closure *) (* atomic explanation in the congruence closure *)
and explanation = and explanation =
| E_reduction (* by pure reduction, tautologically equal *) | E_trivial (* by pure reduction, tautologically equal *)
| E_lit of lit (* because of this literal *) | E_lit of lit (* because of this literal *)
| E_merge of node * node | E_merge of node * node
| E_merge_t of term * term | E_merge_t of term * term
| E_congruence of node * node (* caused by normal congruence *) | E_congruence of node * node (* caused by normal congruence *)
| E_and of explanation * explanation | E_and of explanation * explanation
| E_theory of term * term * (term * term * explanation list) list * proof_step | E_theory of term * term * (term * term * explanation list) list * proof_step
| E_same_val of node * node
type repr = node type repr = node
@ -162,7 +163,7 @@ module Make (A: CC_ARG)
type t = explanation type t = explanation
let rec pp out (e:explanation) = match e with let rec pp out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction" | E_trivial -> Fmt.string out "reduction"
| E_lit lit -> Lit.pp out lit | E_lit lit -> Lit.pp out lit
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
@ -174,25 +175,49 @@ module Make (A: CC_ARG)
(Util.pp_list @@ Fmt.Dump.triple Term.pp Term.pp (Fmt.Dump.list pp)) es (Util.pp_list @@ Fmt.Dump.triple Term.pp Term.pp (Fmt.Dump.list pp)) es
| E_and (a,b) -> | E_and (a,b) ->
Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
| E_same_val (n1,n2) ->
Fmt.fprintf out "(@[same-value@ %a@ %a@])" N.pp n1 N.pp n2
let mk_reduction : t = E_reduction let mk_trivial : t = E_trivial
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2)
let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b) let[@inline] mk_merge a b : t = if N.equal a b then mk_trivial else E_merge (a,b)
let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b) let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_trivial else E_merge_t (a,b)
let[@inline] mk_lit l : t = E_lit l let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_theory t u es pr = E_theory (t,u,es,pr) let[@inline] mk_theory t u es pr = E_theory (t,u,es,pr)
let[@inline] mk_same_value t u = if N.equal t u then mk_trivial else E_same_val (t,u)
let rec mk_list l = let rec mk_list l =
match l with match l with
| [] -> mk_reduction | [] -> mk_trivial
| [x] -> x | [x] -> x
| E_reduction :: tl -> mk_list tl | E_trivial :: tl -> mk_list tl
| x :: y -> | x :: y ->
match mk_list y with match mk_list y with
| E_reduction -> x | E_trivial -> x
| y' -> E_and (x,y') | y' -> E_and (x,y')
end end
module Resolved_expl = struct
type t = {
lits: lit list;
same_value: (N.t * N.t) list;
pr: proof -> proof_step;
}
let[@inline] is_semantic (self:t) : bool =
match self.same_value with [] -> false | _::_ -> true
let pp out (self:t) =
if not (is_semantic self) then (
Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits
) else (
let {lits; same_value; pr=_} = self in
Fmt.fprintf out "(@[resolved-expl@ (@[%a@])@ :same-val (@[%a@])@])"
(Util.pp_list Lit.pp) lits
(Util.pp_list @@ Fmt.Dump.pair N.pp N.pp) same_value
)
end
(** A signature is a shallow term shape where immediate subterms (** A signature is a shallow term shape where immediate subterms
are representative *) are representative *)
module Signature = struct module Signature = struct
@ -274,6 +299,7 @@ module Make (A: CC_ARG)
count_conflict: int Stat.counter; count_conflict: int Stat.counter;
count_props: int Stat.counter; count_props: int Stat.counter;
count_merge: int Stat.counter; count_merge: int Stat.counter;
count_semantic_conflict: int Stat.counter;
} }
(* TODO: an additional union-find to keep track, for each term, (* TODO: an additional union-find to keep track, for each term,
of the terms they are known to be equal to, according of the terms they are known to be equal to, according
@ -443,28 +469,71 @@ module Make (A: CC_ARG)
module Expl_state = struct module Expl_state = struct
type t = { type t = {
mutable lits: Lit.t list; mutable lits: Lit.t list;
mutable same_val: (N.t * N.t) list;
mutable th_lemmas: mutable th_lemmas:
(Lit.t * (Lit.t * Lit.t list) list * proof_step) list; (Lit.t * (Lit.t * Lit.t list) list * proof_step) list;
} }
let create(): t = { lits=[]; th_lemmas=[] } let create(): t = { lits=[]; same_val=[]; th_lemmas=[] }
let[@inline] copy self : t = {self with lits=self.lits} let[@inline] copy self : t = {self with lits=self.lits}
let[@inline] add_lit (self:t) lit = self.lits <- lit :: self.lits let[@inline] add_lit (self:t) lit = self.lits <- lit :: self.lits
let[@inline] add_th (self:t) lit hyps pr : unit = let[@inline] add_th (self:t) lit hyps pr : unit =
self.th_lemmas <- (lit,hyps,pr) :: self.th_lemmas self.th_lemmas <- (lit,hyps,pr) :: self.th_lemmas
let[@inline] add_same_val (self:t) n1 n2 : unit =
self.same_val <- (n1,n2) :: self.same_val
(** Does this explanation contain at least one merge caused by
"same value"? *)
let[@inline] is_semantic (self:t): bool = self.same_val <> []
let merge self other = let merge self other =
let {lits=o_lits; th_lemmas=o_lemmas} = other in let {lits=o_lits; th_lemmas=o_lemmas;same_val=o_same_val} = other in
self.lits <- List.rev_append o_lits self.lits; self.lits <- List.rev_append o_lits self.lits;
self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas;
self.same_val <- List.rev_append o_same_val self.same_val;
()
(* proof of [\/_i ¬lits[i]] *)
let proof_of_th_lemmas (self:t) (proof:proof) : proof_step =
let p_lits1 = Iter.of_list self.lits |> Iter.map Lit.neg in
let p_lits2 =
Iter.of_list self.th_lemmas
|> Iter.map (fun (lit_t_u,_,_) -> Lit.neg lit_t_u)
in
let p_cc = P.lemma_cc (Iter.append p_lits1 p_lits2) proof in
let resolve_with_th_proof pr (lit_t_u,sub_proofs,pr_th) =
(* pr_th: [sub_proofs |- t=u].
now resolve away [sub_proofs] to get literals that were
asserted in the congruence closure *)
let pr_th = List.fold_left
(fun pr_th (lit_i,hyps_i) ->
(* [hyps_i |- lit_i] *)
let lemma_i =
P.lemma_cc Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) proof
in
(* resolve [lit_i] away. *)
P.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th proof)
pr_th sub_proofs
in
P.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr proof
in
(* resolve with theory proofs responsible for some merges, if any. *)
List.fold_left resolve_with_th_proof p_cc self.th_lemmas
let to_resolved_expl (self:t) : Resolved_expl.t =
(* FIXME: package the th lemmas too *)
let {lits; same_val; th_lemmas=_} = self in
let s2 = copy self in
let pr proof = proof_of_th_lemmas s2 proof in
{Resolved_expl.lits; same_value=same_val; pr}
end end
(* decompose explanation [e] into a list of literals added to [acc] *) (* decompose explanation [e] into a list of literals added to [acc] *)
let rec explain_decompose_expl cc (st:Expl_state.t) (e:explanation) : unit = let rec explain_decompose_expl cc (st:Expl_state.t) (e:explanation) : unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
match e with match e with
| E_reduction -> () | E_trivial -> ()
| E_congruence (n1, n2) -> | E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with begin match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
@ -482,6 +551,7 @@ module Make (A: CC_ARG)
assert false assert false
end end
| E_lit lit -> Expl_state.add_lit st lit | E_lit lit -> Expl_state.add_lit st lit
| E_same_val (n1, n2) -> Expl_state.add_same_val st n1 n2
| E_theory (t, u, expl_sets, pr) -> | E_theory (t, u, expl_sets, pr) ->
let sub_proofs = let sub_proofs =
List.map List.map
@ -625,37 +695,9 @@ module Make (A: CC_ARG)
merges. *) merges. *)
let lits_and_proof_of_expl let lits_and_proof_of_expl
(self:t) (st:Expl_state.t) : Lit.t list * proof_step = (self:t) (st:Expl_state.t) : Lit.t list * proof_step =
let {Expl_state.lits; th_lemmas} = st in let {Expl_state.lits; th_lemmas=_; same_val} = st in
let proof = self.proof in assert (same_val = []);
(* proof of [\/_i ¬lits[i]] *) let pr = Expl_state.proof_of_th_lemmas st self.proof in
let pr =
let p_lits1 = Iter.of_list lits |> Iter.map Lit.neg in
let p_lits2 =
Iter.of_list th_lemmas
|> Iter.map (fun (lit_t_u,_,_) -> Lit.neg lit_t_u)
in
let p_cc = P.lemma_cc (Iter.append p_lits1 p_lits2) proof in
let resolve_with_th_proof pr (lit_t_u,sub_proofs,pr_th) =
(* pr_th: [sub_proofs |- t=u].
now resolve away [sub_proofs] to get literals that were
asserted in the congruence closure *)
let pr_th = List.fold_left
(fun pr_th (lit_i,hyps_i) ->
(* [hyps_i |- lit_i] *)
let lemma_i =
P.lemma_cc Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) proof
in
(* resolve [lit_i] away. *)
P.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th proof)
pr_th sub_proofs
in
P.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr proof
in
(* resolve with theory proofs responsible for some merges, if any. *)
List.fold_left resolve_with_th_proof p_cc th_lemmas
in
lits, pr lits, pr
(* main CC algo: add terms from [pending] to the signature table, (* main CC algo: add terms from [pending] to the signature table,
@ -740,8 +782,21 @@ module Make (A: CC_ARG)
explain_equal_rec_ cc expl_st a ra; explain_equal_rec_ cc expl_st a ra;
explain_equal_rec_ cc expl_st b rb; explain_equal_rec_ cc expl_st b rb;
if Expl_state.is_semantic expl_st then (
(* conflict involving some semantic values *)
let lits = expl_st.lits in
let same_val =
expl_st.same_val
|> List.rev_map (fun (t,u) -> N.term t, N.term u) in
assert (same_val <> []);
Stat.incr cc.count_semantic_conflict;
Actions.raise_semantic_conflict acts lits same_val
) else (
(* regular conflict *)
let lits, pr = lits_and_proof_of_expl cc expl_st in let lits, pr = lits_and_proof_of_expl cc expl_st in
raise_conflict_ cc ~th:!th acts (List.rev_map Lit.neg lits) pr raise_conflict_ cc ~th:!th acts (List.rev_map Lit.neg lits) pr
)
); );
(* We will merge [r_from] into [r_into]. (* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always we try to ensure that [size ra <= size rb] in general, but always
@ -943,11 +998,14 @@ module Make (A: CC_ARG)
let[@inline] merge_t cc t1 t2 expl = let[@inline] merge_t cc t1 t2 expl =
merge cc (add_term cc t1) (add_term cc t2) expl merge cc (add_term cc t1) (add_term cc t2) expl
let explain_eq cc n1 n2 : lit list = let merge_same_value cc n1 n2 = merge cc n1 n2 (Expl.mk_same_value n1 n2)
let merge_same_value_t cc t1 t2 = merge_same_value cc (add_term cc t1) (add_term cc t2)
let explain_eq cc n1 n2 : Resolved_expl.t =
let st = Expl_state.create() in let st = Expl_state.create() in
explain_equal_rec_ cc st n1 n2; explain_equal_rec_ cc st n1 n2;
(* FIXME: also need to return the proof? *) (* FIXME: also need to return the proof? *)
st.lits Expl_state.to_resolved_expl st
let on_pre_merge cc f = cc.on_pre_merge <- f :: cc.on_pre_merge let on_pre_merge cc f = cc.on_pre_merge <- f :: cc.on_pre_merge
let on_post_merge cc f = cc.on_post_merge <- f :: cc.on_post_merge let on_post_merge cc f = cc.on_post_merge <- f :: cc.on_post_merge
@ -986,6 +1044,7 @@ module Make (A: CC_ARG)
count_conflict=Stat.mk_int stat "cc.conflicts"; count_conflict=Stat.mk_int stat "cc.conflicts";
count_props=Stat.mk_int stat "cc.propagations"; count_props=Stat.mk_int stat "cc.propagations";
count_merge=Stat.mk_int stat "cc.merges"; count_merge=Stat.mk_int stat "cc.merges";
count_semantic_conflict=Stat.mk_int stat "cc.semantic-conflicts";
} and true_ = lazy ( } and true_ = lazy (
add_term cc (Term.bool tst true) add_term cc (Term.bool tst true)
) and false_ = lazy ( ) and false_ = lazy (

View file

@ -332,6 +332,15 @@ module type CC_ACTIONS = sig
exception). exception).
@param pr the proof of [c] being a tautology *) @param pr the proof of [c] being a tautology *)
val raise_semantic_conflict : t -> Lit.t list -> (T.Term.t * T.Term.t) list -> 'a
(** [raise_semantic_conflict acts lits same_val] declares that
the conjunction of all [lits] (literals true in current trail)
and pairs [t_i = u_i] (which are pairs of terms with the same value
in the current model), implies false.
This does not return. It should raise an exception.
*)
val propagate : t -> Lit.t -> reason:(unit -> Lit.t list * proof_step) -> unit val propagate : t -> Lit.t -> reason:(unit -> Lit.t list * proof_step) -> unit
(** [propagate acts lit ~reason pr] declares that [reason() => lit] (** [propagate acts lit ~reason pr] declares that [reason() => lit]
is a tautology. is a tautology.
@ -484,6 +493,7 @@ module type CC_S = sig
val pp : t Fmt.printer val pp : t Fmt.printer
val mk_merge : N.t -> N.t -> t val mk_merge : N.t -> N.t -> t
(** Explanation: the nodes were explicitly merged *)
val mk_merge_t : term -> term -> t val mk_merge_t : term -> term -> t
(** Explanation: the terms were explicitly merged *) (** Explanation: the terms were explicitly merged *)
@ -493,6 +503,8 @@ module type CC_S = sig
or we merged [t] and [true] because of literal [t], or we merged [t] and [true] because of literal [t],
or [t] and [false] because of literal [¬t] *) or [t] and [false] because of literal [¬t] *)
val mk_same_value : N.t -> N.t -> t
val mk_list : t list -> t val mk_list : t list -> t
(** Conjunction of explanations *) (** Conjunction of explanations *)
@ -520,6 +532,31 @@ module type CC_S = sig
*) *)
end end
(** Resolved explanations.
The congruence closure keeps explanations for why terms are in the same
class. However these are represented in a compact, cheap form.
To use these explanations we need to {b resolve} them into a
resolved explanation, typically a list of
literals that are true in the current trail and are responsible for
merges.
However, we can also have merged classes because they have the same value
in the current model. *)
module Resolved_expl : sig
type t = {
lits: lit list;
same_value: (N.t * N.t) list;
pr: proof -> proof_step;
}
val is_semantic : t -> bool
(** [is_semantic expl] is [true] if there's at least one
pair in [expl.same_value]. *)
val pp : t Fmt.printer
end
type node = N.t type node = N.t
(** A node of the congruence closure *) (** A node of the congruence closure *)
@ -660,16 +697,17 @@ module type CC_S = sig
val assert_lits : t -> lit Iter.t -> unit val assert_lits : t -> lit Iter.t -> unit
(** Addition of many literals *) (** Addition of many literals *)
(* FIXME: this needs to return [lit list * (term*term*P.t) list]. val explain_eq : t -> N.t -> N.t -> Resolved_expl.t
the explanation is [/\_i lit_i /\ /\_j (|- t_j=u_j) |- n1=n2] *)
val explain_eq : t -> N.t -> N.t -> lit list
(** Explain why the two nodes are equal. (** Explain why the two nodes are equal.
Fails if they are not, in an unspecified way *) Fails if they are not, in an unspecified way. *)
val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a
(** Raise a conflict with the given explanation (** Raise a conflict with the given explanation.
it must be a theory tautology that [expl ==> absurd]. It must be a theory tautology that [expl ==> absurd].
To be used in theories. *) To be used in theories.
This fails in an unspecified way if the explanation, once resolved,
satisfies {!Resolved_expl.is_semantic}. *)
val n_true : t -> N.t val n_true : t -> N.t
(** Node for [true] *) (** Node for [true] *)
@ -688,6 +726,12 @@ module type CC_S = sig
val merge_t : t -> term -> term -> Expl.t -> unit val merge_t : t -> term -> term -> Expl.t -> unit
(** Shortcut for adding + merging *) (** Shortcut for adding + merging *)
val merge_same_value : t -> N.t -> N.t -> unit
(** Merge these two nodes because they have the same value
in the model. The explanation will be {!Expl.mk_same_value}. *)
val merge_same_value_t : t -> term -> term -> unit
val check : t -> actions -> unit val check : t -> actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
Will use the {!actions} to propagate literals, declare conflicts, etc. *) Will use the {!actions} to propagate literals, declare conflicts, etc. *)
@ -985,6 +1029,12 @@ module type SOLVER_INTERNAL = sig
is given the whole trail. is given the whole trail.
*) *)
val on_th_combination : t -> (t -> theory_actions -> term list Iter.t) -> unit
(** Add a hook called during theory combination.
The hook must return an iterator of lists, each list [t1tn]
is a set of terms that have the same value in the model
(and therefore must be merged). *)
val declare_pb_is_incomplete : t -> unit val declare_pb_is_incomplete : t -> unit
(** Declare that, in some theory, the problem is outside the logic fragment (** Declare that, in some theory, the problem is outside the logic fragment
that is decidable (e.g. if we meet proper NIA formulas). that is decidable (e.g. if we meet proper NIA formulas).

View file

@ -128,7 +128,9 @@ module Make(A : ARG) : S with module A = A = struct
| By_def -> [] | By_def -> []
| Lit l -> [l] | Lit l -> [l]
| CC_eq (n1,n2) -> | CC_eq (n1,n2) ->
SI.CC.explain_eq (SI.cc si) n1 n2 let r = SI.CC.explain_eq (SI.cc si) n1 n2 in
assert (not (SI.CC.Resolved_expl.is_semantic r));
r.lits
end end
module SimpVar module SimpVar
@ -155,6 +157,7 @@ module Make(A : ARG) : S with module A = A = struct
let mk_lit _ _ _ = assert false let mk_lit _ _ _ = assert false
end) end)
module Subst = SimpSolver.Subst module Subst = SimpSolver.Subst
module Q_tbl = CCHashtbl.Make(A.Q)
module Comb_map = CCMap.Make(LE_.Comb) module Comb_map = CCMap.Make(LE_.Comb)
@ -558,6 +561,7 @@ module Make(A : ARG) : S with module A = A = struct
CCList.flat_map (Tag.to_lits si) reason, pr) CCList.flat_map (Tag.to_lits si) reason, pr)
| _ -> () | _ -> ()
(** Check satisfiability of simplex, and sets [self.last_res] *)
let check_simplex_ self si acts : SimpSolver.Subst.t = let check_simplex_ self si acts : SimpSolver.Subst.t =
Log.debug 5 "(lra.check-simplex)"; Log.debug 5 "(lra.check-simplex)";
let res = let res =
@ -615,6 +619,7 @@ module Make(A : ARG) : S with module A = A = struct
let t2 = N.term n2 in let t2 = N.term n2 in
add_local_eq_t self si acts t1 t2 ~tag:(Tag.CC_eq (n1, n2)) add_local_eq_t self si acts t1 t2 ~tag:(Tag.CC_eq (n1, n2))
(*
(* theory combination: add decisions [t=u] whenever [t] and [u] (* theory combination: add decisions [t=u] whenever [t] and [u]
have the same value in [subst] and both occur under function symbols *) have the same value in [subst] and both occur under function symbols *)
let do_th_combination (self:state) si acts (subst:Subst.t) : unit = let do_th_combination (self:state) si acts (subst:Subst.t) : unit =
@ -683,6 +688,29 @@ module Make(A : ARG) : S with module A = A = struct
end; end;
Log.debugf 1 (fun k->k "(@[lra.do-th-combinations.done@ :new-lits %d@])" !n); Log.debugf 1 (fun k->k "(@[lra.do-th-combinations.done@ :new-lits %d@])" !n);
() ()
*)
let do_th_combination (self:state) _si _acts : A.term list Iter.t =
Log.debug 1 "(lra.do-th-combinations)";
let model = match self.last_res with
| Some (SimpSolver.Sat m) -> m
| _ -> assert false
in
(* gather terms by their model value *)
let tbl = Q_tbl.create 32 in
Subst.to_iter model
(fun (t,q) ->
let l = Q_tbl.get_or ~default:[] tbl q in
Q_tbl.replace tbl q (t :: l));
(* now return classes of terms *)
Q_tbl.to_iter tbl
|> Iter.filter_map
(fun (_q, l) ->
match l with
| [] | [_] -> None
| l -> Some l)
(* partial checks is where we add literals from the trail to the (* partial checks is where we add literals from the trail to the
simplex. *) simplex. *)
@ -757,7 +785,6 @@ module Make(A : ARG) : S with module A = A = struct
let model = check_simplex_ self si acts in let model = check_simplex_ self si acts in
Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model); Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model);
Log.debug 5 "(lra: solver returns SAT)"; Log.debug 5 "(lra: solver returns SAT)";
do_th_combination self si acts model;
() ()
(* help generating model *) (* help generating model *)
@ -812,6 +839,7 @@ module Make(A : ARG) : S with module A = A = struct
Log.debugf 30 (fun k->k "(@[lra.merge-incompatible-consts@ %a@ %a@])" N.pp n1 N.pp n2); Log.debugf 30 (fun k->k "(@[lra.merge-incompatible-consts@ %a@ %a@])" N.pp n1 N.pp n2);
SI.CC.raise_conflict_from_expl si acts expl SI.CC.raise_conflict_from_expl si acts expl
| _ -> ()); | _ -> ());
SI.on_th_combination si (do_th_combination st);
st st
let theory = let theory =

View file

@ -74,6 +74,7 @@ module type S = sig
module Subst : sig module Subst : sig
type t = num V_map.t type t = num V_map.t
val eval : t -> V.t -> Q.t option val eval : t -> V.t -> Q.t option
val to_iter : t -> (V.t * Q.t) Iter.t
val pp : t Fmt.printer val pp : t Fmt.printer
val to_string : t -> string val to_string : t -> string
end end
@ -211,6 +212,7 @@ module Make(Arg: ARG)
module Subst = struct module Subst = struct
type t = num V_map.t type t = num V_map.t
let eval self t = V_map.get t self let eval self t = V_map.get t self
let to_iter self f = V_map.iter (fun k v -> f (k,v)) self
let pp out (self:t) : unit = let pp out (self:t) : unit =
let pp_pair out (v,n) = let pp_pair out (v,n) =
Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in

View file

@ -96,6 +96,15 @@ module Make(A : ARG)
(* actions from the sat solver *) (* actions from the sat solver *)
type sat_acts = (lit, proof, proof_step) Sidekick_sat.acts type sat_acts = (lit, proof, proof_step) Sidekick_sat.acts
(** Conflict obtained during theory combination. It involves equalities
merged because of the current model so it's not a "true" conflict
and doesn't need to kill the current trail. *)
type th_combination_conflict = {
lits: lit list;
same_val: (term*term) list;
}
exception Semantic_conflict of th_combination_conflict
(* the full argument to the congruence closure *) (* the full argument to the congruence closure *)
module CC_actions = struct module CC_actions = struct
module T = T module T = T
@ -119,6 +128,8 @@ module Make(A : ARG)
let[@inline] raise_conflict (a:t) lits (pr:proof_step) = let[@inline] raise_conflict (a:t) lits (pr:proof_step) =
let (module A) = a in let (module A) = a in
A.raise_conflict lits pr A.raise_conflict lits pr
let[@inline] raise_semantic_conflict (_:t) lits same_val =
raise (Semantic_conflict {lits; same_val})
let[@inline] propagate (a:t) lit ~reason = let[@inline] propagate (a:t) lit ~reason =
let (module A) = a in let (module A) = a in
let reason = Sidekick_sat.Consequence reason in let reason = Sidekick_sat.Consequence reason in
@ -263,6 +274,7 @@ module Make(A : ARG)
mutable on_progress: unit -> unit; mutable on_progress: unit -> unit;
mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_th_combination: (t -> theory_actions -> term list Iter.t) list;
mutable preprocess: preprocess_hook list; mutable preprocess: preprocess_hook list;
mutable model_ask: model_ask_hook list; mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list; mutable model_complete: model_completion_hook list;
@ -313,6 +325,7 @@ module Make(A : ARG)
let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f
let on_th_combination self f = self.on_th_combination <- f :: self.on_th_combination
let on_preprocess self f = self.preprocess <- f :: self.preprocess let on_preprocess self f = self.preprocess <- f :: self.preprocess
let on_model ?ask ?complete self = let on_model ?ask ?complete self =
CCOpt.iter (fun f -> self.model_ask <- f :: self.model_ask) ask; CCOpt.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
@ -548,6 +561,43 @@ module Make(A : ARG)
CC.pop_levels (cc self) n; CC.pop_levels (cc self) n;
pop_lvls_ n self.th_states pop_lvls_ n self.th_states
(* run [f] in a local congruence closure level *)
let with_cc_level_ cc f =
CC.push_level cc;
CCFun.protect ~finally:(fun() -> CC.pop_levels cc 1) f
(* do theory combination using the congruence closure. Each theory
can merge classes, *)
let check_th_combination_
(self:t) (acts:theory_actions) : (unit, th_combination_conflict) result =
let cc = cc self in
with_cc_level_ cc @@ fun () ->
(* merge all terms in the class *)
let merge_cls (cls:term list) : unit =
match cls with
| [] -> assert false
| [_] -> ()
| t :: ts ->
Log.debugf 50
(fun k->k "(@[solver.th-comb.merge-cls@ %a@])"
(Util.pp_list Term.pp) cls);
List.iter (fun u -> CC.merge_same_value_t cc t u) ts
in
(* obtain classes of equal terms from the hook, and merge them *)
let add_th_equalities f : unit =
let cls = f self acts in
Iter.iter merge_cls cls
in
try
List.iter add_th_equalities self.on_th_combination;
CC.check cc acts;
Ok ()
with Semantic_conflict c -> Error c
(* handle a literal assumed by the SAT solver *) (* handle a literal assumed by the SAT solver *)
let assert_lits_ ~final (self:t) (acts:theory_actions) (lits:Lit.t Iter.t) : unit = let assert_lits_ ~final (self:t) (acts:theory_actions) (lits:Lit.t Iter.t) : unit =
Log.debugf 2 Log.debugf 2
@ -563,9 +613,9 @@ module Make(A : ARG)
if final then ( if final then (
let continue = ref true in let continue = ref true in
while !continue do while !continue do
(* do final checks in a loop *)
let fcheck = ref true in let fcheck = ref true in
while !fcheck do while !fcheck do
(* TODO: theory combination *)
List.iter (fun f -> f self acts lits) self.on_final_check; List.iter (fun f -> f self acts lits) self.on_final_check;
if has_delayed_actions self then ( if has_delayed_actions self then (
Perform_delayed_th.top self acts; Perform_delayed_th.top self acts;
@ -573,6 +623,32 @@ module Make(A : ARG)
fcheck := false fcheck := false
) )
done; done;
begin match check_th_combination_ self acts with
| Ok () -> ()
| Error {lits; same_val} ->
(* bad model, we add a clause to remove it *)
Log.debugf 10
(fun k->k "(@[solver.th-comb.conflict@ :lits (@[%a@])@ \
:same-val (@[%a@])@])"
(Util.pp_list Lit.pp) lits
(Util.pp_list @@ Fmt.Dump.pair Term.pp Term.pp) same_val);
let c1 = List.rev_map Lit.neg lits in
let c2 =
List.rev_map (fun (t,u) ->
Lit.atom ~sign:false self.tst @@ A.mk_eq self.tst t u) same_val
in
let c = List.rev_append c1 c2 in
let pr = P.lemma_cc (Iter.of_list c) self.proof in
Log.debugf 20
(fun k->k "(@[solver.th-comb.add-clause@ %a@])"
(Util.pp_list Lit.pp) c);
add_clause_temp self acts c pr;
end;
CC.check cc acts; CC.check cc acts;
if not (CC.new_merges cc) && not (has_delayed_actions self) then ( if not (CC.new_merges cc) && not (has_delayed_actions self) then (
continue := false; continue := false;
@ -638,6 +714,7 @@ module Make(A : ARG)
t_defs=[]; t_defs=[];
on_partial_check=[]; on_partial_check=[];
on_final_check=[]; on_final_check=[];
on_th_combination=[];
level=0; level=0;
complete=true; complete=true;
} in } in