feat(smt): produce better model, with eval function

This commit is contained in:
Simon Cruanes 2022-10-15 23:11:27 -04:00
parent 08541613af
commit 4546b7cff2
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
8 changed files with 51 additions and 19 deletions

View file

@ -11,6 +11,7 @@ module Model_builder = Model_builder
module Registry = Registry module Registry = Registry
module Solver_internal = Solver_internal module Solver_internal = Solver_internal
module Solver = Solver module Solver = Solver
module Model = Model
module Theory = Theory module Theory = Theory
module Theory_id = Theory_id module Theory_id = Theory_id
module Preprocess = Preprocess module Preprocess = Preprocess

13
src/smt/model.ml Normal file
View file

@ -0,0 +1,13 @@
(** SMT models.
The solver models are partially evaluated; the frontend might ask
for values for terms not explicitly present in them.
*)
open Sigs
type t = { eval: Term.t -> value option; map: value Term.Map.t }
let eval (self : t) (t : Term.t) : value option =
try Some (Term.Map.find t self.map) with Not_found -> self.eval t

View file

@ -41,17 +41,26 @@ let add (self : t) ?(subs = []) t v : unit =
type eval_cache = Term.Internal_.cache type eval_cache = Term.Internal_.cache
let create_cache = Term.Internal_.create_cache
let eval_opt ?(cache = Term.Internal_.create_cache 8) (self : t) (t : Term.t) =
match TM.get t self.m with
| None -> None
| Some t ->
Some
(T.Internal_.replace_ ~cache self.tst ~recursive:true t
~f:(fun ~recurse:_ u -> TM.get u self.m))
let eval ?(cache = Term.Internal_.create_cache 8) (self : t) (t : Term.t) = let eval ?(cache = Term.Internal_.create_cache 8) (self : t) (t : Term.t) =
let t = TM.get t self.m |> Option.value ~default:t in let t = TM.get t self.m |> Option.value ~default:t in
T.Internal_.replace_ ~cache self.tst ~recursive:true t ~f:(fun ~recurse:_ u -> T.Internal_.replace_ ~cache self.tst ~recursive:true t ~f:(fun ~recurse:_ u ->
TM.get u self.m) TM.get u self.m)
let to_map (self : t) : _ TM.t = let to_map ?(cache = T.Internal_.create_cache 8) (self : t) : _ TM.t =
(* ensure we evaluate each term only once *) (* ensure we evaluate each term only once by using a cache *)
let cache = T.Internal_.create_cache 8 in let map =
let m =
TM.keys self.m TM.keys self.m
|> Iter.map (fun t -> t, eval ~cache self t) |> Iter.map (fun t -> t, eval ~cache self t)
|> Iter.fold (fun m (t, v) -> TM.add t v m) TM.empty |> Iter.fold (fun m (t, v) -> TM.add t v m) TM.empty
in in
m map

View file

@ -28,9 +28,11 @@ val gensym : t -> pre:string -> ty:Term.t -> Term.t
type eval_cache = Term.Internal_.cache type eval_cache = Term.Internal_.cache
val create_cache : int -> eval_cache
val eval : ?cache:eval_cache -> t -> Term.t -> value val eval : ?cache:eval_cache -> t -> Term.t -> value
val eval_opt : ?cache:eval_cache -> t -> Term.t -> value option
val pop_required : t -> Term.t option val pop_required : t -> Term.t option
(** gives the next subterm that is required but has no value yet *) (** gives the next subterm that is required but has no value yet *)
val to_map : t -> Term.t Term.Map.t val to_map : ?cache:eval_cache -> t -> value Term.Map.t

View file

