diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index d0b572b0..4913da09 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -532,6 +532,16 @@ module type SOLVER_INTERNAL = sig *) val add_preprocess : t -> preprocess_hook -> unit + + (** {3 Model production} *) + + type model_hook = + recurse:(t -> CC.N.t -> term) -> + t -> CC.N.t -> term option + (** A model-production hook. It takes the solver, a class, and returns + a term for this class. *) + + val add_model_hook : t -> model_hook -> unit end (** Public view of the solver *) @@ -616,6 +626,7 @@ module type SOLVER = sig val sign : t -> bool end + (* FIXME: just use terms instead? *) (** {3 Semantic values} *) module Value : sig type t diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index 49f5a0fa..a180c7f4 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -170,6 +170,7 @@ module Make(A : ARG) mutable on_progress: unit -> unit; simp: Simplify.t; mutable preprocess: preprocess_hook list; + mutable mk_model: model_hook list; preprocess_cache: Term.t Term.Tbl.t; mutable th_states : th_states; (** Set of theories *) mutable on_partial_check: (t -> actions -> lit Iter.t -> unit) list; @@ -183,6 +184,10 @@ module Make(A : ARG) add_clause:(lit list -> unit) -> term -> term option + and model_hook = + recurse:(t -> CC.N.t -> term) -> + t -> CC.N.t -> term option + type solver = t module Formula = struct @@ -206,6 +211,7 @@ module Make(A : ARG) let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f let add_preprocess self f = self.preprocess <- f :: self.preprocess + let add_model_hook self f = self.mk_model <- f :: self.mk_model let push_decision (_self:t) (acts:actions) (lit:lit) : unit = let sign = Lit.sign lit in @@ -384,17 +390,6 @@ module Make(A : ARG) let[@inline] final_check (self:t) (acts:_ Msat.acts) : unit = check_ ~final:true self acts - (* TODO - let mk_model (self:t) lits : Model.t = - let m = - Iter.fold - (fun m (Th_state ((module Th),st)) -> Th.mk_model st lits m) - Model.empty (theories self) - in - (* now complete model using CC *) - CC.mk_model (cc self) m - *) - let create ~stat (tst:Term.state) (ty_st:Ty.state) () : t = let rec self = { tst; @@ -408,6 +403,7 @@ module Make(A : ARG) simp=Simplify.create tst ty_st; on_progress=(fun () -> ()); preprocess=[]; + mk_model=[]; preprocess_cache=Term.Tbl.create 32; count_axiom = Stat.mk_int stat "solver.th-axioms"; count_preprocess_clause = Stat.mk_int stat "solver.preprocess-clause"; @@ -613,27 +609,66 @@ module Make(A : ARG) let add_clause_l self c = add_clause self (IArray.of_list c) + + (* TODO + let mk_model (self:t) lits : Model.t = + let m = + Iter.fold + (fun m (Th_state ((module Th),st)) -> Th.mk_model st lits m) + Model.empty (theories self) + in + (* now complete model using CC *) + CC.mk_model (cc self) m + *) + let mk_model (self:t) (lits:lit Iter.t) : Model.t = Log.debug 1 "(smt.solver.mk-model)"; Profile.with_ "msat-solver.mk-model" @@ fun () -> let module M = Term.Tbl in - let m = M.create 128 in - let tst = self.si.tst in - (* first, add all boolean *) + let model = M.create 128 in + let {Solver_internal.tst; cc=lazy cc; mk_model=model_hooks; _} = self.si in + + (* first, add all literals to the model using the given propositional model + [lits]. *) lits (fun {Lit.lit_term=t;lit_sign=sign} -> - M.replace m t (Term.bool tst sign)); - (* then add CC classes *) + M.replace model t (Term.bool tst sign)); + + (* compute a value for [n]. *) + let rec val_for_class (n:N.t) : term = + let repr = CC.find cc n in + + (* see if a value is found already (always the case if it's a boolean) *) + match M.get model (N.term repr) with + | Some t_val -> t_val + | None -> + + (* try each model hook *) + let rec aux = function + | [] -> N.term repr + | h :: hooks -> + begin match h ~recurse:(fun _ n -> val_for_class n) self.si repr with + | None -> aux hooks + | Some t -> t + end + in + + let t_val = aux model_hooks in + M.replace model (N.term repr) t_val; (* be sure to cache the value *) + t_val + in + + (* map terms of each CC class to the value computed for their class. *) Solver_internal.CC.all_classes (Solver_internal.cc self.si) (fun repr -> + let t_val = val_for_class repr in (* value for this class *) N.iter_class repr (fun u -> let t_u = N.term u in - if not (N.equal repr u && M.mem m t_u) then ( - M.replace m t_u (N.term repr); + if not (N.equal u repr) && not (Term.equal t_u t_val) then ( + M.replace model t_u t_val; ))); - (* TODO: theory combination *) - Model.Map m + Model.Map model let solve ?(on_exit=[]) ?(check=true) ?(on_progress=fun _ -> ()) ~assumptions (self:t) : res =