diff --git a/src/smt-solver/Sidekick_smt_solver.ml b/src/smt-solver/Sidekick_smt_solver.ml index 62034103..8e17923d 100644 --- a/src/smt-solver/Sidekick_smt_solver.ml +++ b/src/smt-solver/Sidekick_smt_solver.ml @@ -838,7 +838,8 @@ module Make(A : ARG) Profile.with_ "smt-solver.mk-model" @@ fun () -> let module M = Term.Tbl in let model = M.create 128 in - let {Solver_internal.tst; cc=lazy cc; model_ask; model_complete; _} = self.si in + let {Solver_internal.tst; cc=lazy cc; + model_ask=model_ask_hooks; model_complete; _} = self.si in (* first, add all literals to the model using the given propositional model [lits]. *) @@ -867,16 +868,25 @@ module Make(A : ARG) | None -> (* try each model hook *) - let rec aux = function + let rec try_hooks_ = function | [] -> N.term repr | h :: hooks -> begin match h ~recurse:(fun _ n -> val_for_class n) self.si repr with - | None -> aux hooks + | None -> try_hooks_ hooks | Some t -> t end in - let t_val = aux model_ask in + let t_val = + match + (* look for a value in the model for any term in the class *) + N.iter_class repr + |> Iter.find_map (fun n -> M.get model (N.term n)) + with + | Some v -> v + | None -> try_hooks_ model_ask_hooks + in + M.replace model (N.term repr) t_val; (* be sure to cache the value *) t_val in