@ -219,12 +219,8 @@ let solve ?(on_exit = []) ?(on_progress = fun _ -> ())
not @@ Term.is_pi (Term.ty @@ E_node.term repr)) not @@ Term.is_pi (Term.ty @@ E_node.term repr))
|> Iter.map (fun repr -> |> Iter.map (fun repr ->
let v = let v =
match (* find value for this class *)
(* find value for this class *) match Model.eval m (E_node.term repr) with
Iter.find_map
(fun en -> Term.Map.get (E_node.term en) m)
(E_node.iter_class repr)
with
| None -> | None ->
Error.errorf Error.errorf
"(@[solver.mk-model.no-value-for-repr@ %a@ :ty %a@])" "(@[solver.mk-model.no-value-for-repr@ %a@ :ty %a@])"
@ -248,7 +244,7 @@ let solve ?(on_exit = []) ?(on_progress = fun _ -> ())
do_on_exit (); do_on_exit ();
Sat Sat
{ {
get_value = (fun t -> Term.Map.get t m); get_value = Model.eval m;
iter_classes; iter_classes;
eval_lit = eval_lit =
(fun l -> (fun l ->

View file

@ -53,7 +53,7 @@ type t = {
mutable model_complete: model_completion_hook list; mutable model_complete: model_completion_hook list;
simp: Simplify.t; simp: Simplify.t;
delayed_actions: delayed_action Queue.t; delayed_actions: delayed_action Queue.t;
mutable last_model: Term.t Term.Map.t option; mutable last_model: Model.t option;
mutable th_states: th_states; (** Set of theories *) mutable th_states: th_states; (** Set of theories *)
mutable level: int; mutable level: int;
mutable complete: bool; mutable complete: bool;
@ -327,12 +327,13 @@ let rec pop_lvls_theories_ n = function
(** {2 Model construction and theory combination} *) (** {2 Model construction and theory combination} *)
(* make model from the congruence closure *) (* make model from the congruence closure *)
let mk_model_ (self : t) (lits : lit Iter.t) : Term.t Term.Map.t = let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
let@ () = Profile.with_ "smt-solver.mk-model" in let@ () = Profile.with_ "smt-solver.mk-model" in
Log.debug 1 "(smt.solver.mk-model)"; Log.debug 1 "(smt.solver.mk-model)";
let module MB = Model_builder in let module MB = Model_builder in
let { cc; tst; model_ask = model_ask_hooks; model_complete; _ } = self in let { cc; tst; model_ask = model_ask_hooks; model_complete; _ } = self in
let cache = Model_builder.create_cache 8 in
let model = Model_builder.create tst in let model = Model_builder.create tst in
Model_builder.add model (Term.true_ tst) (Term.true_ tst); Model_builder.add model (Term.true_ tst) (Term.true_ tst);
@ -395,7 +396,18 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Term.t Term.Map.t =
in in
compute_fixpoint (); compute_fixpoint ();
MB.to_map model
let map = MB.to_map ~cache model in
let eval (t : Term.t) : value option =
try Some (Term.Map.find t map)
with Not_found ->
MB.require_eval model t;
compute_fixpoint ();
MB.eval_opt ~cache model t
in
{ Model.map; eval }
(* call congruence closure, perform the actions it scheduled *) (* call congruence closure, perform the actions it scheduled *)
let check_cc_with_acts_ (self : t) (acts : theory_actions) = let check_cc_with_acts_ (self : t) (acts : theory_actions) =

View file

@ -270,7 +270,7 @@ val on_progress : t -> (unit, unit) Event.t
val is_complete : t -> bool val is_complete : t -> bool
(** Are we still in a complete logic fragment? *) (** Are we still in a complete logic fragment? *)
val last_model : t -> Term.t Term.Map.t option val last_model : t -> Model.t option
(** {2 Delayed actions} *) (** {2 Delayed actions} *)

View file

@ -22,8 +22,7 @@ let build (self : t) (sat : Solver.sat_result) : Model.t =
match List.map (fun t -> sat.get_value t |> Option.get) args with match List.map (fun t -> sat.get_value t |> Option.get) args with
| exception _ -> | exception _ ->
Log.debugf 1 (fun k -> Log.debugf 1 (fun k ->
k "(@[build-model.warn@ :no-entry-for %a@])" Term.pp t); k "(@[build-model.warn@ :no-entry-for %a@])" Term.pp t)
() (* TODO: warning? *)
| v_args -> | v_args ->
(* see if [v_args] already maps to a value *) (* see if [v_args] already maps to a value *)
let other_v = Model.get_fun_entry f v_args !m in let other_v = Model.get_fun_entry f v_args !m in