feat(cc): remove same-val explanations and model mode

This commit is contained in:
Simon Cruanes 2022-07-21 23:29:07 -04:00
parent dc68a60151
commit e37f66c394
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -126,7 +126,6 @@ module Make (A : ARG) :
| E_congruence of e_node * e_node (* caused by normal congruence *)
| E_and of explanation * explanation
| E_theory of term * term * (term * term * explanation list) list * step_id
| E_same_val of e_node * e_node
type repr = e_node
@ -212,8 +211,6 @@ module Make (A : ARG) :
(Util.pp_list @@ Fmt.Dump.triple Term.pp Term.pp (Fmt.Dump.list pp))
es
| E_and (a, b) -> Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
| E_same_val (n1, n2) ->
Fmt.fprintf out "(@[same-value@ %a@ %a@])" E_node.pp n1 E_node.pp n2
let mk_trivial : t = E_trivial
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1, n2)
@ -233,12 +230,6 @@ module Make (A : ARG) :
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_theory t u es pr = E_theory (t, u, es, pr)
let[@inline] mk_same_value t u =
if E_node.equal t u then
mk_trivial
else
E_same_val (t, u)
let rec mk_list l =
match l with
| [] -> mk_trivial
@ -251,28 +242,10 @@ module Make (A : ARG) :
end
module Resolved_expl = struct
type t = {
lits: lit list;
same_value: (E_node.t * E_node.t) list;
pr: proof_trace -> step_id;
}
let[@inline] is_semantic (self : t) : bool =
match self.same_value with
| [] -> false
| _ :: _ -> true
type t = { lits: lit list; pr: proof_trace -> step_id }
let pp out (self : t) =
if not (is_semantic self) then
Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp)
self.lits
else (
let { lits; same_value; pr = _ } = self in
Fmt.fprintf out "(@[resolved-expl@ (@[%a@])@ :same-val (@[%a@])@])"
(Util.pp_list Lit.pp) lits
(Util.pp_list @@ Fmt.Dump.pair E_node.pp E_node.pp)
same_value
)
Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits
end
type propagation_reason = unit -> lit list * step_id
@ -345,11 +318,9 @@ module Make (A : ARG) :
module Sig_tbl = CCHashtbl.Make (Signature)
module T_tbl = CCHashtbl.Make (Term)
module T_b_tbl = Backtrackable_tbl.Make (Term)
type combine_task =
| CT_merge of e_node * e_node * explanation
| CT_set_val of e_node * value
| CT_act of action
type t = {
@ -366,17 +337,12 @@ module Make (A : ARG) :
have the same signature *)
pending: e_node Vec.t;
combine: combine_task Vec.t;
t_to_val: (e_node * value) T_b_tbl.t;
(* TODO: remove this, make it a plugin/EGG instead *)
(* [repr -> (t,val)] where [repr = t] and [t := val] in the model *)
val_to_t: e_node T_b_tbl.t; (* [val -> t] where [t := val] in the model *)
undo: (unit -> unit) Backtrack_stack.t;
bitgen: Bits.bitfield_gen;
field_marked_explain: Bits.field;
(* used to mark traversed nodes when looking for a common ancestor *)
true_: e_node lazy_t;
false_: e_node lazy_t;
mutable model_mode: bool;
mutable in_loop: bool; (* currently being modified? *)
res_acts: action Vec.t; (* to return *)
on_pre_merge:
@ -389,7 +355,6 @@ module Make (A : ARG) :
count_conflict: int Stat.counter;
count_props: int Stat.counter;
count_merge: int Stat.counter;
count_semantic_conflict: int Stat.counter;
}
(* TODO: an additional union-find to keep track, for each term,
of the terms they are known to be equal to, according
@ -581,31 +546,20 @@ module Make (A : ARG) :
module Expl_state = struct
type t = {
mutable lits: Lit.t list;
mutable same_val: (E_node.t * E_node.t) list;
mutable th_lemmas: (Lit.t * (Lit.t * Lit.t list) list * step_id) list;
}
let create () : t = { lits = []; same_val = []; th_lemmas = [] }
let create () : t = { lits = []; th_lemmas = [] }
let[@inline] copy self : t = { self with lits = self.lits }
let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits
let[@inline] add_th (self : t) lit hyps pr : unit =
self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas
let[@inline] add_same_val (self : t) n1 n2 : unit =
self.same_val <- (n1, n2) :: self.same_val
(** Does this explanation contain at least one merge caused by
"same value"? *)
let[@inline] is_semantic (self : t) : bool = self.same_val <> []
let merge self other =
let { lits = o_lits; th_lemmas = o_lemmas; same_val = o_same_val } =
other
in
let { lits = o_lits; th_lemmas = o_lemmas } = other in
self.lits <- List.rev_append o_lits self.lits;
self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas;
self.same_val <- List.rev_append o_same_val self.same_val;
()
(* proof of [\/_i ¬lits[i]] *)
@ -643,10 +597,10 @@ module Make (A : ARG) :
let to_resolved_expl (self : t) : Resolved_expl.t =
(* FIXME: package the th lemmas too *)
let { lits; same_val; th_lemmas = _ } = self in
let { lits; th_lemmas = _ } = self in
let s2 = copy self in
let pr proof = proof_of_th_lemmas s2 proof in
{ Resolved_expl.lits; same_value = same_val; pr }
{ Resolved_expl.lits; pr }
end
(* decompose explanation [e] into a list of literals added to [acc] *)
@ -670,7 +624,6 @@ module Make (A : ARG) :
explain_equal_rec_ self st c1 c2
| _ -> assert false)
| E_lit lit -> Expl_state.add_lit st lit
| E_same_val (n1, n2) -> Expl_state.add_same_val st n1 n2
| E_theory (t, u, expl_sets, pr) ->
let sub_proofs =
List.map
@ -752,8 +705,7 @@ module Make (A : ARG) :
if Option.is_some sig0 then
(* [n] might be merged with other equiv classes *)
push_pending self n;
if not self.model_mode then
Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self);
Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self);
n
(* compute the initial signature of the given e_node *)
@ -766,7 +718,7 @@ module Make (A : ARG) :
(* add [n] to [sub.root]'s parent list *)
(let sub_r = find_ sub in
let old_parents = sub_r.n_parents in
if Bag.is_empty old_parents && not self.model_mode then
if Bag.is_empty old_parents then
(* first time it has parents: tell watchers that this is a subterm *)
Event.emit_iter self.on_is_subterm (self, sub, u)
~f:(push_action_l self);
@ -816,8 +768,7 @@ module Make (A : ARG) :
merges. *)
let lits_and_proof_of_expl (self : t) (st : Expl_state.t) :
Lit.t list * step_id =
let { Expl_state.lits; th_lemmas = _; same_val } = st in
assert (same_val = []);
let { Expl_state.lits; th_lemmas = _ } = st in
let pr = Expl_state.proof_of_th_lemmas st self.proof in
lits, pr
@ -873,51 +824,11 @@ module Make (A : ARG) :
and task_combine_ self = function
| CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab
| CT_set_val (n, v) -> task_set_val_ self n v
| CT_act (Act_merge (t, u, e)) -> task_merge_ self t u e
| CT_act (Act_propagate _ as a) ->
(* will return this propagation to the caller *)
Vec.push self.res_acts a
and task_set_val_ self n v =
let repr_n = find_ n in
(* - if repr(n) has value [v], do nothing
- else if repr(n) has value [v'], semantic conflict
- else add [repr(n) -> (n,v)] to cc.t_to_val *)
(match T_b_tbl.get self.t_to_val repr_n.n_term with
| Some (n', v') when not (Term.equal v v') ->
(* semantic conflict *)
let expl = [ Expl.mk_merge n n' ] in
let expl_st = explain_expls self expl in
let lits = expl_st.lits in
let tuples =
List.rev_map (fun (t, u) -> true, t.n_term, u.n_term) expl_st.same_val
in
let tuples = (false, n.n_term, n'.n_term) :: tuples in
Log.debugf 5 (fun k ->
k
"(@[cc.semantic-conflict.set-val@ (@[set-val %a@ := %a@])@ \
(@[existing-val %a@ := %a@])@])"
E_node.pp n Term.pp v E_node.pp n' Term.pp v');
Stat.incr self.count_semantic_conflict;
(* FIXME
raise (E_confl(Conflict lits))
let (module A) = acts in
A.raise_semantic_conflict lits tuples
*)
assert false
| Some _ -> ()
| None -> T_b_tbl.add self.t_to_val repr_n.n_term (n, v));
(* now for the reverse map, look in self.val_to_t for [v].
- if present, push a merge command with Expl.mk_same_value
- if not, add [v -> n] *)
match T_b_tbl.get self.val_to_t v with
| None -> T_b_tbl.add self.val_to_t v n
| Some n' when not (same_class n n') ->
merge_classes self n n' (Expl.mk_same_value n n')
| Some _ -> ()
(* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *)
and task_merge_ self a b e_ab : unit =
@ -948,25 +859,9 @@ module Make (A : ARG) :
explain_equal_rec_ self expl_st a ra;
explain_equal_rec_ self expl_st b rb;
if Expl_state.is_semantic expl_st then (
(* conflict involving some semantic values *)
let lits = expl_st.lits in
let same_val =
expl_st.same_val
|> List.rev_map (fun (t, u) -> true, E_node.term t, E_node.term u)
in
assert (same_val <> []);
Stat.incr self.count_semantic_conflict;
(* FIXME
let (module A) = acts in
A.raise_semantic_conflict lits same_val
*)
assert false
) else (
(* regular conflict *)
let lits, pr = lits_and_proof_of_expl self expl_st in
raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr
)
(* regular conflict *)
let lits, pr = lits_and_proof_of_expl self expl_st in
raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr
);
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always
@ -989,10 +884,8 @@ module Make (A : ARG) :
propagate_bools self r2 t2 r1 t1 e_ab false
in
if not self.model_mode then (
merge_bool ra a rb b;
merge_bool rb b ra a
);
merge_bool ra a rb b;
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k ->
@ -1000,16 +893,14 @@ module Make (A : ARG) :
r_into);
(* call [on_pre_merge] functions, and merge theory data items *)
if not self.model_mode then (
(* explanation is [a=ra & e_ab & b=rb] *)
let expl =
Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ]
in
Event.emit_iter self.on_pre_merge (self, r_into, r_from, expl)
~f:(function
| Ok l -> push_action_l self l
| Error c -> raise (E_confl c))
);
(* explanation is [a=ra & e_ab & b=rb] *)
let expl =
Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ]
in
Event.emit_iter self.on_pre_merge (self, r_into, r_from, expl)
~f:(function
| Ok l -> push_action_l self l
| Error c -> raise (E_confl c));
(* TODO: merge plugin data here, _after_ the pre-merge hooks are called,
so they have a chance of observing pre-merge plugin data *)
@ -1044,41 +935,6 @@ module Make (A : ARG) :
E_node.iter_class_ r_from (fun u -> u.n_root <- r_from);
r_into.n_size <- r_into.n_size - r_from.n_size));
(* check for semantic values, update the one of [r_into]
if [r_from] has a value *)
(match T_b_tbl.get self.t_to_val r_from.n_term with
| None -> ()
| Some (n_from, v_from) ->
(match T_b_tbl.get self.t_to_val r_into.n_term with
| None -> T_b_tbl.add self.t_to_val r_into.n_term (n_from, v_from)
| Some (n_into, v_into) when not (Term.equal v_from v_into) ->
(* semantic conflict, including [n_from != n_into] in model *)
let expl =
[ e_ab; Expl.mk_merge r_from n_from; Expl.mk_merge r_into n_into ]
in
let expl_st = explain_expls self expl in
let lits = expl_st.lits in
let tuples =
List.rev_map
(fun (t, u) -> true, t.n_term, u.n_term)
expl_st.same_val
in
let tuples = (false, n_from.n_term, n_into.n_term) :: tuples in
Log.debugf 5 (fun k ->
k
"(@[cc.semantic-conflict.post-merge@ (@[n-from %a@ := %a@])@ \
(@[n-into %a@ := %a@])@])"
E_node.pp n_from Term.pp v_from E_node.pp n_into Term.pp v_into);
Stat.incr self.count_semantic_conflict;
(* FIXME
let (module A) = acts in
A.raise_semantic_conflict lits tuples
*)
assert false
| Some _ -> ()));
(* update explanations (a -> b), arbitrarily.
Note that here we merge the classes by adding a bridge between [a]
and [b], not their roots. *)
@ -1093,9 +949,8 @@ module Make (A : ARG) :
| _ -> assert false);
a.n_expl <- FL_some { next = b; expl = e_ab };
(* call [on_post_merge] *)
if not self.model_mode then
Event.emit_iter self.on_post_merge (self, r_into, r_from)
~f:(push_action_l self)
Event.emit_iter self.on_post_merge (self, r_into, r_from)
~f:(push_action_l self)
)
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
@ -1136,29 +991,25 @@ module Make (A : ARG) :
explain_equal_rec_ self st u1 t1;
(* propagate only if this doesn't depend on some semantic values *)
if not (Expl_state.is_semantic st) then (
let reason () =
(* true literals explaining why t1=t2 *)
let guard = st.lits in
(* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *)
Expl_state.add_lit st (Lit.neg lit);
let _, pr = lits_and_proof_of_expl self st in
guard, pr
in
push_action self (Act_propagate { lit; reason });
Event.emit_iter self.on_propagate (self, lit, reason)
~f:(push_action_l self);
Stat.incr self.count_props
)
let reason () =
(* true literals explaining why t1=t2 *)
let guard = st.lits in
(* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *)
Expl_state.add_lit st (Lit.neg lit);
let _, pr = lits_and_proof_of_expl self st in
guard, pr
in
push_action self (Act_propagate { lit; reason });
Event.emit_iter self.on_propagate (self, lit, reason)
~f:(push_action_l self);
Stat.incr self.count_props
| _ -> ())
let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t)
let push_level (self : t) : unit =
assert (not self.in_loop);
Backtrack_stack.push_level self.undo;
T_b_tbl.push_level self.t_to_val;
T_b_tbl.push_level self.val_to_t
Backtrack_stack.push_level self.undo
let pop_levels (self : t) n : unit =
assert (not self.in_loop);
@ -1168,28 +1019,8 @@ module Make (A : ARG) :
k "(@[cc.pop-levels %d@ :n-lvls %d@])" n
(Backtrack_stack.n_levels self.undo));
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f ());
T_b_tbl.pop_levels self.t_to_val n;
T_b_tbl.pop_levels self.val_to_t n;
()
(* FIXME: remove *)
(* run [f] in a local congruence closure level *)
let with_model_mode self f =
assert (not self.model_mode);
self.model_mode <- true;
push_level self;
CCFun.protect f ~finally:(fun () ->
pop_levels self 1;
self.model_mode <- false)
let get_model_for_each_class self : _ Iter.t =
assert self.model_mode;
all_classes self
|> Iter.filter_map (fun repr ->
match T_b_tbl.get self.t_to_val repr.n_term with
| Some (_, v) -> Some (repr, E_node.iter_class repr, v)
| None -> None)
let assert_eq self t u expl : unit =
assert (not self.in_loop);
let t = add_term self t in
@ -1246,14 +1077,6 @@ module Make (A : ARG) :
let merge_t self t1 t2 expl =
merge self (add_term self t1) (add_term self t2) expl
let set_model_value (self : t) (t : term) (v : value) : unit =
assert (not self.in_loop);
assert self.model_mode;
(* only valid in model mode *)
match T_tbl.find_opt self.tbl t with
| None -> () (* ignore, th combination not needed *)
| Some n -> Vec.push self.combine (CT_set_val (n, v))
let explain_eq self n1 n2 : Resolved_expl.t =
let st = Expl_state.create () in
explain_equal_rec_ self st n1 n2;
@ -1283,9 +1106,6 @@ module Make (A : ARG) :
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
bitgen;
t_to_val = T_b_tbl.create ~size:32 ();
val_to_t = T_b_tbl.create ~size:32 ();
model_mode = false;
on_pre_merge = Event.Emitter.create ();
on_post_merge = Event.Emitter.create ();
on_new_term = Event.Emitter.create ();
@ -1303,7 +1123,6 @@ module Make (A : ARG) :
count_conflict = Stat.mk_int stat "cc.conflicts";
count_props = Stat.mk_int stat "cc.propagations";
count_merge = Stat.mk_int stat "cc.merges";
count_semantic_conflict = Stat.mk_int stat "cc.semantic-conflicts";
}
and true_ = lazy (add_term cc (Term.bool tst true))
and false_ = lazy (add_term cc (Term.bool tst false)) in