wip: theory combination by exposing model (classes) directly to CC

This commit is contained in:
Simon Cruanes 2022-02-17 13:49:47 -05:00
parent 65d4a90df1
commit fd66039c8d
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
5 changed files with 273 additions and 57 deletions

View file

@ -90,13 +90,14 @@ module Make (A: CC_ARG)
(* atomic explanation in the congruence closure *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_trivial (* by pure reduction, tautologically equal *)
| E_lit of lit (* because of this literal *)
| E_merge of node * node
| E_merge_t of term * term
| E_congruence of node * node (* caused by normal congruence *)
| E_and of explanation * explanation
| E_theory of term * term * (term * term * explanation list) list * proof_step
| E_same_val of node * node
type repr = node
@ -162,7 +163,7 @@ module Make (A: CC_ARG)
type t = explanation
let rec pp out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction"
| E_trivial -> Fmt.string out "reduction"
| E_lit lit -> Lit.pp out lit
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
@ -174,25 +175,49 @@ module Make (A: CC_ARG)
(Util.pp_list @@ Fmt.Dump.triple Term.pp Term.pp (Fmt.Dump.list pp)) es
| E_and (a,b) ->
Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
| E_same_val (n1,n2) ->
Fmt.fprintf out "(@[same-value@ %a@ %a@])" N.pp n1 N.pp n2
let mk_reduction : t = E_reduction
let mk_trivial : t = E_trivial
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2)
let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b)
let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b)
let[@inline] mk_merge a b : t = if N.equal a b then mk_trivial else E_merge (a,b)
let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_trivial else E_merge_t (a,b)
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_theory t u es pr = E_theory (t,u,es,pr)
let[@inline] mk_same_value t u = if N.equal t u then mk_trivial else E_same_val (t,u)
let rec mk_list l =
match l with
| [] -> mk_reduction
| [] -> mk_trivial
| [x] -> x
| E_reduction :: tl -> mk_list tl
| E_trivial :: tl -> mk_list tl
| x :: y ->
match mk_list y with
| E_reduction -> x
| E_trivial -> x
| y' -> E_and (x,y')
end
module Resolved_expl = struct
type t = {
lits: lit list;
same_value: (N.t * N.t) list;
pr: proof -> proof_step;
}
let[@inline] is_semantic (self:t) : bool =
match self.same_value with [] -> false | _::_ -> true
let pp out (self:t) =
if not (is_semantic self) then (
Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits
) else (
let {lits; same_value; pr=_} = self in
Fmt.fprintf out "(@[resolved-expl@ (@[%a@])@ :same-val (@[%a@])@])"
(Util.pp_list Lit.pp) lits
(Util.pp_list @@ Fmt.Dump.pair N.pp N.pp) same_value
)
end
(** A signature is a shallow term shape where immediate subterms
are representative *)
module Signature = struct
@ -274,6 +299,7 @@ module Make (A: CC_ARG)
count_conflict: int Stat.counter;
count_props: int Stat.counter;
count_merge: int Stat.counter;
count_semantic_conflict: int Stat.counter;
}
(* TODO: an additional union-find to keep track, for each term,
of the terms they are known to be equal to, according
@ -443,28 +469,71 @@ module Make (A: CC_ARG)
module Expl_state = struct
type t = {
mutable lits: Lit.t list;
mutable same_val: (N.t * N.t) list;
mutable th_lemmas:
(Lit.t * (Lit.t * Lit.t list) list * proof_step) list;
}
let create(): t = { lits=[]; th_lemmas=[] }
let create(): t = { lits=[]; same_val=[]; th_lemmas=[] }
let[@inline] copy self : t = {self with lits=self.lits}
let[@inline] add_lit (self:t) lit = self.lits <- lit :: self.lits
let[@inline] add_th (self:t) lit hyps pr : unit =
self.th_lemmas <- (lit,hyps,pr) :: self.th_lemmas
let[@inline] add_same_val (self:t) n1 n2 : unit =
self.same_val <- (n1,n2) :: self.same_val
(** Does this explanation contain at least one merge caused by
"same value"? *)
let[@inline] is_semantic (self:t): bool = self.same_val <> []
let merge self other =
let {lits=o_lits; th_lemmas=o_lemmas} = other in
let {lits=o_lits; th_lemmas=o_lemmas;same_val=o_same_val} = other in
self.lits <- List.rev_append o_lits self.lits;
self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas
self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas;
self.same_val <- List.rev_append o_same_val self.same_val;
()
(* proof of [\/_i ¬lits[i]] *)
let proof_of_th_lemmas (self:t) (proof:proof) : proof_step =
let p_lits1 = Iter.of_list self.lits |> Iter.map Lit.neg in
let p_lits2 =
Iter.of_list self.th_lemmas
|> Iter.map (fun (lit_t_u,_,_) -> Lit.neg lit_t_u)
in
let p_cc = P.lemma_cc (Iter.append p_lits1 p_lits2) proof in
let resolve_with_th_proof pr (lit_t_u,sub_proofs,pr_th) =
(* pr_th: [sub_proofs |- t=u].
now resolve away [sub_proofs] to get literals that were
asserted in the congruence closure *)
let pr_th = List.fold_left
(fun pr_th (lit_i,hyps_i) ->
(* [hyps_i |- lit_i] *)
let lemma_i =
P.lemma_cc Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) proof
in
(* resolve [lit_i] away. *)
P.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th proof)
pr_th sub_proofs
in
P.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr proof
in
(* resolve with theory proofs responsible for some merges, if any. *)
List.fold_left resolve_with_th_proof p_cc self.th_lemmas
let to_resolved_expl (self:t) : Resolved_expl.t =
(* FIXME: package the th lemmas too *)
let {lits; same_val; th_lemmas=_} = self in
let s2 = copy self in
let pr proof = proof_of_th_lemmas s2 proof in
{Resolved_expl.lits; same_value=same_val; pr}
end
(* decompose explanation [e] into a list of literals added to [acc] *)
let rec explain_decompose_expl cc (st:Expl_state.t) (e:explanation) : unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
match e with
| E_reduction -> ()
| E_trivial -> ()
| E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
@ -482,6 +551,7 @@ module Make (A: CC_ARG)
assert false
end
| E_lit lit -> Expl_state.add_lit st lit
| E_same_val (n1, n2) -> Expl_state.add_same_val st n1 n2
| E_theory (t, u, expl_sets, pr) ->
let sub_proofs =
List.map
@ -625,37 +695,9 @@ module Make (A: CC_ARG)
merges. *)
let lits_and_proof_of_expl
(self:t) (st:Expl_state.t) : Lit.t list * proof_step =
let {Expl_state.lits; th_lemmas} = st in
let proof = self.proof in
(* proof of [\/_i ¬lits[i]] *)
let pr =
let p_lits1 = Iter.of_list lits |> Iter.map Lit.neg in
let p_lits2 =
Iter.of_list th_lemmas
|> Iter.map (fun (lit_t_u,_,_) -> Lit.neg lit_t_u)
in
let p_cc = P.lemma_cc (Iter.append p_lits1 p_lits2) proof in
let resolve_with_th_proof pr (lit_t_u,sub_proofs,pr_th) =
(* pr_th: [sub_proofs |- t=u].
now resolve away [sub_proofs] to get literals that were
asserted in the congruence closure *)
let pr_th = List.fold_left
(fun pr_th (lit_i,hyps_i) ->
(* [hyps_i |- lit_i] *)
let lemma_i =
P.lemma_cc Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) proof
in
(* resolve [lit_i] away. *)
P.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th proof)
pr_th sub_proofs
in
P.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr proof
in
(* resolve with theory proofs responsible for some merges, if any. *)
List.fold_left resolve_with_th_proof p_cc th_lemmas
in
let {Expl_state.lits; th_lemmas=_; same_val} = st in
assert (same_val = []);
let pr = Expl_state.proof_of_th_lemmas st self.proof in
lits, pr
(* main CC algo: add terms from [pending] to the signature table,
@ -740,8 +782,21 @@ module Make (A: CC_ARG)
explain_equal_rec_ cc expl_st a ra;
explain_equal_rec_ cc expl_st b rb;
let lits, pr = lits_and_proof_of_expl cc expl_st in
raise_conflict_ cc ~th:!th acts (List.rev_map Lit.neg lits) pr
if Expl_state.is_semantic expl_st then (
(* conflict involving some semantic values *)
let lits = expl_st.lits in
let same_val =
expl_st.same_val
|> List.rev_map (fun (t,u) -> N.term t, N.term u) in
assert (same_val <> []);
Stat.incr cc.count_semantic_conflict;
Actions.raise_semantic_conflict acts lits same_val
) else (
(* regular conflict *)
let lits, pr = lits_and_proof_of_expl cc expl_st in
raise_conflict_ cc ~th:!th acts (List.rev_map Lit.neg lits) pr
)
);
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always
@ -943,11 +998,14 @@ module Make (A: CC_ARG)
let[@inline] merge_t cc t1 t2 expl =
merge cc (add_term cc t1) (add_term cc t2) expl
let explain_eq cc n1 n2 : lit list =
let merge_same_value cc n1 n2 = merge cc n1 n2 (Expl.mk_same_value n1 n2)
let merge_same_value_t cc t1 t2 = merge_same_value cc (add_term cc t1) (add_term cc t2)
let explain_eq cc n1 n2 : Resolved_expl.t =
let st = Expl_state.create() in
explain_equal_rec_ cc st n1 n2;
(* FIXME: also need to return the proof? *)
st.lits
Expl_state.to_resolved_expl st
let on_pre_merge cc f = cc.on_pre_merge <- f :: cc.on_pre_merge
let on_post_merge cc f = cc.on_post_merge <- f :: cc.on_post_merge
@ -986,6 +1044,7 @@ module Make (A: CC_ARG)
count_conflict=Stat.mk_int stat "cc.conflicts";
count_props=Stat.mk_int stat "cc.propagations";
count_merge=Stat.mk_int stat "cc.merges";
count_semantic_conflict=Stat.mk_int stat "cc.semantic-conflicts";
} and true_ = lazy (
add_term cc (Term.bool tst true)
) and false_ = lazy (

View file

@ -332,6 +332,15 @@ module type CC_ACTIONS = sig
exception).
@param pr the proof of [c] being a tautology *)
val raise_semantic_conflict : t -> Lit.t list -> (T.Term.t * T.Term.t) list -> 'a
(** [raise_semantic_conflict acts lits same_val] declares that
the conjunction of all [lits] (literals true in current trail)
and pairs [t_i = u_i] (which are pairs of terms with the same value
in the current model), implies false.
This does not return. It should raise an exception.
*)
val propagate : t -> Lit.t -> reason:(unit -> Lit.t list * proof_step) -> unit
(** [propagate acts lit ~reason pr] declares that [reason() => lit]
is a tautology.
@ -484,6 +493,7 @@ module type CC_S = sig
val pp : t Fmt.printer
val mk_merge : N.t -> N.t -> t
(** Explanation: the nodes were explicitly merged *)
val mk_merge_t : term -> term -> t
(** Explanation: the terms were explicitly merged *)
@ -493,6 +503,8 @@ module type CC_S = sig
or we merged [t] and [true] because of literal [t],
or [t] and [false] because of literal [¬t] *)
val mk_same_value : N.t -> N.t -> t
val mk_list : t list -> t
(** Conjunction of explanations *)
@ -520,6 +532,31 @@ module type CC_S = sig
*)
end
(** Resolved explanations.
The congruence closure keeps explanations for why terms are in the same
class. However these are represented in a compact, cheap form.
To use these explanations we need to {b resolve} them into a
resolved explanation, typically a list of
literals that are true in the current trail and are responsible for
merges.
However, we can also have merged classes because they have the same value
in the current model. *)
module Resolved_expl : sig
type t = {
lits: lit list;
same_value: (N.t * N.t) list;
pr: proof -> proof_step;
}
val is_semantic : t -> bool
(** [is_semantic expl] is [true] if there's at least one
pair in [expl.same_value]. *)
val pp : t Fmt.printer
end
type node = N.t
(** A node of the congruence closure *)
@ -660,16 +697,17 @@ module type CC_S = sig
val assert_lits : t -> lit Iter.t -> unit
(** Addition of many literals *)
(* FIXME: this needs to return [lit list * (term*term*P.t) list].
the explanation is [/\_i lit_i /\ /\_j (|- t_j=u_j) |- n1=n2] *)
val explain_eq : t -> N.t -> N.t -> lit list
val explain_eq : t -> N.t -> N.t -> Resolved_expl.t
(** Explain why the two nodes are equal.
Fails if they are not, in an unspecified way *)
Fails if they are not, in an unspecified way. *)
val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a
(** Raise a conflict with the given explanation
it must be a theory tautology that [expl ==> absurd].
To be used in theories. *)
(** Raise a conflict with the given explanation.
It must be a theory tautology that [expl ==> absurd].
To be used in theories.
This fails in an unspecified way if the explanation, once resolved,
satisfies {!Resolved_expl.is_semantic}. *)
val n_true : t -> N.t
(** Node for [true] *)
@ -688,6 +726,12 @@ module type CC_S = sig
val merge_t : t -> term -> term -> Expl.t -> unit
(** Shortcut for adding + merging *)
val merge_same_value : t -> N.t -> N.t -> unit
(** Merge these two nodes because they have the same value
in the model. The explanation will be {!Expl.mk_same_value}. *)
val merge_same_value_t : t -> term -> term -> unit
val check : t -> actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
Will use the {!actions} to propagate literals, declare conflicts, etc. *)
@ -985,6 +1029,12 @@ module type SOLVER_INTERNAL = sig
is given the whole trail.
*)
val on_th_combination : t -> (t -> theory_actions -> term list Iter.t) -> unit
(** Add a hook called during theory combination.
The hook must return an iterator of lists, each list [t1tn]
is a set of terms that have the same value in the model
(and therefore must be merged). *)
val declare_pb_is_incomplete : t -> unit
(** Declare that, in some theory, the problem is outside the logic fragment
that is decidable (e.g. if we meet proper NIA formulas).

View file

@ -128,7 +128,9 @@ module Make(A : ARG) : S with module A = A = struct
| By_def -> []
| Lit l -> [l]
| CC_eq (n1,n2) ->
SI.CC.explain_eq (SI.cc si) n1 n2
let r = SI.CC.explain_eq (SI.cc si) n1 n2 in
assert (not (SI.CC.Resolved_expl.is_semantic r));
r.lits
end
module SimpVar
@ -155,6 +157,7 @@ module Make(A : ARG) : S with module A = A = struct
let mk_lit _ _ _ = assert false
end)
module Subst = SimpSolver.Subst
module Q_tbl = CCHashtbl.Make(A.Q)
module Comb_map = CCMap.Make(LE_.Comb)
@ -558,6 +561,7 @@ module Make(A : ARG) : S with module A = A = struct
CCList.flat_map (Tag.to_lits si) reason, pr)
| _ -> ()
(** Check satisfiability of simplex, and sets [self.last_res] *)
let check_simplex_ self si acts : SimpSolver.Subst.t =
Log.debug 5 "(lra.check-simplex)";
let res =
@ -615,6 +619,7 @@ module Make(A : ARG) : S with module A = A = struct
let t2 = N.term n2 in
add_local_eq_t self si acts t1 t2 ~tag:(Tag.CC_eq (n1, n2))
(*
(* theory combination: add decisions [t=u] whenever [t] and [u]
have the same value in [subst] and both occur under function symbols *)
let do_th_combination (self:state) si acts (subst:Subst.t) : unit =
@ -683,6 +688,29 @@ module Make(A : ARG) : S with module A = A = struct
end;
Log.debugf 1 (fun k->k "(@[lra.do-th-combinations.done@ :new-lits %d@])" !n);
()
*)
let do_th_combination (self:state) _si _acts : A.term list Iter.t =
Log.debug 1 "(lra.do-th-combinations)";
let model = match self.last_res with
| Some (SimpSolver.Sat m) -> m
| _ -> assert false
in
(* gather terms by their model value *)
let tbl = Q_tbl.create 32 in
Subst.to_iter model
(fun (t,q) ->
let l = Q_tbl.get_or ~default:[] tbl q in
Q_tbl.replace tbl q (t :: l));
(* now return classes of terms *)
Q_tbl.to_iter tbl
|> Iter.filter_map
(fun (_q, l) ->
match l with
| [] | [_] -> None
| l -> Some l)
(* partial checks is where we add literals from the trail to the
simplex. *)
@ -757,7 +785,6 @@ module Make(A : ARG) : S with module A = A = struct
let model = check_simplex_ self si acts in
Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model);
Log.debug 5 "(lra: solver returns SAT)";
do_th_combination self si acts model;
()
(* help generating model *)
@ -812,6 +839,7 @@ module Make(A : ARG) : S with module A = A = struct
Log.debugf 30 (fun k->k "(@[lra.merge-incompatible-consts@ %a@ %a@])" N.pp n1 N.pp n2);
SI.CC.raise_conflict_from_expl si acts expl
| _ -> ());
SI.on_th_combination si (do_th_combination st);
st
let theory =

