wip(smt): theory combination

This commit is contained in:
Simon Cruanes 2022-08-27 21:38:20 -04:00
parent 2a0feed32c
commit ccb3753668
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
7 changed files with 174 additions and 121 deletions

View file

@ -55,6 +55,16 @@ let unfold_app (e : term) : term * term list =
in in
aux [] e aux [] e
let[@inline] is_const e =
match e.view with
| E_const _ -> true
| _ -> false
let[@inline] is_app e =
match e.view with
| E_app _ -> true
| _ -> false
(* debug printer *) (* debug printer *)
let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit = let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
let rec loop k ~depth names out e = let rec loop k ~depth names out e =

View file

@ -53,6 +53,8 @@ include WITH_SET_MAP_TBL with type t := t
val view : t -> view val view : t -> view
val unfold_app : t -> t * t list val unfold_app : t -> t * t list
val is_app : t -> bool
val is_const : t -> bool
val iter_dag : ?seen:unit Tbl.t -> iter_ty:bool -> f:(t -> unit) -> t -> unit val iter_dag : ?seen:unit Tbl.t -> iter_ty:bool -> f:(t -> unit) -> t -> unit
(** [iter_dag t ~f] calls [f] once on each subterm of [t], [t] included. (** [iter_dag t ~f] calls [f] once on each subterm of [t], [t] included.

View file

@ -21,6 +21,7 @@ module type PREPROCESS_ACTS = sig
val mk_lit : ?sign:bool -> term -> lit val mk_lit : ?sign:bool -> term -> lit
val add_clause : lit list -> step_id -> unit val add_clause : lit list -> step_id -> unit
val add_lit : ?default_pol:bool -> lit -> unit val add_lit : ?default_pol:bool -> lit -> unit
val add_term_needing_combination : term -> unit
end end
type preprocess_actions = (module PREPROCESS_ACTS) type preprocess_actions = (module PREPROCESS_ACTS)
@ -39,10 +40,9 @@ type t = {
proof: proof_trace; (** proof logger *) proof: proof_trace; (** proof logger *)
registry: Registry.t; registry: Registry.t;
on_progress: (unit, unit) Event.Emitter.t; on_progress: (unit, unit) Event.Emitter.t;
th_comb: Th_combination.t;
mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_th_combination:
(t -> theory_actions -> (term * value) Iter.t) list;
mutable preprocess: preprocess_hook list; mutable preprocess: preprocess_hook list;
mutable model_ask: model_ask_hook list; mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list; mutable model_complete: model_completion_hook list;
@ -82,11 +82,11 @@ let add_simplifier (self : t) f : unit = Simplify.add_hook self.simp f
let[@inline] has_delayed_actions self = let[@inline] has_delayed_actions self =
not (Queue.is_empty self.delayed_actions) not (Queue.is_empty self.delayed_actions)
let on_th_combination self f =
self.on_th_combination <- f :: self.on_th_combination
let on_preprocess self f = self.preprocess <- f :: self.preprocess let on_preprocess self f = self.preprocess <- f :: self.preprocess
let add_term_needing_combination self t =
Th_combination.add_term_needing_combination self.th_comb t
let on_model ?ask ?complete self = let on_model ?ask ?complete self =
Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask; Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
Option.iter Option.iter
@ -130,6 +130,9 @@ let preprocess_term_ (self : t) (t0 : term) : unit =
let mk_lit ?sign t : Lit.t = Lit.atom ?sign self.tst t let mk_lit ?sign t : Lit.t = Lit.atom ?sign self.tst t
let add_lit ?default_pol lit : unit = delayed_add_lit self ?default_pol lit let add_lit ?default_pol lit : unit = delayed_add_lit self ?default_pol lit
let add_clause c pr : unit = delayed_add_clause self ~keep:true c pr let add_clause c pr : unit = delayed_add_clause self ~keep:true c pr
let add_term_needing_combination t =
Th_combination.add_term_needing_combination self.th_comb t
end in end in
let acts = (module A : PREPROCESS_ACTS) in let acts = (module A : PREPROCESS_ACTS) in
@ -397,33 +400,12 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
(* do theory combination using the congruence closure. Each theory (* do theory combination using the congruence closure. Each theory
can merge classes, *) can merge classes, *)
let check_th_combination_ (self : t) (_acts : theory_actions) lits : let check_th_combination_ (self : t) (acts : theory_actions) _lits : unit =
(Model.t, th_combination_conflict) result = let lits_to_decide = Th_combination.pop_new_lits self.th_comb in
(* FIXME if lits_to_decide <> [] then (
let (module A) = acts in
(* enter model mode, disabling most of congruence closure *) List.iter (fun lit -> A.add_lit ~default_pol:false lit) lits_to_decide
CC.with_model_mode cc @@ fun () -> )
let set_val (t, v) : unit =
Log.debugf 50 (fun k ->
k "(@[solver.th-comb.cc-set-term-value@ %a@ :val %a@])" Term.pp_debug t
Term.pp_debug v);
CC.set_model_value cc t v
in
(* obtain assignments from the hook, and communicate them to the CC *)
let add_th_values f : unit =
let vals = f self acts in
Iter.iter set_val vals
in
try
List.iter add_th_values self.on_th_combination;
CC.check cc;
let m = mk_model_ self in
Ok m
with Semantic_conflict c -> Error c
*)
let m = mk_model_ self lits in
Ok m
(* call congruence closure, perform the actions it scheduled *) (* call congruence closure, perform the actions it scheduled *)
let check_cc_with_acts_ (self : t) (acts : theory_actions) = let check_cc_with_acts_ (self : t) (acts : theory_actions) =
@ -471,40 +453,13 @@ let assert_lits_ ~final (self : t) (acts : theory_actions) (lits : Lit.t Iter.t)
(* do actual theory combination if nothing changed by pure "final check" *) (* do actual theory combination if nothing changed by pure "final check" *)
if not new_work then ( if not new_work then (
match check_th_combination_ self acts lits with check_th_combination_ self acts lits;
| Ok m -> self.last_model <- Some m
| Error { lits; semantic } ->
(* bad model, we add a clause to remove it *)
Log.debugf 5 (fun k ->
k
"(@[solver.th-comb.conflict@ :lits (@[%a@])@ :same-val \
(@[%a@])@])"
(Util.pp_list Lit.pp) lits
(Util.pp_list
@@ Fmt.Dump.(triple bool Term.pp_debug Term.pp_debug))
semantic);
let c1 = List.rev_map Lit.neg lits in (* if theory combination didn't add new clauses, compute a model *)
let c2 = if not (has_delayed_actions self) then (
semantic let m = mk_model_ self lits in
|> List.rev_map (fun (sign, t, u) -> self.last_model <- Some m
let eqn = Term.eq self.tst t u in )
let lit = Lit.atom ~sign:(not sign) self.tst eqn in
(* make sure to consider the new lit *)
add_lit self acts lit;
lit)
in
let c = List.rev_append c1 c2 in
let pr =
Proof_trace.add_step self.proof @@ fun () -> Proof_core.lemma_cc c
in
Log.debugf 20 (fun k ->
k "(@[solver.th-comb.add-semantic-conflict-clause@ %a@])"
(Util.pp_list Lit.pp) c);
(* will add a delayed action *)
add_clause_temp self acts c pr
); );
Perform_delayed_th.top self acts Perform_delayed_th.top self acts
@ -585,6 +540,7 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
stat; stat;
simp = Simplify.create tst ~proof; simp = Simplify.create tst ~proof;
last_model = None; last_model = None;
th_comb = Th_combination.create ~stat tst;
on_progress = Event.Emitter.create (); on_progress = Event.Emitter.create ();
preprocess = []; preprocess = [];
model_ask = []; model_ask = [];
@ -598,7 +554,6 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
count_conflict = Stat.mk_int stat "smt.solver.th-conflicts"; count_conflict = Stat.mk_int stat "smt.solver.th-conflicts";
on_partial_check = []; on_partial_check = [];
on_final_check = []; on_final_check = [];
on_th_combination = [];
level = 0; level = 0;
complete = true; complete = true;
} }

View file

@ -73,6 +73,10 @@ module type PREPROCESS_ACTS = sig
val add_lit : ?default_pol:bool -> lit -> unit val add_lit : ?default_pol:bool -> lit -> unit
(** Ensure the literal will be decided/handled by the SAT solver. *) (** Ensure the literal will be decided/handled by the SAT solver. *)
val add_term_needing_combination : term -> unit
(** Declare this term as being a foreign variable in the theory,
which means it needs to go through theory combination. *)
end end
type preprocess_actions = (module PREPROCESS_ACTS) type preprocess_actions = (module PREPROCESS_ACTS)
@ -98,6 +102,10 @@ val preprocess_clause_array : t -> lit array -> step_id -> lit array * step_id
val simplify_and_preproc_lit : t -> lit -> lit * step_id option val simplify_and_preproc_lit : t -> lit -> lit * step_id option
(** Simplify literal then preprocess it *) (** Simplify literal then preprocess it *)
val add_term_needing_combination : t -> term -> unit
(** Declare this term as being a foreign variable in the theory,
which means it needs to go through theory combination. *)
(** {3 hooks for the theory} *) (** {3 hooks for the theory} *)
val raise_conflict : t -> theory_actions -> lit list -> step_id -> 'a val raise_conflict : t -> theory_actions -> lit list -> step_id -> 'a
@ -216,16 +224,6 @@ val on_final_check : t -> (t -> theory_actions -> lit Iter.t -> unit) -> unit
is given the whole trail. is given the whole trail.
*) *)
val on_th_combination :
t -> (t -> theory_actions -> (term * value) Iter.t) -> unit
(** Add a hook called during theory combination.
The hook must return an iterator of pairs [(t, v)]
which mean that term [t] has value [v] in the model.
Terms with the same value (according to {!Term.equal}) will be
merged in the CC; if two terms with different values are merged,
we get a semantic conflict and must pick another model. *)
val declare_pb_is_incomplete : t -> unit val declare_pb_is_incomplete : t -> unit
(** Declare that, in some theory, the problem is outside the logic fragment (** Declare that, in some theory, the problem is outside the logic fragment
that is decidable (e.g. if we meet proper NIA formulas). that is decidable (e.g. if we meet proper NIA formulas).

62
src/smt/th_combination.ml Normal file
View file

@ -0,0 +1,62 @@
open Sidekick_core
module T = Term
type t = {
tst: Term.store;
processed: T.Set.t T.Tbl.t; (** type -> set of terms *)
unprocessed: T.t Vec.t;
new_lits: Lit.t Vec.t;
n_terms: int Stat.counter;
n_lits: int Stat.counter;
}
let create ?(stat = Stat.global) tst : t =
{
tst;
processed = T.Tbl.create 8;
unprocessed = Vec.create ();
new_lits = Vec.create ();
n_terms = Stat.mk_int stat "smt.thcomb.terms";
n_lits = Stat.mk_int stat "smt.thcomb.intf-lits";
}
let processed_ (self : t) t : bool =
let ty = T.ty t in
match T.Tbl.find_opt self.processed ty with
| None -> false
| Some set -> T.Set.mem t set
let add_term_needing_combination (self : t) (t : T.t) : unit =
if not (processed_ self t) then (
Log.debugf 50 (fun k -> k "(@[th.comb.add-term-needing-comb@ %a@])" T.pp t);
Vec.push self.unprocessed t
)
let pop_new_lits (self : t) : Lit.t list =
(* first, process new terms, if any *)
while not (Vec.is_empty self.unprocessed) do
let t = Vec.pop_exn self.unprocessed in
let ty = T.ty t in
let set_for_ty =
try T.Tbl.find self.processed ty with Not_found -> T.Set.empty
in
if not (T.Set.mem t set_for_ty) then (
Stat.incr self.n_terms;
(* now create [t=u] for each [u] in [set_for_ty] *)
T.Set.iter
(fun u ->
let lit = Lit.make_eq self.tst t u in
Stat.incr self.n_lits;
Vec.push self.new_lits lit)
set_for_ty;
(* add [t] to the set of processed terms *)
let new_set_for_ty = T.Set.add t set_for_ty in
T.Tbl.replace self.processed ty new_set_for_ty
)
done;
let lits = Vec.to_list self.new_lits in
Vec.clear self.new_lits;
lits

View file

@ -0,0 +1,17 @@
(** Delayed Theory Combination *)
open Sidekick_core
type t
val create : ?stat:Stat.t -> Term.store -> t
val add_term_needing_combination : t -> Term.t -> unit
(** [add_term_needing_combination self t] means that [t] occurs as a foreign
variable in another term, so it is important that its theory, and the
theory in which it occurs, agree on it being equal to other
foreign terms. *)
val pop_new_lits : t -> Lit.t list
(** Get the new literals that the solver needs to decide, so that the
SMT solver gives each theory the same partition of interface equalities. *)

View file

@ -130,8 +130,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct
in_model: unit Term.Tbl.t; (* terms to add to model *) in_model: unit Term.Tbl.t; (* terms to add to model *)
encoded_eqs: unit Term.Tbl.t; encoded_eqs: unit Term.Tbl.t;
(* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *) (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *)
needs_th_combination: unit Term.Tbl.t;
(* terms that require theory combination *)
simp_preds: (Term.t * S_op.t * A.Q.t) Term.Tbl.t; simp_preds: (Term.t * S_op.t * A.Q.t) Term.Tbl.t;
(* term -> its simplex meaning *) (* term -> its simplex meaning *)
simp_defined: LE.t Term.Tbl.t; simp_defined: LE.t Term.Tbl.t;
@ -157,7 +155,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct
simp_preds = Term.Tbl.create 32; simp_preds = Term.Tbl.create 32;
simp_defined = Term.Tbl.create 16; simp_defined = Term.Tbl.create 16;
encoded_eqs = Term.Tbl.create 8; encoded_eqs = Term.Tbl.create 8;
needs_th_combination = Term.Tbl.create 8;
encoded_le = Comb_map.empty; encoded_le = Comb_map.empty;
simplex = SimpSolver.create ~stat (); simplex = SimpSolver.create ~stat ();
last_res = None; last_res = None;
@ -275,6 +272,11 @@ module Make (A : ARG) = (* : S with module A = A *) struct
| Geq -> S_op.Geq | Geq -> S_op.Geq
| Gt -> S_op.Gt | Gt -> S_op.Gt
(* add [t] to the theory combination system if it's not just a constant
of type Real *)
let add_lra_var_to_th_combination (si : SI.t) (t : term) : unit =
if not (Term.is_const t) then SI.add_term_needing_combination si t
(* TODO: refactor that and {!var_encoding_comb} *) (* TODO: refactor that and {!var_encoding_comb} *)
(* turn a linear expression into a single constant and a coeff. (* turn a linear expression into a single constant and a coeff.
This might define a side variable in the simplex. *) This might define a side variable in the simplex. *)
@ -300,17 +302,20 @@ module Make (A : ARG) = (* : S with module A = A *) struct
proxy, A.Q.one) proxy, A.Q.one)
(* look for subterms of type Real, for they will need theory combination *) (* look for subterms of type Real, for they will need theory combination *)
let on_subterm (self : state) (t : Term.t) : unit = let on_subterm (_self : state) (si : SI.t) (t : Term.t) : unit =
Log.debugf 50 (fun k -> k "(@[lra.cc-on-subterm@ %a@])" Term.pp_debug t); Log.debugf 50 (fun k -> k "(@[lra.cc-on-subterm@ %a@])" Term.pp_debug t);
match A.view_as_lra t with match A.view_as_lra t with
| LRA_other _ when not (A.has_ty_real t) -> () | LRA_other _ when not (A.has_ty_real t) ->
(* for a non-LRA term [f args], if any of [args] is in LRA,
it needs theory combination *)
let _, args = Term.unfold_app t in
List.iter
(fun arg ->
if A.has_ty_real arg then SI.add_term_needing_combination si arg)
args
| LRA_pred _ | LRA_const _ -> () | LRA_pred _ | LRA_const _ -> ()
| LRA_op _ | LRA_other _ | LRA_mult _ -> | LRA_op _ | LRA_other _ | LRA_mult _ ->
if not (Term.Tbl.mem self.needs_th_combination t) then ( SI.add_term_needing_combination si t
Log.debugf 5 (fun k ->
k "(@[lra.needs-th-combination@ %a@])" Term.pp_debug t);
Term.Tbl.add self.needs_th_combination t ()
)
(* preprocess linear expressions away *) (* preprocess linear expressions away *)
let preproc_lra (self : state) si (module PA : SI.PREPROCESS_ACTS) let preproc_lra (self : state) si (module PA : SI.PREPROCESS_ACTS)
@ -323,7 +328,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
Log.debugf 50 (fun k -> Log.debugf 50 (fun k ->
k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t); k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t);
ignore (CC.add_term (SI.cc si) t : E_node.t); ignore (CC.add_term (SI.cc si) t : E_node.t);
if sub then on_subterm self t if sub then on_subterm self si t
in in
match A.view_as_lra t with match A.view_as_lra t with
@ -369,7 +374,11 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* obtain a single variable for the linear combination *) (* obtain a single variable for the linear combination *)
let v, c_v = le_comb_to_singleton_ self le_comb in let v, c_v = le_comb_to_singleton_ self le_comb in
declare_term_to_cc ~sub:false v; declare_term_to_cc ~sub:false v;
LE_.Comb.iter (fun v _ -> declare_term_to_cc ~sub:true v) le_comb; LE_.Comb.iter
(fun v _ ->
declare_term_to_cc ~sub:true v;
add_lra_var_to_th_combination si v)
le_comb;
(* turn into simplex constraint. For example, (* turn into simplex constraint. For example,
[c . v <= const] becomes a direct simplex constraint [v <= const/c] [c . v <= const] becomes a direct simplex constraint [v <= const/c]
@ -568,7 +577,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* evaluate a linear expression *) (* evaluate a linear expression *)
let eval_le_in_subst_ subst (le : LE.t) = LE.eval (eval_in_subst_ subst) le let eval_le_in_subst_ subst (le : LE.t) = LE.eval (eval_in_subst_ subst) le
(* FIXME: rename, this is more "provide_model_to_cc" *) (* FIXME: rework into model creation
let do_th_combination (self : state) _si _acts : _ Iter.t = let do_th_combination (self : state) _si _acts : _ Iter.t =
Log.debug 1 "(lra.do-th-combinations)"; Log.debug 1 "(lra.do-th-combinations)";
let model = let model =
@ -603,6 +612,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* return whole model *) (* return whole model *)
Term.Tbl.to_iter vals |> Iter.map (fun (t, v) -> t, t_const self v) Term.Tbl.to_iter vals |> Iter.map (fun (t, v) -> t, t_const self v)
*)
(* partial checks is where we add literals from the trail to the (* partial checks is where we add literals from the trail to the
simplex. *) simplex. *)
@ -714,7 +724,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
SI.on_partial_check si (partial_check_ st); SI.on_partial_check si (partial_check_ st);
SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st); SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st);
SI.on_cc_is_subterm si (fun (_, _, t) -> SI.on_cc_is_subterm si (fun (_, _, t) ->
on_subterm st t; on_subterm st si t;
[]); []);
SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) -> SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) ->
match as_const_ (E_node.term n1), as_const_ (E_node.term n2) with match as_const_ (E_node.term n1), as_const_ (E_node.term n2) with
@ -725,7 +735,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct
E_node.pp n2); E_node.pp n2);
Error (CC.Handler_action.Conflict expl) Error (CC.Handler_action.Conflict expl)
| _ -> Ok []); | _ -> Ok []);
SI.on_th_combination si (do_th_combination st);
st st
let theory = let theory =