mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 11:15:43 -05:00
feat: construct model from congruence closure after th-combination
we already obtain a model from theories, and saturate the congruence closure with it, it's a shame not to use it.
This commit is contained in:
parent
a388c96fe3
commit
1946a5e7cf
3 changed files with 132 additions and 112 deletions
|
|
@ -1061,6 +1061,15 @@ module Make (A: CC_ARG)
|
||||||
cc.model_mode <- false;
|
cc.model_mode <- false;
|
||||||
)
|
)
|
||||||
|
|
||||||
|
let get_model_for_each_class self : _ Iter.t =
|
||||||
|
assert self.model_mode;
|
||||||
|
all_classes self
|
||||||
|
|> Iter.filter_map
|
||||||
|
(fun repr ->
|
||||||
|
match T_b_tbl.get self.t_to_val repr.n_term with
|
||||||
|
| Some (_,v) -> Some (repr, N.iter_class repr, v)
|
||||||
|
| None -> None)
|
||||||
|
|
||||||
(* assert that this boolean literal holds.
|
(* assert that this boolean literal holds.
|
||||||
if a lit is [= a b], merge [a] and [b];
|
if a lit is [= a b], merge [a] and [b];
|
||||||
otherwise merge the atom with true/false *)
|
otherwise merge the atom with true/false *)
|
||||||
|
|
|
||||||
|
|
@ -735,6 +735,9 @@ module type CC_S = sig
|
||||||
val with_model_mode : t -> (unit -> 'a) -> 'a
|
val with_model_mode : t -> (unit -> 'a) -> 'a
|
||||||
(** Enter model combination mode. *)
|
(** Enter model combination mode. *)
|
||||||
|
|
||||||
|
val get_model_for_each_class : t -> (repr * N.t Iter.t * value) Iter.t
|
||||||
|
(** In model combination mode, obtain classes with their values. *)
|
||||||
|
|
||||||
val check : t -> actions -> unit
|
val check : t -> actions -> unit
|
||||||
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
|
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
|
||||||
Will use the {!actions} to propagate literals, declare conflicts, etc. *)
|
Will use the {!actions} to propagate literals, declare conflicts, etc. *)
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,28 @@ module Make(A : ARG)
|
||||||
module CC = Sidekick_cc.Make(CC_actions)
|
module CC = Sidekick_cc.Make(CC_actions)
|
||||||
module N = CC.N
|
module N = CC.N
|
||||||
|
|
||||||
|
module Model = struct
|
||||||
|
type t =
|
||||||
|
| Empty
|
||||||
|
| Map of term Term.Tbl.t
|
||||||
|
let empty = Empty
|
||||||
|
let mem = function
|
||||||
|
| Empty -> fun _ -> false
|
||||||
|
| Map tbl -> Term.Tbl.mem tbl
|
||||||
|
let find = function
|
||||||
|
| Empty -> fun _ -> None
|
||||||
|
| Map tbl -> Term.Tbl.get tbl
|
||||||
|
let eval = find
|
||||||
|
let pp out = function
|
||||||
|
| Empty -> Fmt.string out "(model)"
|
||||||
|
| Map tbl ->
|
||||||
|
let pp_pair out (t,v) =
|
||||||
|
Fmt.fprintf out "(@[<1>%a@ := %a@])" Term.pp t Term.pp v
|
||||||
|
in
|
||||||
|
Fmt.fprintf out "(@[<hv>model@ %a@])"
|
||||||
|
(Util.pp_iter pp_pair) (Term.Tbl.to_iter tbl)
|
||||||
|
end
|
||||||
|
|
||||||
(* delayed actions. We avoid doing them on the spot because, when
|
(* delayed actions. We avoid doing them on the spot because, when
|
||||||
triggered by a theory, they might go back to the theory "too early". *)
|
triggered by a theory, they might go back to the theory "too early". *)
|
||||||
type delayed_action =
|
type delayed_action =
|
||||||
|
|
@ -285,6 +307,7 @@ module Make(A : ARG)
|
||||||
simp: Simplify.t;
|
simp: Simplify.t;
|
||||||
preprocessed: unit Term.Tbl.t;
|
preprocessed: unit Term.Tbl.t;
|
||||||
delayed_actions: delayed_action Queue.t;
|
delayed_actions: delayed_action Queue.t;
|
||||||
|
mutable last_model: Model.t option;
|
||||||
|
|
||||||
mutable t_defs : (term*term) list; (* term definitions *)
|
mutable t_defs : (term*term) list; (* term definitions *)
|
||||||
mutable th_states : th_states; (** Set of theories *)
|
mutable th_states : th_states; (** Set of theories *)
|
||||||
|
|
@ -560,14 +583,95 @@ module Make(A : ARG)
|
||||||
push_lvl_ self.th_states
|
push_lvl_ self.th_states
|
||||||
|
|
||||||
let pop_levels (self:t) n : unit =
|
let pop_levels (self:t) n : unit =
|
||||||
|
self.last_model <- None;
|
||||||
self.level <- self.level - n;
|
self.level <- self.level - n;
|
||||||
CC.pop_levels (cc self) n;
|
CC.pop_levels (cc self) n;
|
||||||
pop_lvls_ n self.th_states
|
pop_lvls_ n self.th_states
|
||||||
|
|
||||||
|
(** {2 Model construction and theory combination} *)
|
||||||
|
|
||||||
|
(* make model from the congruence closure *)
|
||||||
|
let mk_model_ (self:t) : Model.t =
|
||||||
|
Log.debug 1 "(smt.solver.mk-model)";
|
||||||
|
Profile.with_ "smt-solver.mk-model" @@ fun () ->
|
||||||
|
let module M = Term.Tbl in
|
||||||
|
let {cc=lazy cc;
|
||||||
|
model_ask=model_ask_hooks; model_complete; _} = self in
|
||||||
|
|
||||||
|
let model = M.create 128 in
|
||||||
|
(* populate with information from the CC *)
|
||||||
|
CC.get_model_for_each_class cc
|
||||||
|
(fun (_, ts, v) ->
|
||||||
|
Iter.iter
|
||||||
|
(fun n ->
|
||||||
|
let t = N.term n in
|
||||||
|
M.replace model t v) ts);
|
||||||
|
|
||||||
|
(* complete model with theory specific values *)
|
||||||
|
let complete_with f =
|
||||||
|
f self
|
||||||
|
~add:(fun t u ->
|
||||||
|
if not (M.mem model t) then (
|
||||||
|
Log.debugf 20 (fun k->k "(@[smt.model-complete@ %a@ :with-val %a@])" Term.pp t Term.pp u);
|
||||||
|
M.replace model t u
|
||||||
|
));
|
||||||
|
in
|
||||||
|
List.iter complete_with model_complete;
|
||||||
|
|
||||||
|
(* compute a value for [n]. *)
|
||||||
|
let rec val_for_class (n:N.t) : term =
|
||||||
|
Log.debugf 5 (fun k->k "val-for-term %a" N.pp n);
|
||||||
|
let repr = CC.find cc n in
|
||||||
|
Log.debugf 5 (fun k->k "val-for-term.repr %a" N.pp repr);
|
||||||
|
|
||||||
|
(* see if a value is found already (always the case if it's a boolean) *)
|
||||||
|
match M.get model (N.term repr) with
|
||||||
|
| Some t_val ->
|
||||||
|
Log.debugf 5 (fun k->k "cached val is %a" Term.pp t_val);
|
||||||
|
t_val
|
||||||
|
| None ->
|
||||||
|
|
||||||
|
(* try each model hook *)
|
||||||
|
let rec try_hooks_ = function
|
||||||
|
| [] -> N.term repr
|
||||||
|
| h :: hooks ->
|
||||||
|
begin match h ~recurse:(fun _ n -> val_for_class n) self repr with
|
||||||
|
| None -> try_hooks_ hooks
|
||||||
|
| Some t -> t
|
||||||
|
end
|
||||||
|
in
|
||||||
|
|
||||||
|
let t_val =
|
||||||
|
match
|
||||||
|
(* look for a value in the model for any term in the class *)
|
||||||
|
N.iter_class repr
|
||||||
|
|> Iter.find_map (fun n -> M.get model (N.term n))
|
||||||
|
with
|
||||||
|
| Some v -> v
|
||||||
|
| None -> try_hooks_ model_ask_hooks
|
||||||
|
in
|
||||||
|
|
||||||
|
M.replace model (N.term repr) t_val; (* be sure to cache the value *)
|
||||||
|
Log.debugf 5 (fun k->k "val is %a" Term.pp t_val);
|
||||||
|
t_val
|
||||||
|
in
|
||||||
|
|
||||||
|
(* map terms of each CC class to the value computed for their class. *)
|
||||||
|
CC.all_classes cc
|
||||||
|
(fun repr ->
|
||||||
|
let t_val = val_for_class repr in (* value for this class *)
|
||||||
|
N.iter_class repr
|
||||||
|
(fun u ->
|
||||||
|
let t_u = N.term u in
|
||||||
|
if not (N.equal u repr) && not (Term.equal t_u t_val) then (
|
||||||
|
M.replace model t_u t_val;
|
||||||
|
)));
|
||||||
|
Model.Map model
|
||||||
|
|
||||||
(* do theory combination using the congruence closure. Each theory
|
(* do theory combination using the congruence closure. Each theory
|
||||||
can merge classes, *)
|
can merge classes, *)
|
||||||
let check_th_combination_
|
let check_th_combination_
|
||||||
(self:t) (acts:theory_actions) : (unit, th_combination_conflict) result =
|
(self:t) (acts:theory_actions) : (Model.t, th_combination_conflict) result =
|
||||||
let cc = cc self in
|
let cc = cc self in
|
||||||
(* entier model mode, disabling most of congruence closure *)
|
(* entier model mode, disabling most of congruence closure *)
|
||||||
CC.with_model_mode cc @@ fun () ->
|
CC.with_model_mode cc @@ fun () ->
|
||||||
|
|
@ -579,16 +683,17 @@ module Make(A : ARG)
|
||||||
CC.set_model_value cc t v
|
CC.set_model_value cc t v
|
||||||
in
|
in
|
||||||
|
|
||||||
(* obtain classes of equal terms from the hook, and merge them *)
|
(* obtain assignments from the hook, and communicate them to the CC *)
|
||||||
let add_th_equalities f : unit =
|
let add_th_values f : unit =
|
||||||
let vals = f self acts in
|
let vals = f self acts in
|
||||||
Iter.iter set_val vals
|
Iter.iter set_val vals
|
||||||
in
|
in
|
||||||
|
|
||||||
try
|
try
|
||||||
List.iter add_th_equalities self.on_th_combination;
|
List.iter add_th_values self.on_th_combination;
|
||||||
CC.check cc acts;
|
CC.check cc acts;
|
||||||
Ok ()
|
let m = mk_model_ self in
|
||||||
|
Ok m
|
||||||
with Semantic_conflict c -> Error c
|
with Semantic_conflict c -> Error c
|
||||||
|
|
||||||
(* handle a literal assumed by the SAT solver *)
|
(* handle a literal assumed by the SAT solver *)
|
||||||
|
|
@ -608,11 +713,9 @@ module Make(A : ARG)
|
||||||
List.iter (fun f -> f self acts lits) self.on_final_check;
|
List.iter (fun f -> f self acts lits) self.on_final_check;
|
||||||
CC.check cc acts;
|
CC.check cc acts;
|
||||||
|
|
||||||
let new_work = has_delayed_actions self in
|
begin match check_th_combination_ self acts with
|
||||||
(* do actual theory combination if nothing changed by pure "final check" *)
|
| Ok m ->
|
||||||
if not new_work then (
|
self.last_model <- Some m
|
||||||
match check_th_combination_ self acts with
|
|
||||||
| Ok () -> ()
|
|
||||||
|
|
||||||
| Error {lits; semantic} ->
|
| Error {lits; semantic} ->
|
||||||
(* bad model, we add a clause to remove it *)
|
(* bad model, we add a clause to remove it *)
|
||||||
|
|
@ -642,7 +745,7 @@ module Make(A : ARG)
|
||||||
(Util.pp_list Lit.pp) c);
|
(Util.pp_list Lit.pp) c);
|
||||||
(* will add a delayed action *)
|
(* will add a delayed action *)
|
||||||
add_clause_temp self acts c pr;
|
add_clause_temp self acts c pr;
|
||||||
);
|
end;
|
||||||
|
|
||||||
Perform_delayed_th.top self acts;
|
Perform_delayed_th.top self acts;
|
||||||
) else (
|
) else (
|
||||||
|
|
@ -691,6 +794,7 @@ module Make(A : ARG)
|
||||||
th_states=Ths_nil;
|
th_states=Ths_nil;
|
||||||
stat;
|
stat;
|
||||||
simp=Simplify.create tst ty_st ~proof;
|
simp=Simplify.create tst ty_st ~proof;
|
||||||
|
last_model=None;
|
||||||
on_progress=(fun () -> ());
|
on_progress=(fun () -> ());
|
||||||
preprocess=[];
|
preprocess=[];
|
||||||
model_ask=[];
|
model_ask=[];
|
||||||
|
|
@ -745,28 +849,6 @@ module Make(A : ARG)
|
||||||
| U_asked_to_stop -> Fmt.string out {|"asked to stop by callback"|}
|
| U_asked_to_stop -> Fmt.string out {|"asked to stop by callback"|}
|
||||||
end [@@ocaml.warning "-37"]
|
end [@@ocaml.warning "-37"]
|
||||||
|
|
||||||
module Model = struct
|
|
||||||
type t =
|
|
||||||
| Empty
|
|
||||||
| Map of term Term.Tbl.t
|
|
||||||
let empty = Empty
|
|
||||||
let mem = function
|
|
||||||
| Empty -> fun _ -> false
|
|
||||||
| Map tbl -> Term.Tbl.mem tbl
|
|
||||||
let find = function
|
|
||||||
| Empty -> fun _ -> None
|
|
||||||
| Map tbl -> Term.Tbl.get tbl
|
|
||||||
let eval = find
|
|
||||||
let pp out = function
|
|
||||||
| Empty -> Fmt.string out "(model)"
|
|
||||||
| Map tbl ->
|
|
||||||
let pp_pair out (t,v) =
|
|
||||||
Fmt.fprintf out "(@[<1>%a@ := %a@])" Term.pp t Term.pp v
|
|
||||||
in
|
|
||||||
Fmt.fprintf out "(@[<hv>model@ %a@])"
|
|
||||||
(Util.pp_iter pp_pair) (Term.Tbl.to_iter tbl)
|
|
||||||
end
|
|
||||||
|
|
||||||
type res =
|
type res =
|
||||||
| Sat of Model.t
|
| Sat of Model.t
|
||||||
| Unsat of {
|
| Unsat of {
|
||||||
|
|
@ -901,82 +983,6 @@ module Make(A : ARG)
|
||||||
|
|
||||||
let assert_term self t = assert_terms self [t]
|
let assert_term self t = assert_terms self [t]
|
||||||
|
|
||||||
let mk_model (self:t) (lits:lit Iter.t) : Model.t =
|
|
||||||
Log.debug 1 "(smt.solver.mk-model)";
|
|
||||||
Profile.with_ "smt-solver.mk-model" @@ fun () ->
|
|
||||||
let module M = Term.Tbl in
|
|
||||||
let model = M.create 128 in
|
|
||||||
let {Solver_internal.tst; cc=lazy cc;
|
|
||||||
model_ask=model_ask_hooks; model_complete; _} = self.si in
|
|
||||||
|
|
||||||
(* first, add all literals to the model using the given propositional model
|
|
||||||
[lits]. *)
|
|
||||||
lits
|
|
||||||
(fun lit ->
|
|
||||||
let t, sign = Lit.signed_term lit in
|
|
||||||
M.replace model t (Term.bool tst sign));
|
|
||||||
|
|
||||||
(* complete model with theory specific values *)
|
|
||||||
let complete_with f =
|
|
||||||
f self.si
|
|
||||||
~add:(fun t u ->
|
|
||||||
if not (M.mem model t) then (
|
|
||||||
Log.debugf 20 (fun k->k "(@[smt.model-complete@ %a@ :with-val %a@])" Term.pp t Term.pp u);
|
|
||||||
M.replace model t u
|
|
||||||
));
|
|
||||||
in
|
|
||||||
List.iter complete_with model_complete;
|
|
||||||
|
|
||||||
(* compute a value for [n]. *)
|
|
||||||
let rec val_for_class (n:N.t) : term =
|
|
||||||
Log.debugf 5 (fun k->k "val-for-term %a" N.pp n);
|
|
||||||
let repr = CC.find cc n in
|
|
||||||
Log.debugf 5 (fun k->k "val-for-term.repr %a" N.pp repr);
|
|
||||||
|
|
||||||
(* see if a value is found already (always the case if it's a boolean) *)
|
|
||||||
match M.get model (N.term repr) with
|
|
||||||
| Some t_val ->
|
|
||||||
Log.debugf 5 (fun k->k "cached val is %a" Term.pp t_val);
|
|
||||||
t_val
|
|
||||||
| None ->
|
|
||||||
|
|
||||||
(* try each model hook *)
|
|
||||||
let rec try_hooks_ = function
|
|
||||||
| [] -> N.term repr
|
|
||||||
| h :: hooks ->
|
|
||||||
begin match h ~recurse:(fun _ n -> val_for_class n) self.si repr with
|
|
||||||
| None -> try_hooks_ hooks
|
|
||||||
| Some t -> t
|
|
||||||
end
|
|
||||||
in
|
|
||||||
|
|
||||||
let t_val =
|
|
||||||
match
|
|
||||||
(* look for a value in the model for any term in the class *)
|
|
||||||
N.iter_class repr
|
|
||||||
|> Iter.find_map (fun n -> M.get model (N.term n))
|
|
||||||
with
|
|
||||||
| Some v -> v
|
|
||||||
| None -> try_hooks_ model_ask_hooks
|
|
||||||
in
|
|
||||||
|
|
||||||
M.replace model (N.term repr) t_val; (* be sure to cache the value *)
|
|
||||||
Log.debugf 5 (fun k->k "val is %a" Term.pp t_val);
|
|
||||||
t_val
|
|
||||||
in
|
|
||||||
|
|
||||||
(* map terms of each CC class to the value computed for their class. *)
|
|
||||||
Solver_internal.CC.all_classes (Solver_internal.cc self.si)
|
|
||||||
(fun repr ->
|
|
||||||
let t_val = val_for_class repr in (* value for this class *)
|
|
||||||
N.iter_class repr
|
|
||||||
(fun u ->
|
|
||||||
let t_u = N.term u in
|
|
||||||
if not (N.equal u repr) && not (Term.equal t_u t_val) then (
|
|
||||||
M.replace model t_u t_val;
|
|
||||||
)));
|
|
||||||
Model.Map model
|
|
||||||
|
|
||||||
exception Resource_exhausted = Sidekick_sat.Resource_exhausted
|
exception Resource_exhausted = Sidekick_sat.Resource_exhausted
|
||||||
|
|
||||||
let solve
|
let solve
|
||||||
|
|
@ -1011,15 +1017,17 @@ module Make(A : ARG)
|
||||||
| Sat_solver.Sat (module SAT) ->
|
| Sat_solver.Sat (module SAT) ->
|
||||||
Log.debug 1 "(sidekick.smt-solver: SAT)";
|
Log.debug 1 "(sidekick.smt-solver: SAT)";
|
||||||
|
|
||||||
Log.debugf 50
|
Log.debugf 5
|
||||||
(fun k->
|
(fun k->
|
||||||
let ppc out n =
|
let ppc out n =
|
||||||
Fmt.fprintf out "{@[<hv>class@ %a@]}" (Util.pp_iter N.pp) (N.iter_class n) in
|
Fmt.fprintf out "{@[<hv>class@ %a@]}" (Util.pp_iter N.pp) (N.iter_class n) in
|
||||||
k "(@[sidekick.smt-solver.classes@ (@[%a@])@])"
|
k "(@[sidekick.smt-solver.classes@ (@[%a@])@])"
|
||||||
(Util.pp_iter ppc) (CC.all_classes @@ Solver_internal.cc self.si));
|
(Util.pp_iter ppc) (CC.all_classes @@ Solver_internal.cc self.si));
|
||||||
|
|
||||||
let _lits f = SAT.iter_trail f in
|
let m = match self.si.last_model with
|
||||||
let m = mk_model self _lits in
|
| Some m -> m
|
||||||
|
| None -> assert false
|
||||||
|
in
|
||||||
(* TODO: check model *)
|
(* TODO: check model *)
|
||||||
let _ = check in
|
let _ = check in
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue