add with_model_mode to congruence closure

this mode:
- enables `set_mode_value`
- disables all callbacks
- can only be used locally with a push/pop wrapper
This commit is contained in:
Simon Cruanes 2022-02-17 18:08:55 -05:00
parent 95f84b4854
commit 6e941683a2
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
3 changed files with 52 additions and 27 deletions

View file

@ -299,6 +299,7 @@ module Make (A: CC_ARG)
field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *) field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *)
true_ : node lazy_t; true_ : node lazy_t;
false_ : node lazy_t; false_ : node lazy_t;
mutable model_mode: bool;
mutable on_pre_merge: ev_on_pre_merge list; mutable on_pre_merge: ev_on_pre_merge list;
mutable on_post_merge: ev_on_post_merge list; mutable on_post_merge: ev_on_post_merge list;
@ -642,7 +643,9 @@ module Make (A: CC_ARG)
(* [n] might be merged with other equiv classes *) (* [n] might be merged with other equiv classes *)
push_pending cc n; push_pending cc n;
); );
if not cc.model_mode then (
List.iter (fun f -> f cc n t) cc.on_new_term; List.iter (fun f -> f cc n t) cc.on_new_term;
);
n n
(* compute the initial signature of the given node *) (* compute the initial signature of the given node *)
@ -656,7 +659,7 @@ module Make (A: CC_ARG)
begin begin
let sub_r = find_ sub in let sub_r = find_ sub in
let old_parents = sub_r.n_parents in let old_parents = sub_r.n_parents in
if Bag.is_empty old_parents then ( if Bag.is_empty old_parents && not self.model_mode then (
(* first time it has parents: tell watchers that this is a subterm *) (* first time it has parents: tell watchers that this is a subterm *)
List.iter (fun f -> f sub u) self.on_is_subterm; List.iter (fun f -> f sub u) self.on_is_subterm;
); );
@ -785,7 +788,7 @@ module Make (A: CC_ARG)
List.rev_map (fun (t,u) -> true, t.n_term, u.n_term) expl_st.same_val List.rev_map (fun (t,u) -> true, t.n_term, u.n_term) expl_st.same_val
in in
let tuples = (false, n.n_term, n'.n_term) :: tuples in let tuples = (false, n.n_term, n'.n_term) :: tuples in
Log.debugf 20 Log.debugf 5
(fun k->k "(@[cc.semantic-conflict.set-val@ (@[set-val %a@ := %a@])@ \ (fun k->k "(@[cc.semantic-conflict.set-val@ (@[set-val %a@ := %a@])@ \
(@[existing-val %a@ := %a@])@])" (@[existing-val %a@ := %a@])@])"
N.pp n Term.pp v N.pp n' Term.pp v'); N.pp n Term.pp v N.pp n' Term.pp v');
@ -872,18 +875,21 @@ module Make (A: CC_ARG)
propagate_bools cc acts r2 t2 r1 t1 e_ab false propagate_bools cc acts r2 t2 r1 t1 e_ab false
) )
in in
if not cc.model_mode then (
merge_bool ra a rb b; merge_bool ra a rb b;
merge_bool rb b ra a; merge_bool rb b ra a;
);
(* perform [union r_from r_into] *) (* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* call [on_pre_merge] functions, and merge theory data items *) (* call [on_pre_merge] functions, and merge theory data items *)
begin if not cc.model_mode then (
(* explanation is [a=ra & e_ab & b=rb] *) (* explanation is [a=ra & e_ab & b=rb] *)
let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in
List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge; List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge;
end; );
begin begin
(* parents might have a different signature, check for collisions *) (* parents might have a different signature, check for collisions *)
@ -944,7 +950,7 @@ module Make (A: CC_ARG)
in in
let tuples = (false, n_from.n_term, n_into.n_term) :: tuples in let tuples = (false, n_from.n_term, n_into.n_term) :: tuples in
Log.debugf 20 Log.debugf 5
(fun k->k "(@[cc.semantic-conflict.post-merge@ \ (fun k->k "(@[cc.semantic-conflict.post-merge@ \
(@[n-from %a@ := %a@])@ (@[n-into %a@ := %a@])@])" (@[n-from %a@ := %a@])@ (@[n-into %a@ := %a@])@])"
N.pp n_from Term.pp v_from N.pp n_into Term.pp v_into); N.pp n_from Term.pp v_from N.pp n_into Term.pp v_into);
@ -973,9 +979,9 @@ module Make (A: CC_ARG)
a.n_expl <- FL_some {next=b; expl=e_ab}; a.n_expl <- FL_some {next=b; expl=e_ab};
end; end;
(* call [on_post_merge] *) (* call [on_post_merge] *)
begin if not cc.model_mode then (
List.iter (fun f -> f cc acts r_into r_from) cc.on_post_merge; List.iter (fun f -> f cc acts r_into r_from) cc.on_post_merge;
end; );
) )
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
@ -1043,9 +1049,19 @@ module Make (A: CC_ARG)
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f()); Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f());
() ()
(* TODO:
CC.set_as_lit cc n (Lit.abs lit); (* run [f] in a local congruence closure level *)
*) let with_model_mode cc f =
assert (not cc.model_mode);
assert (not cc.new_merges);
cc.model_mode <- true;
push_level cc;
CCFun.protect f
~finally:(fun() ->
pop_levels cc 1;
cc.model_mode <- false;
cc.new_merges <- false;
)
(* assert that this boolean literal holds. (* assert that this boolean literal holds.
if a lit is [= a b], merge [a] and [b]; if a lit is [= a b], merge [a] and [b];
@ -1096,6 +1112,7 @@ module Make (A: CC_ARG)
merge cc (add_term cc t1) (add_term cc t2) expl merge cc (add_term cc t1) (add_term cc t2) expl
let set_model_value (self:t) (t:term) (v:value) : unit = let set_model_value (self:t) (t:term) (v:value) : unit =
assert (self.model_mode); (* only valid there *)
let n = add_term self t in let n = add_term self t in
Vec.push self.combine (CT_set_val (n,v)) Vec.push self.combine (CT_set_val (n,v))
@ -1127,6 +1144,7 @@ module Make (A: CC_ARG)
bitgen; bitgen;
t_to_val=T_tbl.create 32; t_to_val=T_tbl.create 32;
val_to_t=T_tbl.create 32; val_to_t=T_tbl.create 32;
model_mode=false;
on_pre_merge; on_pre_merge;
on_post_merge; on_post_merge;
on_new_term; on_new_term;

View file

@ -732,6 +732,9 @@ module type CC_S = sig
val set_model_value : t -> term -> value -> unit val set_model_value : t -> term -> value -> unit
(** Set the value of a term in the model. *) (** Set the value of a term in the model. *)
val with_model_mode : t -> (unit -> 'a) -> 'a
(** Enter model combination mode. *)
val check : t -> actions -> unit val check : t -> actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
Will use the {!actions} to propagate literals, declare conflicts, etc. *) Will use the {!actions} to propagate literals, declare conflicts, etc. *)

View file

@ -564,17 +564,13 @@ module Make(A : ARG)
CC.pop_levels (cc self) n; CC.pop_levels (cc self) n;
pop_lvls_ n self.th_states pop_lvls_ n self.th_states
(* run [f] in a local congruence closure level *)
let with_cc_level_ cc f =
CC.push_level cc;
CCFun.protect ~finally:(fun() -> CC.pop_levels cc 1) f
(* do theory combination using the congruence closure. Each theory (* do theory combination using the congruence closure. Each theory
can merge classes, *) can merge classes, *)
let check_th_combination_ let check_th_combination_
(self:t) (acts:theory_actions) : (unit, th_combination_conflict) result = (self:t) (acts:theory_actions) : (unit, th_combination_conflict) result =
let cc = cc self in let cc = cc self in
with_cc_level_ cc @@ fun () -> (* entier model mode, disabling most of congruence closure *)
CC.with_model_mode cc @@ fun () ->
let set_val (t,v) : unit = let set_val (t,v) : unit =
Log.debugf 50 Log.debugf 50
@ -622,13 +618,16 @@ module Make(A : ARG)
done; done;
CC.check cc acts; CC.check cc acts;
let new_merges_in_cc = CC.new_merges cc in let more_work_to_do = CC.new_merges cc || has_delayed_actions self in
if not more_work_to_do then (
match check_th_combination_ self acts with
| Ok () ->
continue := false;
begin match check_th_combination_ self acts with
| Ok () -> ()
| Error {lits; semantic} -> | Error {lits; semantic} ->
(* bad model, we add a clause to remove it *) (* bad model, we add a clause to remove it *)
Log.debugf 10 Log.debugf 5
(fun k->k "(@[solver.th-comb.conflict@ :lits (@[%a@])@ \ (fun k->k "(@[solver.th-comb.conflict@ :lits (@[%a@])@ \
:same-val (@[%a@])@])" :same-val (@[%a@])@])"
(Util.pp_list Lit.pp) lits (Util.pp_list Lit.pp) lits
@ -639,7 +638,11 @@ module Make(A : ARG)
semantic semantic
|> List.rev_map |> List.rev_map
(fun (sign,t,u) -> (fun (sign,t,u) ->
Lit.atom ~sign:(not sign) self.tst @@ A.mk_eq self.tst t u) let eqn = A.mk_eq self.tst t u in
let lit = Lit.atom ~sign:(not sign) self.tst eqn in
(* make sure to consider the new lit *)
add_lit self acts lit;
lit)
in in
let c = List.rev_append c1 c2 in let c = List.rev_append c1 c2 in
@ -650,11 +653,12 @@ module Make(A : ARG)
(Util.pp_list Lit.pp) c); (Util.pp_list Lit.pp) c);
(* will add a delayed action *) (* will add a delayed action *)
add_clause_temp self acts c pr; add_clause_temp self acts c pr;
end;
if not new_merges_in_cc && not (has_delayed_actions self) then ( continue := false; (* FIXME *)
continue := false;
); );
Perform_delayed_th.top self acts;
(* FIXME: give a chance to the SAT solver to run again? *)
done; done;
) else ( ) else (
List.iter (fun f -> f self acts lits) self.on_partial_check; List.iter (fun f -> f self acts lits) self.on_partial_check;