diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 0011e306..79adf3f9 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -1061,6 +1061,15 @@ module Make (A: CC_ARG) cc.model_mode <- false; ) + let get_model_for_each_class self : _ Iter.t = + assert self.model_mode; + all_classes self + |> Iter.filter_map + (fun repr -> + match T_b_tbl.get self.t_to_val repr.n_term with + | Some (_,v) -> Some (repr, N.iter_class repr, v) + | None -> None) + (* assert that this boolean literal holds. if a lit is [= a b], merge [a] and [b]; otherwise merge the atom with true/false *) diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 837a5b1b..ec74c46c 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -735,6 +735,9 @@ module type CC_S = sig val with_model_mode : t -> (unit -> 'a) -> 'a (** Enter model combination mode. *) + val get_model_for_each_class : t -> (repr * N.t Iter.t * value) Iter.t + (** In model combination mode, obtain classes with their values. *) + val check : t -> actions -> unit (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. Will use the {!actions} to propagate literals, declare conflicts, etc. *) diff --git a/src/smt-solver/Sidekick_smt_solver.ml b/src/smt-solver/Sidekick_smt_solver.ml index 0549c10c..a667129d 100644 --- a/src/smt-solver/Sidekick_smt_solver.ml +++ b/src/smt-solver/Sidekick_smt_solver.ml @@ -142,6 +142,28 @@ module Make(A : ARG) module CC = Sidekick_cc.Make(CC_actions) module N = CC.N + module Model = struct + type t = + | Empty + | Map of term Term.Tbl.t + let empty = Empty + let mem = function + | Empty -> fun _ -> false + | Map tbl -> Term.Tbl.mem tbl + let find = function + | Empty -> fun _ -> None + | Map tbl -> Term.Tbl.get tbl + let eval = find + let pp out = function + | Empty -> Fmt.string out "(model)" + | Map tbl -> + let pp_pair out (t,v) = + Fmt.fprintf out "(@[<1>%a@ := %a@])" Term.pp t Term.pp v + in + Fmt.fprintf out "(@[model@ %a@])" + (Util.pp_iter pp_pair) (Term.Tbl.to_iter tbl) + end + (* delayed actions. We avoid doing them on the spot because, when triggered by a theory, they might go back to the theory "too early". *) type delayed_action = @@ -285,6 +307,7 @@ module Make(A : ARG) simp: Simplify.t; preprocessed: unit Term.Tbl.t; delayed_actions: delayed_action Queue.t; + mutable last_model: Model.t option; mutable t_defs : (term*term) list; (* term definitions *) mutable th_states : th_states; (** Set of theories *) @@ -560,14 +583,95 @@ module Make(A : ARG) push_lvl_ self.th_states let pop_levels (self:t) n : unit = + self.last_model <- None; self.level <- self.level - n; CC.pop_levels (cc self) n; pop_lvls_ n self.th_states + (** {2 Model construction and theory combination} *) + + (* make model from the congruence closure *) + let mk_model_ (self:t) : Model.t = + Log.debug 1 "(smt.solver.mk-model)"; + Profile.with_ "smt-solver.mk-model" @@ fun () -> + let module M = Term.Tbl in + let {cc=lazy cc; + model_ask=model_ask_hooks; model_complete; _} = self in + + let model = M.create 128 in + (* populate with information from the CC *) + CC.get_model_for_each_class cc + (fun (_, ts, v) -> + Iter.iter + (fun n -> + let t = N.term n in + M.replace model t v) ts); + + (* complete model with theory specific values *) + let complete_with f = + f self + ~add:(fun t u -> + if not (M.mem model t) then ( + Log.debugf 20 (fun k->k "(@[smt.model-complete@ %a@ :with-val %a@])" Term.pp t Term.pp u); + M.replace model t u + )); + in + List.iter complete_with model_complete; + + (* compute a value for [n]. *) + let rec val_for_class (n:N.t) : term = + Log.debugf 5 (fun k->k "val-for-term %a" N.pp n); + let repr = CC.find cc n in + Log.debugf 5 (fun k->k "val-for-term.repr %a" N.pp repr); + + (* see if a value is found already (always the case if it's a boolean) *) + match M.get model (N.term repr) with + | Some t_val -> + Log.debugf 5 (fun k->k "cached val is %a" Term.pp t_val); + t_val + | None -> + + (* try each model hook *) + let rec try_hooks_ = function + | [] -> N.term repr + | h :: hooks -> + begin match h ~recurse:(fun _ n -> val_for_class n) self repr with + | None -> try_hooks_ hooks + | Some t -> t + end + in + + let t_val = + match + (* look for a value in the model for any term in the class *) + N.iter_class repr + |> Iter.find_map (fun n -> M.get model (N.term n)) + with + | Some v -> v + | None -> try_hooks_ model_ask_hooks + in + + M.replace model (N.term repr) t_val; (* be sure to cache the value *) + Log.debugf 5 (fun k->k "val is %a" Term.pp t_val); + t_val + in + + (* map terms of each CC class to the value computed for their class. *) + CC.all_classes cc + (fun repr -> + let t_val = val_for_class repr in (* value for this class *) + N.iter_class repr + (fun u -> + let t_u = N.term u in + if not (N.equal u repr) && not (Term.equal t_u t_val) then ( + M.replace model t_u t_val; + ))); + Model.Map model + (* do theory combination using the congruence closure. Each theory can merge classes, *) let check_th_combination_ - (self:t) (acts:theory_actions) : (unit, th_combination_conflict) result = + (self:t) (acts:theory_actions) : (Model.t, th_combination_conflict) result = let cc = cc self in (* entier model mode, disabling most of congruence closure *) CC.with_model_mode cc @@ fun () -> @@ -579,16 +683,17 @@ module Make(A : ARG) CC.set_model_value cc t v in - (* obtain classes of equal terms from the hook, and merge them *) - let add_th_equalities f : unit = + (* obtain assignments from the hook, and communicate them to the CC *) + let add_th_values f : unit = let vals = f self acts in Iter.iter set_val vals in try - List.iter add_th_equalities self.on_th_combination; + List.iter add_th_values self.on_th_combination; CC.check cc acts; - Ok () + let m = mk_model_ self in + Ok m with Semantic_conflict c -> Error c (* handle a literal assumed by the SAT solver *) @@ -608,11 +713,9 @@ module Make(A : ARG) List.iter (fun f -> f self acts lits) self.on_final_check; CC.check cc acts; - let new_work = has_delayed_actions self in - (* do actual theory combination if nothing changed by pure "final check" *) - if not new_work then ( - match check_th_combination_ self acts with - | Ok () -> () + begin match check_th_combination_ self acts with + | Ok m -> + self.last_model <- Some m | Error {lits; semantic} -> (* bad model, we add a clause to remove it *) @@ -642,7 +745,7 @@ module Make(A : ARG) (Util.pp_list Lit.pp) c); (* will add a delayed action *) add_clause_temp self acts c pr; - ); + end; Perform_delayed_th.top self acts; ) else ( @@ -691,6 +794,7 @@ module Make(A : ARG) th_states=Ths_nil; stat; simp=Simplify.create tst ty_st ~proof; + last_model=None; on_progress=(fun () -> ()); preprocess=[]; model_ask=[]; @@ -745,28 +849,6 @@ module Make(A : ARG) | U_asked_to_stop -> Fmt.string out {|"asked to stop by callback"|} end [@@ocaml.warning "-37"] - module Model = struct - type t = - | Empty - | Map of term Term.Tbl.t - let empty = Empty - let mem = function - | Empty -> fun _ -> false - | Map tbl -> Term.Tbl.mem tbl - let find = function - | Empty -> fun _ -> None - | Map tbl -> Term.Tbl.get tbl - let eval = find - let pp out = function - | Empty -> Fmt.string out "(model)" - | Map tbl -> - let pp_pair out (t,v) = - Fmt.fprintf out "(@[<1>%a@ := %a@])" Term.pp t Term.pp v - in - Fmt.fprintf out "(@[model@ %a@])" - (Util.pp_iter pp_pair) (Term.Tbl.to_iter tbl) - end - type res = | Sat of Model.t | Unsat of { @@ -901,82 +983,6 @@ module Make(A : ARG) let assert_term self t = assert_terms self [t] - let mk_model (self:t) (lits:lit Iter.t) : Model.t = - Log.debug 1 "(smt.solver.mk-model)"; - Profile.with_ "smt-solver.mk-model" @@ fun () -> - let module M = Term.Tbl in - let model = M.create 128 in - let {Solver_internal.tst; cc=lazy cc; - model_ask=model_ask_hooks; model_complete; _} = self.si in - - (* first, add all literals to the model using the given propositional model - [lits]. *) - lits - (fun lit -> - let t, sign = Lit.signed_term lit in - M.replace model t (Term.bool tst sign)); - - (* complete model with theory specific values *) - let complete_with f = - f self.si - ~add:(fun t u -> - if not (M.mem model t) then ( - Log.debugf 20 (fun k->k "(@[smt.model-complete@ %a@ :with-val %a@])" Term.pp t Term.pp u); - M.replace model t u - )); - in - List.iter complete_with model_complete; - - (* compute a value for [n]. *) - let rec val_for_class (n:N.t) : term = - Log.debugf 5 (fun k->k "val-for-term %a" N.pp n); - let repr = CC.find cc n in - Log.debugf 5 (fun k->k "val-for-term.repr %a" N.pp repr); - - (* see if a value is found already (always the case if it's a boolean) *) - match M.get model (N.term repr) with - | Some t_val -> - Log.debugf 5 (fun k->k "cached val is %a" Term.pp t_val); - t_val - | None -> - - (* try each model hook *) - let rec try_hooks_ = function - | [] -> N.term repr - | h :: hooks -> - begin match h ~recurse:(fun _ n -> val_for_class n) self.si repr with - | None -> try_hooks_ hooks - | Some t -> t - end - in - - let t_val = - match - (* look for a value in the model for any term in the class *) - N.iter_class repr - |> Iter.find_map (fun n -> M.get model (N.term n)) - with - | Some v -> v - | None -> try_hooks_ model_ask_hooks - in - - M.replace model (N.term repr) t_val; (* be sure to cache the value *) - Log.debugf 5 (fun k->k "val is %a" Term.pp t_val); - t_val - in - - (* map terms of each CC class to the value computed for their class. *) - Solver_internal.CC.all_classes (Solver_internal.cc self.si) - (fun repr -> - let t_val = val_for_class repr in (* value for this class *) - N.iter_class repr - (fun u -> - let t_u = N.term u in - if not (N.equal u repr) && not (Term.equal t_u t_val) then ( - M.replace model t_u t_val; - ))); - Model.Map model - exception Resource_exhausted = Sidekick_sat.Resource_exhausted let solve @@ -1011,15 +1017,17 @@ module Make(A : ARG) | Sat_solver.Sat (module SAT) -> Log.debug 1 "(sidekick.smt-solver: SAT)"; - Log.debugf 50 + Log.debugf 5 (fun k-> let ppc out n = Fmt.fprintf out "{@[class@ %a@]}" (Util.pp_iter N.pp) (N.iter_class n) in k "(@[sidekick.smt-solver.classes@ (@[%a@])@])" (Util.pp_iter ppc) (CC.all_classes @@ Solver_internal.cc self.si)); - let _lits f = SAT.iter_trail f in - let m = mk_model self _lits in + let m = match self.si.last_model with + | Some m -> m + | None -> assert false + in (* TODO: check model *) let _ = check in