diff --git a/src/smt/model_builder.ml b/src/smt/model_builder.ml index 2b07f902..4905fed4 100644 --- a/src/smt/model_builder.ml +++ b/src/smt/model_builder.ml @@ -5,7 +5,7 @@ module TM = Term.Map type t = { tst: Term.store; mutable m: Term.t TM.t; - required: Term.t Queue.t; + required: (Term.t * Term.t Iter.t) Queue.t; gensym: Gensym.t; } @@ -18,25 +18,19 @@ let pp out (self : t) : unit = in Fmt.fprintf out "(@[model-builder@ :model %a@ :q (@[%a@])@])" (Util.pp_iter pp_kv) (TM.to_iter self.m) (Util.pp_iter T.pp) - (Iter.of_queue self.required) + (Iter.of_queue self.required |> Iter.map fst) let gensym self ~pre ~ty : Term.t = Gensym.fresh_term self.gensym ~pre ty -let rec pop_required (self : t) : _ option = - match Queue.take_opt self.required with - | None -> None - | Some t when TM.mem t self.m -> pop_required self - | Some t -> Some t - -let require_eval (self : t) t : unit = - if not @@ TM.mem t self.m then Queue.push t self.required +let require_eval (self : t) t ~cls : unit = + if not @@ TM.mem t self.m then Queue.push (t, cls) self.required let[@inline] mem self t : bool = TM.mem t self.m let add (self : t) ?(subs = []) t v : unit = if not @@ mem self t then ( self.m <- TM.add t v self.m; - List.iter (fun u -> require_eval self u) subs + List.iter (fun u -> require_eval self u ~cls:Iter.empty) subs ) type eval_cache = Term.Internal_.cache @@ -56,6 +50,15 @@ let eval ?(cache = Term.Internal_.create_cache 8) (self : t) (t : Term.t) = T.Internal_.replace_ ~cache self.tst ~recursive:true t ~f:(fun ~recurse:_ u -> TM.get u self.m) +let rec pop_required (self : t) : _ option = + match Queue.take_opt self.required with + | None -> None + | Some (t, cls) when TM.mem t self.m -> + (* make sure we also map [cls] to [t]'s value *) + cls (fun u -> add self u (TM.find t self.m)); + pop_required self + | Some pair -> Some pair + let to_map ?(cache = T.Internal_.create_cache 8) (self : t) : _ TM.t = (* ensure we evaluate each term only once by using a cache *) let map = diff --git a/src/smt/model_builder.mli b/src/smt/model_builder.mli index 23948711..0934d4d1 100644 --- a/src/smt/model_builder.mli +++ b/src/smt/model_builder.mli @@ -15,8 +15,9 @@ include Sidekick_sigs.PRINT with type t := t val create : Term.store -> t val mem : t -> Term.t -> bool -val require_eval : t -> Term.t -> unit -(** Require that this term gets a value. *) +val require_eval : t -> Term.t -> cls:Term.t Iter.t -> unit +(** Require that this term gets a value, and assign it to all terms + in the given class. *) val add : t -> ?subs:Term.t list -> Term.t -> value -> unit (** Add a value to the model. @@ -32,7 +33,7 @@ val create_cache : int -> eval_cache val eval : ?cache:eval_cache -> t -> Term.t -> value val eval_opt : ?cache:eval_cache -> t -> Term.t -> value option -val pop_required : t -> Term.t option +val pop_required : t -> (Term.t * Term.t Iter.t) option (** gives the next subterm that is required but has no value yet *) val to_map : ?cache:eval_cache -> t -> value Term.Map.t diff --git a/src/smt/solver_internal.ml b/src/smt/solver_internal.ml index bf5bfff1..fc2326d1 100644 --- a/src/smt/solver_internal.ml +++ b/src/smt/solver_internal.ml @@ -360,39 +360,44 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = (* require a value for each class that doesn't already have one *) CC.all_classes cc (fun repr -> + Log.debugf 5 (fun k -> + k "(@[model.fixpoint.require-cls@ %a@])" E_node.pp repr); let t = E_node.term repr in - MB.require_eval model t); + let ts = E_node.iter_class repr |> Iter.map E_node.term in + MB.require_eval model t ~cls:ts); (* now for the fixpoint. This is typically where composite theories such as arrays and datatypes contribute their skeleton values. *) - let rec compute_fixpoint () = - match MB.pop_required model with - | None -> () - | Some t when Term.is_pi (Term.ty t) -> - (* TODO: when we support lambdas? *) - () - | Some t -> - (* compute a value for [t] *) - Log.debugf 5 (fun k -> - k "(@[model.fixpoint.compute-for-required@ %a@])" Term.pp t); + let compute_fixpoint () = + let continue = ref true in + while !continue do + match MB.pop_required model with + | None -> continue := false + | Some (t, _cls) when Term.is_pi (Term.ty t) -> + (* TODO: when we support lambdas? *) + () + | Some (t, cls) -> + (* compute a value for [t] *) + Log.debugf 5 (fun k -> + k "(@[model.fixpoint.compute-for-required@ %a@])" Term.pp t); - (* try each model hook *) - let rec try_hooks_ = function - | [] -> - (* should not happen *) - Error.errorf "cannot build a value for term@ `%a`@ of type `%a`" - Term.pp t Term.pp (Term.ty t) - | h :: hooks -> - (match h self model t with - | None -> try_hooks_ hooks - | Some (v, subs) -> - MB.add model ~subs t v; - ()) - in + (* try each model hook *) + let rec try_hooks_ = function + | h :: hooks -> + (match h self model t with + | None -> try_hooks_ hooks + | Some (v, subs) -> + MB.add model ~subs t v; + cls (fun u -> MB.add model ~subs:[] u v); + ()) + | [] -> + (* should not happen *) + Error.errorf "cannot build a value for term@ `%a`@ of type `%a`" + Term.pp t Term.pp (Term.ty t) + in - try_hooks_ model_ask_hooks; - (* continue to next value *) - (compute_fixpoint [@tailcall]) () + try_hooks_ model_ask_hooks + done in compute_fixpoint (); @@ -402,7 +407,7 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = let eval (t : Term.t) : value option = try Some (Term.Map.find t map) with Not_found -> - MB.require_eval model t; + MB.require_eval model t ~cls:Iter.empty; compute_fixpoint (); MB.eval_opt ~cache model t in