From 4546b7cff2879427f18ff4d8cdd210d1bf6bafa9 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 15 Oct 2022 23:11:27 -0400 Subject: [PATCH] feat(smt): produce better model, with eval function --- src/smt/Sidekick_smt_solver.ml | 1 + src/smt/model.ml | 13 +++++++++++++ src/smt/model_builder.ml | 19 ++++++++++++++----- src/smt/model_builder.mli | 4 +++- src/smt/solver.ml | 10 +++------- src/smt/solver_internal.ml | 18 +++++++++++++++--- src/smt/solver_internal.mli | 2 +- src/smtlib/build_model.ml | 3 +-- 8 files changed, 51 insertions(+), 19 deletions(-) create mode 100644 src/smt/model.ml diff --git a/src/smt/Sidekick_smt_solver.ml b/src/smt/Sidekick_smt_solver.ml index 61914980..588c81f2 100644 --- a/src/smt/Sidekick_smt_solver.ml +++ b/src/smt/Sidekick_smt_solver.ml @@ -11,6 +11,7 @@ module Model_builder = Model_builder module Registry = Registry module Solver_internal = Solver_internal module Solver = Solver +module Model = Model module Theory = Theory module Theory_id = Theory_id module Preprocess = Preprocess diff --git a/src/smt/model.ml b/src/smt/model.ml new file mode 100644 index 00000000..259d9c53 --- /dev/null +++ b/src/smt/model.ml @@ -0,0 +1,13 @@ +(** SMT models. + + The solver models are partially evaluated; the frontend might ask + for values for terms not explicitly present in them. + +*) + +open Sigs + +type t = { eval: Term.t -> value option; map: value Term.Map.t } + +let eval (self : t) (t : Term.t) : value option = + try Some (Term.Map.find t self.map) with Not_found -> self.eval t diff --git a/src/smt/model_builder.ml b/src/smt/model_builder.ml index 28a3dfc6..2b07f902 100644 --- a/src/smt/model_builder.ml +++ b/src/smt/model_builder.ml @@ -41,17 +41,26 @@ let add (self : t) ?(subs = []) t v : unit = type eval_cache = Term.Internal_.cache +let create_cache = Term.Internal_.create_cache + +let eval_opt ?(cache = Term.Internal_.create_cache 8) (self : t) (t : Term.t) = + match TM.get t self.m with + | None -> None + | Some t -> + Some + (T.Internal_.replace_ ~cache self.tst ~recursive:true t + ~f:(fun ~recurse:_ u -> TM.get u self.m)) + let eval ?(cache = Term.Internal_.create_cache 8) (self : t) (t : Term.t) = let t = TM.get t self.m |> Option.value ~default:t in T.Internal_.replace_ ~cache self.tst ~recursive:true t ~f:(fun ~recurse:_ u -> TM.get u self.m) -let to_map (self : t) : _ TM.t = - (* ensure we evaluate each term only once *) - let cache = T.Internal_.create_cache 8 in - let m = +let to_map ?(cache = T.Internal_.create_cache 8) (self : t) : _ TM.t = + (* ensure we evaluate each term only once by using a cache *) + let map = TM.keys self.m |> Iter.map (fun t -> t, eval ~cache self t) |> Iter.fold (fun m (t, v) -> TM.add t v m) TM.empty in - m + map diff --git a/src/smt/model_builder.mli b/src/smt/model_builder.mli index 965ad71f..23948711 100644 --- a/src/smt/model_builder.mli +++ b/src/smt/model_builder.mli @@ -28,9 +28,11 @@ val gensym : t -> pre:string -> ty:Term.t -> Term.t type eval_cache = Term.Internal_.cache +val create_cache : int -> eval_cache val eval : ?cache:eval_cache -> t -> Term.t -> value +val eval_opt : ?cache:eval_cache -> t -> Term.t -> value option val pop_required : t -> Term.t option (** gives the next subterm that is required but has no value yet *) -val to_map : t -> Term.t Term.Map.t +val to_map : ?cache:eval_cache -> t -> value Term.Map.t diff --git a/src/smt/solver.ml b/src/smt/solver.ml index 5f05640b..49e297e7 100644 --- a/src/smt/solver.ml +++ b/src/smt/solver.ml @@ -219,12 +219,8 @@ let solve ?(on_exit = []) ?(on_progress = fun _ -> ()) not @@ Term.is_pi (Term.ty @@ E_node.term repr)) |> Iter.map (fun repr -> let v = - match - (* find value for this class *) - Iter.find_map - (fun en -> Term.Map.get (E_node.term en) m) - (E_node.iter_class repr) - with + (* find value for this class *) + match Model.eval m (E_node.term repr) with | None -> Error.errorf "(@[solver.mk-model.no-value-for-repr@ %a@ :ty %a@])" @@ -248,7 +244,7 @@ let solve ?(on_exit = []) ?(on_progress = fun _ -> ()) do_on_exit (); Sat { - get_value = (fun t -> Term.Map.get t m); + get_value = Model.eval m; iter_classes; eval_lit = (fun l -> diff --git a/src/smt/solver_internal.ml b/src/smt/solver_internal.ml index 34351ee6..bf5bfff1 100644 --- a/src/smt/solver_internal.ml +++ b/src/smt/solver_internal.ml @@ -53,7 +53,7 @@ type t = { mutable model_complete: model_completion_hook list; simp: Simplify.t; delayed_actions: delayed_action Queue.t; - mutable last_model: Term.t Term.Map.t option; + mutable last_model: Model.t option; mutable th_states: th_states; (** Set of theories *) mutable level: int; mutable complete: bool; @@ -327,12 +327,13 @@ let rec pop_lvls_theories_ n = function (** {2 Model construction and theory combination} *) (* make model from the congruence closure *) -let mk_model_ (self : t) (lits : lit Iter.t) : Term.t Term.Map.t = +let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = let@ () = Profile.with_ "smt-solver.mk-model" in Log.debug 1 "(smt.solver.mk-model)"; let module MB = Model_builder in let { cc; tst; model_ask = model_ask_hooks; model_complete; _ } = self in + let cache = Model_builder.create_cache 8 in let model = Model_builder.create tst in Model_builder.add model (Term.true_ tst) (Term.true_ tst); @@ -395,7 +396,18 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Term.t Term.Map.t = in compute_fixpoint (); - MB.to_map model + + let map = MB.to_map ~cache model in + + let eval (t : Term.t) : value option = + try Some (Term.Map.find t map) + with Not_found -> + MB.require_eval model t; + compute_fixpoint (); + MB.eval_opt ~cache model t + in + + { Model.map; eval } (* call congruence closure, perform the actions it scheduled *) let check_cc_with_acts_ (self : t) (acts : theory_actions) = diff --git a/src/smt/solver_internal.mli b/src/smt/solver_internal.mli index 09f0fe74..479d3b1c 100644 --- a/src/smt/solver_internal.mli +++ b/src/smt/solver_internal.mli @@ -270,7 +270,7 @@ val on_progress : t -> (unit, unit) Event.t val is_complete : t -> bool (** Are we still in a complete logic fragment? *) -val last_model : t -> Term.t Term.Map.t option +val last_model : t -> Model.t option (** {2 Delayed actions} *) diff --git a/src/smtlib/build_model.ml b/src/smtlib/build_model.ml index 2a5bb1ea..67f6c953 100644 --- a/src/smtlib/build_model.ml +++ b/src/smtlib/build_model.ml @@ -22,8 +22,7 @@ let build (self : t) (sat : Solver.sat_result) : Model.t = match List.map (fun t -> sat.get_value t |> Option.get) args with | exception _ -> Log.debugf 1 (fun k -> - k "(@[build-model.warn@ :no-entry-for %a@])" Term.pp t); - () (* TODO: warning? *) + k "(@[build-model.warn@ :no-entry-for %a@])" Term.pp t) | v_args -> (* see if [v_args] already maps to a value *) let other_v = Model.get_fun_entry f v_args !m in