feat(cc): flag some explanations as being theory-induced

This commit is contained in:
Simon Cruanes 2020-01-17 18:48:17 -06:00
parent 5bcecfd68c
commit e21dea4780
4 changed files with 59 additions and 41 deletions

View file

@ -94,6 +94,7 @@ module Make (A: CC_ARG)
| E_merge_t of term * term
| E_congruence of node * node (* caused by normal congruence *)
| E_and of explanation * explanation
| E_theory of explanation
type repr = node
@ -166,6 +167,7 @@ module Make (A: CC_ARG)
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
| E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" Term.pp a Term.pp b
| E_theory e -> Fmt.fprintf out "(@[th@ %a@])" pp e
| E_and (a,b) ->
Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
@ -174,6 +176,7 @@ module Make (A: CC_ARG)
let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b)
let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b)
let[@inline] mk_lit l : t = E_lit l
let mk_theory e = E_theory e
let rec mk_list l =
match l with
@ -275,7 +278,7 @@ module Make (A: CC_ARG)
and ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
and ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
and ev_on_new_term = t -> N.t -> term -> unit
and ev_on_conflict = t -> lit list -> unit
and ev_on_conflict = t -> th:bool -> lit list -> unit
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
let[@inline] size_ (r:repr) = r.n_size
@ -368,11 +371,11 @@ module Make (A: CC_ARG)
n.n_expl <- FL_none;
end
let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ =
let raise_conflict (cc:t) ~th (acts:actions) (e:lit list) : _ =
(* clear tasks queue *)
Vec.clear cc.pending;
Vec.clear cc.combine;
List.iter (fun f -> f cc e) cc.on_conflict;
List.iter (fun f -> f cc ~th e) cc.on_conflict;
Stat.incr cc.count_conflict;
Actions.raise_conflict acts e P.default
@ -429,7 +432,7 @@ module Make (A: CC_ARG)
n
(* decompose explanation [e] into a list of literals added to [acc] *)
let rec explain_decompose cc (acc:lit list) (e:explanation) : _ list =
let rec explain_decompose cc ~th (acc:lit list) (e:explanation) : _ list =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
match e with
| E_reduction -> acc
@ -438,49 +441,51 @@ module Make (A: CC_ARG)
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
List.fold_left2 (explain_pair cc) acc a1 a2
List.fold_left2 (explain_pair cc ~th) acc a1 a2
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
let acc = explain_pair cc acc f1 f2 in
List.fold_left2 (explain_pair cc) acc a1 a2
let acc = explain_pair cc ~th acc f1 f2 in
List.fold_left2 (explain_pair cc ~th) acc a1 a2
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
let acc = explain_pair cc acc a1 a2 in
let acc = explain_pair cc acc b1 b2 in
explain_pair cc acc c1 c2
let acc = explain_pair cc ~th acc a1 a2 in
let acc = explain_pair cc ~th acc b1 b2 in
explain_pair cc ~th acc c1 c2
| _ ->
assert false
end
| E_lit lit -> lit :: acc
| E_merge (a,b) -> explain_pair cc acc a b
| E_theory e' ->
th := true; explain_decompose cc ~th acc e'
| E_merge (a,b) -> explain_pair cc ~th acc a b
| E_merge_t (a,b) ->
(* find nodes for [a] and [b] on the fly *)
begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with
| a, b -> explain_pair cc acc a b
| a, b -> explain_pair cc ~th acc a b
| exception Not_found ->
Error.errorf "expl: cannot find node(s) for %a, %a" Term.pp a Term.pp b
end
| E_and (a,b) ->
let acc = explain_decompose cc acc a in
explain_decompose cc acc b
let acc = explain_decompose cc ~th acc a in
explain_decompose cc ~th acc b
and explain_pair (cc:t) (acc:lit list) (a:node) (b:node) : _ list =
and explain_pair (cc:t) ~th (acc:lit list) (a:node) (b:node) : _ list =
Log.debugf 5
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
assert (N.equal (find_ a) (find_ b));
let ancestor = find_common_ancestor cc a b in
let acc = explain_along_path cc acc a ancestor in
explain_along_path cc acc b ancestor
let acc = explain_along_path cc ~th acc a ancestor in
explain_along_path cc ~th acc b ancestor
(* explain why [a = parent_a], where [a -> ... -> target] in the
proof forest *)
and explain_along_path cc acc (a:node) (target:node) : _ list =
and explain_along_path cc ~th acc (a:node) (target:node) : _ list =
let rec aux acc n =
if n == target then acc
else (
match n.n_expl with
| FL_none -> assert false
| FL_some {next=next_n; expl=expl} ->
let acc = explain_decompose cc acc expl in
let acc = explain_decompose cc ~th acc expl in
(* now prove [next_n = target] *)
aux acc next_n
)
@ -631,10 +636,11 @@ module Make (A: CC_ARG)
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ \
@[:r1 %a@ :t1 %a@]@ @[:r2 %a@ :t2 %a@]@ :e_ab %a@])"
N.pp ra N.pp a N.pp rb N.pp b Expl.pp e_ab);
let lits = explain_decompose cc [] e_ab in
let lits = explain_pair cc lits a ra in
let lits = explain_pair cc lits b rb in
raise_conflict cc acts (List.rev_map Lit.neg lits)
let th = ref false in
let lits = explain_decompose cc ~th [] e_ab in
let lits = explain_pair cc ~th lits a ra in
let lits = explain_pair cc ~th lits b rb in
raise_conflict cc ~th:!th acts (List.rev_map Lit.neg lits)
);
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always
@ -727,8 +733,9 @@ module Make (A: CC_ARG)
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
(* explanation for [t1 =e= t2 = r2] *)
let half_expl = lazy (
let lits = explain_decompose cc [] e_12 in
explain_pair cc lits r2 t2
let th = ref false in
let lits = explain_decompose cc ~th [] e_12 in
th, explain_pair cc ~th lits r2 t2
) in
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
contains at least one lit *)
@ -745,7 +752,10 @@ module Make (A: CC_ARG)
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *)
let reason =
let e = lazy (explain_pair cc (Lazy.force half_expl) u1 t1) in
let e = lazy (
let lazy (th, acc) = half_expl in
explain_pair cc ~th acc u1 t1
) in
fun () -> Lazy.force e
in
List.iter (fun f -> f cc lit reason) cc.on_propagate;
@ -808,9 +818,10 @@ module Make (A: CC_ARG)
let raise_conflict_from_expl cc (acts:actions) expl =
Log.debugf 5
(fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl);
let lits = explain_decompose cc [] expl in
let th = ref true in
let lits = explain_decompose cc ~th [] expl in
let lits = List.rev_map Lit.neg lits in
raise_conflict cc acts lits
raise_conflict cc ~th:!th acts lits
let merge cc n1 n2 expl =
Log.debugf 5