View file

@ -74,6 +74,7 @@ module type S = sig
module Subst : sig
type t = num V_map.t
val eval : t -> V.t -> Q.t option
val to_iter : t -> (V.t * Q.t) Iter.t
val pp : t Fmt.printer
val to_string : t -> string
end
@ -211,6 +212,7 @@ module Make(Arg: ARG)
module Subst = struct
type t = num V_map.t
let eval self t = V_map.get t self
let to_iter self f = V_map.iter (fun k v -> f (k,v)) self
let pp out (self:t) : unit =
let pp_pair out (v,n) =
Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in

View file

@ -96,6 +96,15 @@ module Make(A : ARG)
(* actions from the sat solver *)
type sat_acts = (lit, proof, proof_step) Sidekick_sat.acts
(** Conflict obtained during theory combination. It involves equalities
merged because of the current model so it's not a "true" conflict
and doesn't need to kill the current trail. *)
type th_combination_conflict = {
lits: lit list;
same_val: (term*term) list;
}
exception Semantic_conflict of th_combination_conflict
(* the full argument to the congruence closure *)
module CC_actions = struct
module T = T
@ -119,6 +128,8 @@ module Make(A : ARG)
let[@inline] raise_conflict (a:t) lits (pr:proof_step) =
let (module A) = a in
A.raise_conflict lits pr
let[@inline] raise_semantic_conflict (_:t) lits same_val =
raise (Semantic_conflict {lits; same_val})
let[@inline] propagate (a:t) lit ~reason =
let (module A) = a in
let reason = Sidekick_sat.Consequence reason in
@ -263,6 +274,7 @@ module Make(A : ARG)
mutable on_progress: unit -> unit;
mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_th_combination: (t -> theory_actions -> term list Iter.t) list;
mutable preprocess: preprocess_hook list;
mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list;
@ -313,6 +325,7 @@ module Make(A : ARG)
let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f
let on_th_combination self f = self.on_th_combination <- f :: self.on_th_combination
let on_preprocess self f = self.preprocess <- f :: self.preprocess
let on_model ?ask ?complete self =
CCOpt.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
@ -548,6 +561,43 @@ module Make(A : ARG)
CC.pop_levels (cc self) n;
pop_lvls_ n self.th_states
(* run [f] in a local congruence closure level *)
let with_cc_level_ cc f =
CC.push_level cc;
CCFun.protect ~finally:(fun() -> CC.pop_levels cc 1) f
(* do theory combination using the congruence closure. Each theory
can merge classes, *)
let check_th_combination_
(self:t) (acts:theory_actions) : (unit, th_combination_conflict) result =
let cc = cc self in
with_cc_level_ cc @@ fun () ->
(* merge all terms in the class *)
let merge_cls (cls:term list) : unit =
match cls with
| [] -> assert false
| [_] -> ()
| t :: ts ->
Log.debugf 50
(fun k->k "(@[solver.th-comb.merge-cls@ %a@])"
(Util.pp_list Term.pp) cls);
List.iter (fun u -> CC.merge_same_value_t cc t u) ts
in
(* obtain classes of equal terms from the hook, and merge them *)
let add_th_equalities f : unit =
let cls = f self acts in
Iter.iter merge_cls cls
in
try
List.iter add_th_equalities self.on_th_combination;
CC.check cc acts;
Ok ()
with Semantic_conflict c -> Error c
(* handle a literal assumed by the SAT solver *)
let assert_lits_ ~final (self:t) (acts:theory_actions) (lits:Lit.t Iter.t) : unit =
Log.debugf 2
@ -563,9 +613,9 @@ module Make(A : ARG)
if final then (
let continue = ref true in
while !continue do
(* do final checks in a loop *)
let fcheck = ref true in
while !fcheck do
(* TODO: theory combination *)
List.iter (fun f -> f self acts lits) self.on_final_check;
if has_delayed_actions self then (
Perform_delayed_th.top self acts;
@ -573,6 +623,32 @@ module Make(A : ARG)
fcheck := false
)
done;
begin match check_th_combination_ self acts with
| Ok () -> ()
| Error {lits; same_val} ->
(* bad model, we add a clause to remove it *)
Log.debugf 10
(fun k->k "(@[solver.th-comb.conflict@ :lits (@[%a@])@ \
:same-val (@[%a@])@])"
(Util.pp_list Lit.pp) lits
(Util.pp_list @@ Fmt.Dump.pair Term.pp Term.pp) same_val);
let c1 = List.rev_map Lit.neg lits in
let c2 =
List.rev_map (fun (t,u) ->
Lit.atom ~sign:false self.tst @@ A.mk_eq self.tst t u) same_val
in
let c = List.rev_append c1 c2 in
let pr = P.lemma_cc (Iter.of_list c) self.proof in
Log.debugf 20
(fun k->k "(@[solver.th-comb.add-clause@ %a@])"
(Util.pp_list Lit.pp) c);
add_clause_temp self acts c pr;
end;
CC.check cc acts;
if not (CC.new_merges cc) && not (has_delayed_actions self) then (
continue := false;
@ -638,6 +714,7 @@ module Make(A : ARG)
t_defs=[];
on_partial_check=[];
on_final_check=[];
on_th_combination=[];
level=0;
complete=true;
} in