diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 7bd2bdd2..c89a9c14 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -997,10 +997,12 @@ module type SOLVER_INTERNAL = sig (** {3 Model production} *) - type model_hook = + type model_ask_hook = recurse:(t -> CC.N.t -> term) -> t -> CC.N.t -> term option - (** A model-production hook. It takes the solver, a class, and returns + (** A model-production hook to query values from a theory. + + It takes the solver, a class, and returns a term for this class. For example, an arithmetic theory might detect that a class contains a numeric constant, and return this constant as a model value. @@ -1008,8 +1010,15 @@ module type SOLVER_INTERNAL = sig If no hook assigns a value to a class, a fake value is created for it. *) - val on_model_gen : t -> model_hook -> unit - (** Add a hook that will be called when a model is being produced *) + type model_completion_hook = + t -> add:(term -> term -> unit) -> unit + (** A model production hook, for the theory to add values. + The hook is given a [add] function to add bindings to the model. *) + + val on_model : + ?ask:model_ask_hook -> ?complete:model_completion_hook -> + t -> unit + (** Add model production/completion hooks. *) end (** User facing view of the solver. diff --git a/src/lra/sidekick_arith_lra.ml b/src/lra/sidekick_arith_lra.ml index a10ccf0c..ea941200 100644 --- a/src/lra/sidekick_arith_lra.ml +++ b/src/lra/sidekick_arith_lra.ml @@ -170,6 +170,7 @@ module Make(A : ARG) : S with module A = A = struct proof: SI.P.t; simps: T.t T.Tbl.t; (* cache *) gensym: A.Gensym.t; + in_model: T.t Vec.t; (* terms to add to model *) encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *) needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *) mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *) @@ -182,6 +183,7 @@ module Make(A : ARG) : S with module A = A = struct let create ?(stat=Stat.create()) proof tst ty_st : state = { tst; ty_st; proof; + in_model=Vec.create(); simps=T.Tbl.create 128; gensym=A.Gensym.create tst; encoded_eqs=T.Tbl.create 8; @@ -283,6 +285,9 @@ module Make(A : ARG) : S with module A = A = struct (* preprocess subterm *) let preproc_t ~steps t = let u, pr = SI.preprocess_term si (module PA) t in + if t != u then ( + Vec.push self.in_model t; + ); CCOpt.iter (fun s -> steps := s :: !steps) pr; u in @@ -674,17 +679,34 @@ module Make(A : ARG) : S with module A = A = struct T.Tbl.add self.needs_th_combination t () ) + let to_rat_t_ self q = A.mk_lra self.tst (LRA_const q) + (* help generating model *) let model_gen_ (self:state) ~recurse:_ _si n : _ option = let t = N.term n in begin match self.last_res with | Some (SimpSolver.Sat m) -> Log.debugf 50 (fun k->k "lra: model ask %a" T.pp t); - let to_rat q = A.mk_lra self.tst (LRA_const q) in - SimpSolver.V_map.get t m |> CCOpt.map to_rat + SimpSolver.V_map.get t m |> CCOpt.map (to_rat_t_ self) | _ -> None end + (* help generating model *) + let model_complete_ (self:state) _si ~add : unit = + begin match self.last_res with + | Some (SimpSolver.Sat m) -> + Log.debugf 50 (fun k->k "lra: model complete"); + + let add_t t = + match SimpSolver.V_map.get t m with + | None -> () + | Some u -> add t (to_rat_t_ self u) + in + Vec.iter add_t self.in_model + + | _ -> () + end + let k_state = SI.Registry.create_key () let create_and_setup si = @@ -696,7 +718,7 @@ module Make(A : ARG) : S with module A = A = struct SI.on_preprocess si (preproc_lra st); SI.on_final_check si (final_check_ st); SI.on_partial_check si (partial_check_ st); - SI.on_model_gen si (model_gen_ st); + SI.on_model si ~ask:(model_gen_ st) ~complete:(model_complete_ st); SI.on_cc_is_subterm si (on_subterm st); SI.on_cc_post_merge si (fun _ _ n1 n2 -> diff --git a/src/smt-solver/Sidekick_smt_solver.ml b/src/smt-solver/Sidekick_smt_solver.ml index fdbd3334..36954b9e 100644 --- a/src/smt-solver/Sidekick_smt_solver.ml +++ b/src/smt-solver/Sidekick_smt_solver.ml @@ -255,7 +255,8 @@ module Make(A : ARG) mutable on_progress: unit -> unit; simp: Simplify.t; mutable preprocess: preprocess_hook list; - mutable mk_model: model_hook list; + mutable model_ask: model_ask_hook list; + mutable model_complete: model_completion_hook list; preprocess_cache: (Term.t * proof_step Bag.t) Term.Tbl.t; mutable t_defs : (term*term) list; (* term definitions *) mutable th_states : th_states; (** Set of theories *) @@ -270,10 +271,13 @@ module Make(A : ARG) preprocess_actions -> term -> (term * proof_step Iter.t) option - and model_hook = + and model_ask_hook = recurse:(t -> CC.N.t -> term) -> t -> CC.N.t -> term option + and model_completion_hook = + t -> add:(term -> term -> unit) -> unit + type solver = t module Proof = P @@ -292,7 +296,10 @@ module Make(A : ARG) let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f let on_preprocess self f = self.preprocess <- f :: self.preprocess - let on_model_gen self f = self.mk_model <- f :: self.mk_model + let on_model ?ask ?complete self = + CCOpt.iter (fun f -> self.model_ask <- f :: self.model_ask) ask; + CCOpt.iter (fun f -> self.model_complete <- f :: self.model_complete) complete; + () let push_decision (_self:t) (acts:theory_actions) (lit:lit) : unit = let (module A) = acts in @@ -654,7 +661,8 @@ module Make(A : ARG) simp=Simplify.create tst ty_st ~proof; on_progress=(fun () -> ()); preprocess=[]; - mk_model=[]; + model_ask=[]; + model_complete=[]; registry=Registry.create(); preprocess_cache=Term.Tbl.create 32; count_axiom = Stat.mk_int stat "solver.th-axioms"; @@ -883,7 +891,7 @@ module Make(A : ARG) Profile.with_ "smt-solver.mk-model" @@ fun () -> let module M = Term.Tbl in let model = M.create 128 in - let {Solver_internal.tst; cc=lazy cc; mk_model=model_hooks; _} = self.si in + let {Solver_internal.tst; cc=lazy cc; model_ask; model_complete; _} = self.si in (* first, add all literals to the model using the given propositional model [lits]. *) @@ -892,6 +900,16 @@ module Make(A : ARG) let t, sign = Lit.signed_term lit in M.replace model t (Term.bool tst sign)); + (* complete model with theory specific values *) + let complete_with f = + f self.si + ~add:(fun t u -> + if not (M.mem model t) then ( + M.replace model t u + )); + in + List.iter complete_with model_complete; + (* compute a value for [n]. *) let rec val_for_class (n:N.t) : term = let repr = CC.find cc n in @@ -911,7 +929,7 @@ module Make(A : ARG) end in - let t_val = aux model_hooks in + let t_val = aux model_ask in M.replace model (N.term repr) t_val; (* be sure to cache the value *) t_val in diff --git a/src/th-data/Sidekick_th_data.ml b/src/th-data/Sidekick_th_data.ml index aac8d061..ffd2e2a6 100644 --- a/src/th-data/Sidekick_th_data.ml +++ b/src/th-data/Sidekick_th_data.ml @@ -716,7 +716,7 @@ module Make(A : ARG) : S with module A = A = struct SI.on_cc_new_term solver (on_new_term self); SI.on_cc_pre_merge solver (on_pre_merge self); SI.on_final_check solver (on_final_check self); - SI.on_model_gen solver (on_model_gen self); + SI.on_model solver ~ask:(on_model_gen self); self let theory =