View file

@ -192,6 +192,7 @@ module type CC_S = sig
val mk_merge_t : term -> term -> t
val mk_lit : lit -> t
val mk_list : t list -> t
val mk_theory : t -> t (* TODO: indicate what theory, or even provide a lemma *)
end
type node = N.t
@ -216,7 +217,7 @@ module type CC_S = sig
type ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
type ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
type ev_on_new_term = t -> N.t -> term -> unit
type ev_on_conflict = t -> lit list -> unit
type ev_on_conflict = t -> th:bool -> lit list -> unit
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
val create :
@ -445,7 +446,7 @@ module type SOLVER_INTERNAL = sig
(** Callback to add data on terms when they are added to the congruence
closure *)
val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit
val on_cc_conflict : t -> (CC.t -> th:bool -> lit list -> unit) -> unit
(** Callback called on every CC conflict *)
val on_cc_propagate : t -> (CC.t -> lit -> (unit -> lit list) -> unit) -> unit

View file

@ -71,9 +71,11 @@ module Check_cc = struct
~create_and_setup:(fun si ->
let n_calls = Stat.mk_int (Solver.Solver_internal.stats si) "check-cc.call" in
Solver.Solver_internal.on_cc_conflict si
(fun c ->
Stat.incr n_calls;
check_conflict si c))
(fun cc ~th c ->
if not th then (
Stat.incr n_calls;
check_conflict si cc c
)))
()
end

View file

