feat(model): add theory hooks to "complete" models

these hooks are allowed to add terms to the model, that are not in the
congruence closure (for example in LRA, terms that were preprocessed
away).
This commit is contained in:
Simon Cruanes 2022-02-03 14:00:43 -05:00
parent c4bbaddc06
commit a98132ed0c
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
4 changed files with 63 additions and 14 deletions

View file

@ -997,10 +997,12 @@ module type SOLVER_INTERNAL = sig
(** {3 Model production} *)
type model_hook =
type model_ask_hook =
recurse:(t -> CC.N.t -> term) ->
t -> CC.N.t -> term option
(** A model-production hook. It takes the solver, a class, and returns
(** A model-production hook to query values from a theory.
It takes the solver, a class, and returns
a term for this class. For example, an arithmetic theory
might detect that a class contains a numeric constant, and return
this constant as a model value.
@ -1008,8 +1010,15 @@ module type SOLVER_INTERNAL = sig
If no hook assigns a value to a class, a fake value is created for it.
*)
val on_model_gen : t -> model_hook -> unit
(** Add a hook that will be called when a model is being produced *)
type model_completion_hook =
t -> add:(term -> term -> unit) -> unit
(** A model production hook, for the theory to add values.
The hook is given a [add] function to add bindings to the model. *)
val on_model :
?ask:model_ask_hook -> ?complete:model_completion_hook ->
t -> unit
(** Add model production/completion hooks. *)
end
(** User facing view of the solver.

View file

@ -170,6 +170,7 @@ module Make(A : ARG) : S with module A = A = struct
proof: SI.P.t;
simps: T.t T.Tbl.t; (* cache *)
gensym: A.Gensym.t;
in_model: T.t Vec.t; (* terms to add to model *)
encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *)
needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *)
mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *)
@ -182,6 +183,7 @@ module Make(A : ARG) : S with module A = A = struct
let create ?(stat=Stat.create()) proof tst ty_st : state =
{ tst; ty_st;
proof;
in_model=Vec.create();
simps=T.Tbl.create 128;
gensym=A.Gensym.create tst;
encoded_eqs=T.Tbl.create 8;
@ -283,6 +285,9 @@ module Make(A : ARG) : S with module A = A = struct
(* preprocess subterm *)
let preproc_t ~steps t =
let u, pr = SI.preprocess_term si (module PA) t in
if t != u then (
Vec.push self.in_model t;
);
CCOpt.iter (fun s -> steps := s :: !steps) pr;
u
in
@ -674,17 +679,34 @@ module Make(A : ARG) : S with module A = A = struct
T.Tbl.add self.needs_th_combination t ()
)
let to_rat_t_ self q = A.mk_lra self.tst (LRA_const q)
(* help generating model *)
let model_gen_ (self:state) ~recurse:_ _si n : _ option =
let t = N.term n in
begin match self.last_res with
| Some (SimpSolver.Sat m) ->
Log.debugf 50 (fun k->k "lra: model ask %a" T.pp t);
let to_rat q = A.mk_lra self.tst (LRA_const q) in
SimpSolver.V_map.get t m |> CCOpt.map to_rat
SimpSolver.V_map.get t m |> CCOpt.map (to_rat_t_ self)
| _ -> None
end
(* help generating model *)
let model_complete_ (self:state) _si ~add : unit =
begin match self.last_res with
| Some (SimpSolver.Sat m) ->
Log.debugf 50 (fun k->k "lra: model complete");
let add_t t =
match SimpSolver.V_map.get t m with
| None -> ()
| Some u -> add t (to_rat_t_ self u)
in
Vec.iter add_t self.in_model
| _ -> ()
end
let k_state = SI.Registry.create_key ()
let create_and_setup si =
@ -696,7 +718,7 @@ module Make(A : ARG) : S with module A = A = struct
SI.on_preprocess si (preproc_lra st);
SI.on_final_check si (final_check_ st);
SI.on_partial_check si (partial_check_ st);
SI.on_model_gen si (model_gen_ st);
SI.on_model si ~ask:(model_gen_ st) ~complete:(model_complete_ st);
SI.on_cc_is_subterm si (on_subterm st);
SI.on_cc_post_merge si
(fun _ _ n1 n2 ->

View file

@ -255,7 +255,8 @@ module Make(A : ARG)
mutable on_progress: unit -> unit;
simp: Simplify.t;
mutable preprocess: preprocess_hook list;
mutable mk_model: model_hook list;
mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list;
preprocess_cache: (Term.t * proof_step Bag.t) Term.Tbl.t;
mutable t_defs : (term*term) list; (* term definitions *)
mutable th_states : th_states; (** Set of theories *)
@ -270,10 +271,13 @@ module Make(A : ARG)
preprocess_actions ->
term -> (term * proof_step Iter.t) option
and model_hook =
and model_ask_hook =
recurse:(t -> CC.N.t -> term) ->
t -> CC.N.t -> term option
and model_completion_hook =
t -> add:(term -> term -> unit) -> unit
type solver = t
module Proof = P
@ -292,7 +296,10 @@ module Make(A : ARG)
let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f
let on_preprocess self f = self.preprocess <- f :: self.preprocess
let on_model_gen self f = self.mk_model <- f :: self.mk_model
let on_model ?ask ?complete self =
CCOpt.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
CCOpt.iter (fun f -> self.model_complete <- f :: self.model_complete) complete;
()
let push_decision (_self:t) (acts:theory_actions) (lit:lit) : unit =
let (module A) = acts in
@ -654,7 +661,8 @@ module Make(A : ARG)
simp=Simplify.create tst ty_st ~proof;
on_progress=(fun () -> ());
preprocess=[];
mk_model=[];
model_ask=[];
model_complete=[];
registry=Registry.create();
preprocess_cache=Term.Tbl.create 32;
count_axiom = Stat.mk_int stat "solver.th-axioms";
@ -883,7 +891,7 @@ module Make(A : ARG)
Profile.with_ "smt-solver.mk-model" @@ fun () ->
let module M = Term.Tbl in
let model = M.create 128 in
let {Solver_internal.tst; cc=lazy cc; mk_model=model_hooks; _} = self.si in
let {Solver_internal.tst; cc=lazy cc; model_ask; model_complete; _} = self.si in
(* first, add all literals to the model using the given propositional model
[lits]. *)
@ -892,6 +900,16 @@ module Make(A : ARG)
let t, sign = Lit.signed_term lit in
M.replace model t (Term.bool tst sign));
(* complete model with theory specific values *)
let complete_with f =
f self.si
~add:(fun t u ->
if not (M.mem model t) then (
M.replace model t u
));
in
List.iter complete_with model_complete;
(* compute a value for [n]. *)
let rec val_for_class (n:N.t) : term =
let repr = CC.find cc n in
@ -911,7 +929,7 @@ module Make(A : ARG)
end
in
let t_val = aux model_hooks in
let t_val = aux model_ask in
M.replace model (N.term repr) t_val; (* be sure to cache the value *)
t_val
in

View file

@ -716,7 +716,7 @@ module Make(A : ARG) : S with module A = A = struct
SI.on_cc_new_term solver (on_new_term self);
SI.on_cc_pre_merge solver (on_pre_merge self);
SI.on_final_check solver (on_final_check self);
SI.on_model_gen solver (on_model_gen self);
SI.on_model solver ~ask:(on_model_gen self);
self
let theory =