feat(model): add theory hooks to "complete" models

these hooks are allowed to add terms to the model, that are not in the
congruence closure (for example in LRA, terms that were preprocessed
away).
This commit is contained in:
Simon Cruanes 2022-02-03 14:00:43 -05:00
parent c4bbaddc06
commit a98132ed0c
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
4 changed files with 63 additions and 14 deletions

View file

@ -997,10 +997,12 @@ module type SOLVER_INTERNAL = sig
(** {3 Model production} *) (** {3 Model production} *)
type model_hook = type model_ask_hook =
recurse:(t -> CC.N.t -> term) -> recurse:(t -> CC.N.t -> term) ->
t -> CC.N.t -> term option t -> CC.N.t -> term option
(** A model-production hook. It takes the solver, a class, and returns (** A model-production hook to query values from a theory.
It takes the solver, a class, and returns
a term for this class. For example, an arithmetic theory a term for this class. For example, an arithmetic theory
might detect that a class contains a numeric constant, and return might detect that a class contains a numeric constant, and return
this constant as a model value. this constant as a model value.
@ -1008,8 +1010,15 @@ module type SOLVER_INTERNAL = sig
If no hook assigns a value to a class, a fake value is created for it. If no hook assigns a value to a class, a fake value is created for it.
*) *)
val on_model_gen : t -> model_hook -> unit type model_completion_hook =
(** Add a hook that will be called when a model is being produced *) t -> add:(term -> term -> unit) -> unit
(** A model production hook, for the theory to add values.
The hook is given a [add] function to add bindings to the model. *)
val on_model :
?ask:model_ask_hook -> ?complete:model_completion_hook ->
t -> unit
(** Add model production/completion hooks. *)
end end
(** User facing view of the solver. (** User facing view of the solver.

View file

@ -170,6 +170,7 @@ module Make(A : ARG) : S with module A = A = struct
proof: SI.P.t; proof: SI.P.t;
simps: T.t T.Tbl.t; (* cache *) simps: T.t T.Tbl.t; (* cache *)
gensym: A.Gensym.t; gensym: A.Gensym.t;
in_model: T.t Vec.t; (* terms to add to model *)
encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *) encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *)
needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *) needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *)
mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *) mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *)
@ -182,6 +183,7 @@ module Make(A : ARG) : S with module A = A = struct
let create ?(stat=Stat.create()) proof tst ty_st : state = let create ?(stat=Stat.create()) proof tst ty_st : state =
{ tst; ty_st; { tst; ty_st;
proof; proof;
in_model=Vec.create();
simps=T.Tbl.create 128; simps=T.Tbl.create 128;
gensym=A.Gensym.create tst; gensym=A.Gensym.create tst;
encoded_eqs=T.Tbl.create 8; encoded_eqs=T.Tbl.create 8;
@ -283,6 +285,9 @@ module Make(A : ARG) : S with module A = A = struct
(* preprocess subterm *) (* preprocess subterm *)
let preproc_t ~steps t = let preproc_t ~steps t =
let u, pr = SI.preprocess_term si (module PA) t in let u, pr = SI.preprocess_term si (module PA) t in
if t != u then (
Vec.push self.in_model t;
);
CCOpt.iter (fun s -> steps := s :: !steps) pr; CCOpt.iter (fun s -> steps := s :: !steps) pr;
u u
in in
@ -674,17 +679,34 @@ module Make(A : ARG) : S with module A = A = struct
T.Tbl.add self.needs_th_combination t () T.Tbl.add self.needs_th_combination t ()
) )
let to_rat_t_ self q = A.mk_lra self.tst (LRA_const q)
(* help generating model *) (* help generating model *)
let model_gen_ (self:state) ~recurse:_ _si n : _ option = let model_gen_ (self:state) ~recurse:_ _si n : _ option =
let t = N.term n in let t = N.term n in
begin match self.last_res with begin match self.last_res with
| Some (SimpSolver.Sat m) -> | Some (SimpSolver.Sat m) ->
Log.debugf 50 (fun k->k "lra: model ask %a" T.pp t); Log.debugf 50 (fun k->k "lra: model ask %a" T.pp t);
let to_rat q = A.mk_lra self.tst (LRA_const q) in SimpSolver.V_map.get t m |> CCOpt.map (to_rat_t_ self)
SimpSolver.V_map.get t m |> CCOpt.map to_rat
| _ -> None | _ -> None
end end
(* help generating model *)
let model_complete_ (self:state) _si ~add : unit =
begin match self.last_res with
| Some (SimpSolver.Sat m) ->
Log.debugf 50 (fun k->k "lra: model complete");
let add_t t =
match SimpSolver.V_map.get t m with
| None -> ()
| Some u -> add t (to_rat_t_ self u)
in
Vec.iter add_t self.in_model
| _ -> ()
end
let k_state = SI.Registry.create_key () let k_state = SI.Registry.create_key ()
let create_and_setup si = let create_and_setup si =
@ -696,7 +718,7 @@ module Make(A : ARG) : S with module A = A = struct
SI.on_preprocess si (preproc_lra st); SI.on_preprocess si (preproc_lra st);
SI.on_final_check si (final_check_ st); SI.on_final_check si (final_check_ st);
SI.on_partial_check si (partial_check_ st); SI.on_partial_check si (partial_check_ st);
SI.on_model_gen si (model_gen_ st); SI.on_model si ~ask:(model_gen_ st) ~complete:(model_complete_ st);
SI.on_cc_is_subterm si (on_subterm st); SI.on_cc_is_subterm si (on_subterm st);
SI.on_cc_post_merge si SI.on_cc_post_merge si
(fun _ _ n1 n2 -> (fun _ _ n1 n2 ->

View file

@ -255,7 +255,8 @@ module Make(A : ARG)
mutable on_progress: unit -> unit; mutable on_progress: unit -> unit;
simp: Simplify.t; simp: Simplify.t;
mutable preprocess: preprocess_hook list; mutable preprocess: preprocess_hook list;
mutable mk_model: model_hook list; mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list;
preprocess_cache: (Term.t * proof_step Bag.t) Term.Tbl.t; preprocess_cache: (Term.t * proof_step Bag.t) Term.Tbl.t;
mutable t_defs : (term*term) list; (* term definitions *) mutable t_defs : (term*term) list; (* term definitions *)
mutable th_states : th_states; (** Set of theories *) mutable th_states : th_states; (** Set of theories *)
@ -270,10 +271,13 @@ module Make(A : ARG)
preprocess_actions -> preprocess_actions ->
term -> (term * proof_step Iter.t) option term -> (term * proof_step Iter.t) option
and model_hook = and model_ask_hook =
recurse:(t -> CC.N.t -> term) -> recurse:(t -> CC.N.t -> term) ->
t -> CC.N.t -> term option t -> CC.N.t -> term option
and model_completion_hook =
t -> add:(term -> term -> unit) -> unit
type solver = t type solver = t
module Proof = P module Proof = P
@ -292,7 +296,10 @@ module Make(A : ARG)
let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f
let on_preprocess self f = self.preprocess <- f :: self.preprocess let on_preprocess self f = self.preprocess <- f :: self.preprocess
let on_model_gen self f = self.mk_model <- f :: self.mk_model let on_model ?ask ?complete self =
CCOpt.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
CCOpt.iter (fun f -> self.model_complete <- f :: self.model_complete) complete;
()
let push_decision (_self:t) (acts:theory_actions) (lit:lit) : unit = let push_decision (_self:t) (acts:theory_actions) (lit:lit) : unit =
let (module A) = acts in let (module A) = acts in
@ -654,7 +661,8 @@ module Make(A : ARG)
simp=Simplify.create tst ty_st ~proof; simp=Simplify.create tst ty_st ~proof;
on_progress=(fun () -> ()); on_progress=(fun () -> ());
preprocess=[]; preprocess=[];
mk_model=[]; model_ask=[];
model_complete=[];
registry=Registry.create(); registry=Registry.create();
preprocess_cache=Term.Tbl.create 32; preprocess_cache=Term.Tbl.create 32;
count_axiom = Stat.mk_int stat "solver.th-axioms"; count_axiom = Stat.mk_int stat "solver.th-axioms";
@ -883,7 +891,7 @@ module Make(A : ARG)
Profile.with_ "smt-solver.mk-model" @@ fun () -> Profile.with_ "smt-solver.mk-model" @@ fun () ->
let module M = Term.Tbl in let module M = Term.Tbl in
let model = M.create 128 in let model = M.create 128 in
let {Solver_internal.tst; cc=lazy cc; mk_model=model_hooks; _} = self.si in let {Solver_internal.tst; cc=lazy cc; model_ask; model_complete; _} = self.si in
(* first, add all literals to the model using the given propositional model (* first, add all literals to the model using the given propositional model
[lits]. *) [lits]. *)
@ -892,6 +900,16 @@ module Make(A : ARG)
let t, sign = Lit.signed_term lit in let t, sign = Lit.signed_term lit in
M.replace model t (Term.bool tst sign)); M.replace model t (Term.bool tst sign));
(* complete model with theory specific values *)
let complete_with f =
f self.si
~add:(fun t u ->
if not (M.mem model t) then (
M.replace model t u
));
in
List.iter complete_with model_complete;
(* compute a value for [n]. *) (* compute a value for [n]. *)
let rec val_for_class (n:N.t) : term = let rec val_for_class (n:N.t) : term =
let repr = CC.find cc n in let repr = CC.find cc n in
@ -911,7 +929,7 @@ module Make(A : ARG)
end end
in in
let t_val = aux model_hooks in let t_val = aux model_ask in
M.replace model (N.term repr) t_val; (* be sure to cache the value *) M.replace model (N.term repr) t_val; (* be sure to cache the value *)
t_val t_val
in in

View file

@ -716,7 +716,7 @@ module Make(A : ARG) : S with module A = A = struct
SI.on_cc_new_term solver (on_new_term self); SI.on_cc_new_term solver (on_new_term self);
SI.on_cc_pre_merge solver (on_pre_merge self); SI.on_cc_pre_merge solver (on_pre_merge self);
SI.on_final_check solver (on_final_check self); SI.on_final_check solver (on_final_check self);
SI.on_model_gen solver (on_model_gen self); SI.on_model solver ~ask:(on_model_gen self);
self self
let theory = let theory =