wip(smt): theory combination

This commit is contained in:
Simon Cruanes 2022-08-27 21:38:20 -04:00
parent 2a0feed32c
commit ccb3753668
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
7 changed files with 174 additions and 121 deletions

View file

@ -55,6 +55,16 @@ let unfold_app (e : term) : term * term list =
in
aux [] e
let[@inline] is_const e =
match e.view with
| E_const _ -> true
| _ -> false
let[@inline] is_app e =
match e.view with
| E_app _ -> true
| _ -> false
(* debug printer *)
let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
let rec loop k ~depth names out e =

View file

@ -53,6 +53,8 @@ include WITH_SET_MAP_TBL with type t := t
val view : t -> view
val unfold_app : t -> t * t list
val is_app : t -> bool
val is_const : t -> bool
val iter_dag : ?seen:unit Tbl.t -> iter_ty:bool -> f:(t -> unit) -> t -> unit
(** [iter_dag t ~f] calls [f] once on each subterm of [t], [t] included.

View file

@ -21,6 +21,7 @@ module type PREPROCESS_ACTS = sig
val mk_lit : ?sign:bool -> term -> lit
val add_clause : lit list -> step_id -> unit
val add_lit : ?default_pol:bool -> lit -> unit
val add_term_needing_combination : term -> unit
end
type preprocess_actions = (module PREPROCESS_ACTS)
@ -39,10 +40,9 @@ type t = {
proof: proof_trace; (** proof logger *)
registry: Registry.t;
on_progress: (unit, unit) Event.Emitter.t;
th_comb: Th_combination.t;
mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_th_combination:
(t -> theory_actions -> (term * value) Iter.t) list;
mutable preprocess: preprocess_hook list;
mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list;
@ -82,11 +82,11 @@ let add_simplifier (self : t) f : unit = Simplify.add_hook self.simp f
let[@inline] has_delayed_actions self =
not (Queue.is_empty self.delayed_actions)
let on_th_combination self f =
self.on_th_combination <- f :: self.on_th_combination
let on_preprocess self f = self.preprocess <- f :: self.preprocess
let add_term_needing_combination self t =
Th_combination.add_term_needing_combination self.th_comb t
let on_model ?ask ?complete self =
Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
Option.iter
@ -130,6 +130,9 @@ let preprocess_term_ (self : t) (t0 : term) : unit =
let mk_lit ?sign t : Lit.t = Lit.atom ?sign self.tst t
let add_lit ?default_pol lit : unit = delayed_add_lit self ?default_pol lit
let add_clause c pr : unit = delayed_add_clause self ~keep:true c pr
let add_term_needing_combination t =
Th_combination.add_term_needing_combination self.th_comb t
end in
let acts = (module A : PREPROCESS_ACTS) in
@ -397,33 +400,12 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
(* do theory combination using the congruence closure. Each theory
can merge classes, *)
let check_th_combination_ (self : t) (_acts : theory_actions) lits :
(Model.t, th_combination_conflict) result =
(* FIXME
(* enter model mode, disabling most of congruence closure *)
CC.with_model_mode cc @@ fun () ->
let set_val (t, v) : unit =
Log.debugf 50 (fun k ->
k "(@[solver.th-comb.cc-set-term-value@ %a@ :val %a@])" Term.pp_debug t
Term.pp_debug v);
CC.set_model_value cc t v
in
(* obtain assignments from the hook, and communicate them to the CC *)
let add_th_values f : unit =
let vals = f self acts in
Iter.iter set_val vals
in
try
List.iter add_th_values self.on_th_combination;
CC.check cc;
let m = mk_model_ self in
Ok m
with Semantic_conflict c -> Error c
*)
let m = mk_model_ self lits in
Ok m
let check_th_combination_ (self : t) (acts : theory_actions) _lits : unit =
let lits_to_decide = Th_combination.pop_new_lits self.th_comb in
if lits_to_decide <> [] then (
let (module A) = acts in
List.iter (fun lit -> A.add_lit ~default_pol:false lit) lits_to_decide
)
(* call congruence closure, perform the actions it scheduled *)
let check_cc_with_acts_ (self : t) (acts : theory_actions) =
@ -471,40 +453,13 @@ let assert_lits_ ~final (self : t) (acts : theory_actions) (lits : Lit.t Iter.t)
(* do actual theory combination if nothing changed by pure "final check" *)
if not new_work then (
match check_th_combination_ self acts lits with
| Ok m -> self.last_model <- Some m
| Error { lits; semantic } ->
(* bad model, we add a clause to remove it *)
Log.debugf 5 (fun k ->
k
"(@[solver.th-comb.conflict@ :lits (@[%a@])@ :same-val \
(@[%a@])@])"
(Util.pp_list Lit.pp) lits
(Util.pp_list
@@ Fmt.Dump.(triple bool Term.pp_debug Term.pp_debug))
semantic);
check_th_combination_ self acts lits;
let c1 = List.rev_map Lit.neg lits in
let c2 =
semantic
|> List.rev_map (fun (sign, t, u) ->
let eqn = Term.eq self.tst t u in
let lit = Lit.atom ~sign:(not sign) self.tst eqn in
(* make sure to consider the new lit *)
add_lit self acts lit;
lit)
in
let c = List.rev_append c1 c2 in
let pr =
Proof_trace.add_step self.proof @@ fun () -> Proof_core.lemma_cc c
in
Log.debugf 20 (fun k ->
k "(@[solver.th-comb.add-semantic-conflict-clause@ %a@])"
(Util.pp_list Lit.pp) c);
(* will add a delayed action *)
add_clause_temp self acts c pr
(* if theory combination didn't add new clauses, compute a model *)
if not (has_delayed_actions self) then (
let m = mk_model_ self lits in
self.last_model <- Some m
)
);
Perform_delayed_th.top self acts
@ -585,6 +540,7 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
stat;
simp = Simplify.create tst ~proof;
last_model = None;
th_comb = Th_combination.create ~stat tst;
on_progress = Event.Emitter.create ();
preprocess = [];
model_ask = [];
@ -598,7 +554,6 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
count_conflict = Stat.mk_int stat "smt.solver.th-conflicts";
on_partial_check = [];
on_final_check = [];
on_th_combination = [];
level = 0;
complete = true;
}

View file

@ -73,6 +73,10 @@ module type PREPROCESS_ACTS = sig
val add_lit : ?default_pol:bool -> lit -> unit
(** Ensure the literal will be decided/handled by the SAT solver. *)
val add_term_needing_combination : term -> unit
(** Declare this term as being a foreign variable in the theory,
which means it needs to go through theory combination. *)
end
type preprocess_actions = (module PREPROCESS_ACTS)
@ -98,6 +102,10 @@ val preprocess_clause_array : t -> lit array -> step_id -> lit array * step_id
val simplify_and_preproc_lit : t -> lit -> lit * step_id option
(** Simplify literal then preprocess it *)
val add_term_needing_combination : t -> term -> unit
(** Declare this term as being a foreign variable in the theory,
which means it needs to go through theory combination. *)
(** {3 hooks for the theory} *)
val raise_conflict : t -> theory_actions -> lit list -> step_id -> 'a
@ -216,16 +224,6 @@ val on_final_check : t -> (t -> theory_actions -> lit Iter.t -> unit) -> unit
is given the whole trail.
*)
val on_th_combination :
t -> (t -> theory_actions -> (term * value) Iter.t) -> unit
(** Add a hook called during theory combination.
The hook must return an iterator of pairs [(t, v)]
which mean that term [t] has value [v] in the model.
Terms with the same value (according to {!Term.equal}) will be
merged in the CC; if two terms with different values are merged,
we get a semantic conflict and must pick another model. *)
val declare_pb_is_incomplete : t -> unit
(** Declare that, in some theory, the problem is outside the logic fragment
that is decidable (e.g. if we meet proper NIA formulas).

62
src/smt/th_combination.ml Normal file
View file

@ -0,0 +1,62 @@
open Sidekick_core
module T = Term
type t = {
tst: Term.store;
processed: T.Set.t T.Tbl.t; (** type -> set of terms *)
unprocessed: T.t Vec.t;
new_lits: Lit.t Vec.t;
n_terms: int Stat.counter;
n_lits: int Stat.counter;
}
let create ?(stat = Stat.global) tst : t =
{
tst;
processed = T.Tbl.create 8;
unprocessed = Vec.create ();
new_lits = Vec.create ();
n_terms = Stat.mk_int stat "smt.thcomb.terms";
n_lits = Stat.mk_int stat "smt.thcomb.intf-lits";
}
let processed_ (self : t) t : bool =
let ty = T.ty t in
match T.Tbl.find_opt self.processed ty with
| None -> false
| Some set -> T.Set.mem t set
let add_term_needing_combination (self : t) (t : T.t) : unit =
if not (processed_ self t) then (
Log.debugf 50 (fun k -> k "(@[th.comb.add-term-needing-comb@ %a@])" T.pp t);
Vec.push self.unprocessed t
)
let pop_new_lits (self : t) : Lit.t list =
(* first, process new terms, if any *)
while not (Vec.is_empty self.unprocessed) do
let t = Vec.pop_exn self.unprocessed in
let ty = T.ty t in
let set_for_ty =
try T.Tbl.find self.processed ty with Not_found -> T.Set.empty
in
if not (T.Set.mem t set_for_ty) then (
Stat.incr self.n_terms;
(* now create [t=u] for each [u] in [set_for_ty] *)
T.Set.iter
(fun u ->
let lit = Lit.make_eq self.tst t u in
Stat.incr self.n_lits;
Vec.push self.new_lits lit)
set_for_ty;
(* add [t] to the set of processed terms *)
let new_set_for_ty = T.Set.add t set_for_ty in
T.Tbl.replace self.processed ty new_set_for_ty
)
done;
let lits = Vec.to_list self.new_lits in
Vec.clear self.new_lits;
lits

View file

@ -0,0 +1,17 @@
(** Delayed Theory Combination *)
open Sidekick_core
type t
val create : ?stat:Stat.t -> Term.store -> t
val add_term_needing_combination : t -> Term.t -> unit
(** [add_term_needing_combination self t] means that [t] occurs as a foreign
variable in another term, so it is important that its theory, and the
theory in which it occurs, agree on it being equal to other
foreign terms. *)
val pop_new_lits : t -> Lit.t list
(** Get the new literals that the solver needs to decide, so that the
SMT solver gives each theory the same partition of interface equalities. *)

View file

@ -130,8 +130,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct
in_model: unit Term.Tbl.t; (* terms to add to model *)
encoded_eqs: unit Term.Tbl.t;
(* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *)
needs_th_combination: unit Term.Tbl.t;
(* terms that require theory combination *)
simp_preds: (Term.t * S_op.t * A.Q.t) Term.Tbl.t;
(* term -> its simplex meaning *)
simp_defined: LE.t Term.Tbl.t;
@ -157,7 +155,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct
simp_preds = Term.Tbl.create 32;
simp_defined = Term.Tbl.create 16;
encoded_eqs = Term.Tbl.create 8;
needs_th_combination = Term.Tbl.create 8;
encoded_le = Comb_map.empty;
simplex = SimpSolver.create ~stat ();
last_res = None;
@ -275,6 +272,11 @@ module Make (A : ARG) = (* : S with module A = A *) struct
| Geq -> S_op.Geq
| Gt -> S_op.Gt
(* add [t] to the theory combination system if it's not just a constant
of type Real *)
let add_lra_var_to_th_combination (si : SI.t) (t : term) : unit =
if not (Term.is_const t) then SI.add_term_needing_combination si t
(* TODO: refactor that and {!var_encoding_comb} *)
(* turn a linear expression into a single constant and a coeff.
This might define a side variable in the simplex. *)
@ -300,17 +302,20 @@ module Make (A : ARG) = (* : S with module A = A *) struct
proxy, A.Q.one)
(* look for subterms of type Real, for they will need theory combination *)
let on_subterm (self : state) (t : Term.t) : unit =
let on_subterm (_self : state) (si : SI.t) (t : Term.t) : unit =
Log.debugf 50 (fun k -> k "(@[lra.cc-on-subterm@ %a@])" Term.pp_debug t);
match A.view_as_lra t with
| LRA_other _ when not (A.has_ty_real t) -> ()
| LRA_other _ when not (A.has_ty_real t) ->
(* for a non-LRA term [f args], if any of [args] is in LRA,
it needs theory combination *)
let _, args = Term.unfold_app t in
List.iter
(fun arg ->
if A.has_ty_real arg then SI.add_term_needing_combination si arg)
args
| LRA_pred _ | LRA_const _ -> ()
| LRA_op _ | LRA_other _ | LRA_mult _ ->
if not (Term.Tbl.mem self.needs_th_combination t) then (
Log.debugf 5 (fun k ->
k "(@[lra.needs-th-combination@ %a@])" Term.pp_debug t);
Term.Tbl.add self.needs_th_combination t ()
)
SI.add_term_needing_combination si t
(* preprocess linear expressions away *)
let preproc_lra (self : state) si (module PA : SI.PREPROCESS_ACTS)
@ -323,7 +328,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
Log.debugf 50 (fun k ->
k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t);
ignore (CC.add_term (SI.cc si) t : E_node.t);
if sub then on_subterm self t
if sub then on_subterm self si t
in
match A.view_as_lra t with
@ -369,7 +374,11 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* obtain a single variable for the linear combination *)
let v, c_v = le_comb_to_singleton_ self le_comb in
declare_term_to_cc ~sub:false v;
LE_.Comb.iter (fun v _ -> declare_term_to_cc ~sub:true v) le_comb;
LE_.Comb.iter
(fun v _ ->
declare_term_to_cc ~sub:true v;
add_lra_var_to_th_combination si v)
le_comb;
(* turn into simplex constraint. For example,
[c . v <= const] becomes a direct simplex constraint [v <= const/c]
@ -568,7 +577,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* evaluate a linear expression *)
let eval_le_in_subst_ subst (le : LE.t) = LE.eval (eval_in_subst_ subst) le
(* FIXME: rename, this is more "provide_model_to_cc" *)
(* FIXME: rework into model creation
let do_th_combination (self : state) _si _acts : _ Iter.t =
Log.debug 1 "(lra.do-th-combinations)";
let model =
@ -603,6 +612,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* return whole model *)
Term.Tbl.to_iter vals |> Iter.map (fun (t, v) -> t, t_const self v)
*)
(* partial checks is where we add literals from the trail to the
simplex. *)
@ -714,7 +724,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
SI.on_partial_check si (partial_check_ st);
SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st);
SI.on_cc_is_subterm si (fun (_, _, t) ->
on_subterm st t;
on_subterm st si t;
[]);
SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) ->
match as_const_ (E_node.term n1), as_const_ (E_node.term n2) with
@ -725,7 +735,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct
E_node.pp n2);
Error (CC.Handler_action.Conflict expl)
| _ -> Ok []);
SI.on_th_combination si (do_th_combination st);
st
let theory =