@ -169,7 +169,7 @@ module Make(A : ARG) : S with module A = A = struct
(* build full explanation of why the constructor terms are equal *)
(* TODO: have a sort of lemma (injectivity) here to justify this in the proof *)
let expl =
Expl.mk_list [
Expl.mk_theory @@ Expl.mk_list [
e_n1_n2;
Expl.mk_merge n1 c1.c_n;
Expl.mk_merge n2 c2.c_n;
@ -331,7 +331,7 @@ module Make(A : ARG) : S with module A = A = struct
Log.debugf 5
(fun k->k "(@[%s.on-new-term.is-a.reduce@ :t %a@ :to %B@ :n %a@ :sub-cstor %a@])"
name T.pp t is_true N.pp n Monoid_cstor.pp cstor);
SI.CC.merge cc n (SI.CC.n_bool cc is_true) (Expl.mk_merge n_u repr_u)
SI.CC.merge cc n (SI.CC.n_bool cc is_true) (Expl.mk_theory @@ Expl.mk_merge n_u repr_u)
end
| T_select (c_t, i, u) ->
let n_u = SI.CC.add_term cc u in
@ -344,7 +344,7 @@ module Make(A : ARG) : S with module A = A = struct
assert (i < IArray.length cstor.c_args);
let u_i = IArray.get cstor.c_args i in
let n_u_i = SI.CC.add_term cc u_i in
SI.CC.merge cc n n_u_i (Expl.mk_merge n_u repr_u)
SI.CC.merge cc n n_u_i (Expl.mk_theory @@ Expl.mk_merge n_u repr_u)
| Some _ -> ()
| None ->
N_tbl.add self.to_decide repr_u (); (* needs to be decided *)
@ -364,7 +364,7 @@ module Make(A : ARG) : S with module A = A = struct
name Monoid_parents.pp_is_a is_a2 is_true N.pp n1 N.pp n2 Monoid_cstor.pp c1);
SI.CC.merge cc is_a2.is_a_n (SI.CC.n_bool cc is_true)
Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2;
mk_merge_t (N.term n2) is_a2.is_a_arg])
mk_merge_t (N.term n2) is_a2.is_a_arg] |> mk_theory)
in
let merge_select n1 (c1:Monoid_cstor.t) n2 (sel2:Monoid_parents.select) =
if A.Cstor.equal c1.c_cstor sel2.sel_cstor then (
@ -376,7 +376,7 @@ module Make(A : ARG) : S with module A = A = struct
let n_u_i = SI.CC.add_term cc u_i in
SI.CC.merge cc sel2.sel_n n_u_i
Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2;
mk_merge_t (N.term n2) sel2.sel_arg]);
mk_merge_t (N.term n2) sel2.sel_arg] |> mk_theory);
)
in
let merge_c_p n1 n2 =
@ -409,7 +409,7 @@ module Make(A : ARG) : S with module A = A = struct
| Current of parent_uplink option
let pp_st out st =
Fmt.fprintf out "(@[st :cstor %a@ :flag %s@])"
Fmt.fprintf out "(@[acycl.st :cstor %a@ :flag %s@])"
Monoid_cstor.pp st.cstor
(match st.flag with
| New -> "new" | Done -> "done"
@ -454,7 +454,7 @@ module Make(A : ARG) : S with module A = A = struct
| None -> c0
| Some parent -> mk_path c0 n parent
in
SI.CC.raise_conflict_from_expl cc acts (Expl.mk_list c)
SI.CC.raise_conflict_from_expl cc acts (Expl.mk_list c |> Expl.mk_theory)
(* traverse constructor arguments *)
and traverse_sub n st: unit =
IArray.iter
@ -468,6 +468,10 @@ module Make(A : ARG) : S with module A = A = struct
st.cstor.Monoid_cstor.c_args;
in
begin
(* TODO: instead, create whole graph here (repr -> cstor*repr list)
and then just look for cycles in the graph using DFS.
Be sure to annotate edges with all info for conflicts, so that the
full conflict is just the cycle itself. *)
(* populate tbl with [repr->cstor] *)
ST_cstors.iter_all self.cstors
(fun (n, cstor) ->
@ -495,7 +499,7 @@ module Make(A : ARG) : S with module A = A = struct
Log.debugf 50
(fun k->k"(@[%s.assign-is-a@ :lhs %a@ :rhs %a@ :lit %a@])"
name T.pp u T.pp rhs SI.Lit.pp lit);
SI.cc_merge_t solver acts u rhs (Expl.mk_lit lit)
SI.cc_merge_t solver acts u rhs (Expl.mk_theory @@ Expl.mk_lit lit)
| _ -> ()
in
Iter.iter check_lit trail