feat(cc): flag some explanations as being theory-induced

This commit is contained in:
Simon Cruanes 2020-01-17 18:48:17 -06:00
parent 5bcecfd68c
commit e21dea4780
4 changed files with 59 additions and 41 deletions

View file

@ -94,6 +94,7 @@ module Make (A: CC_ARG)
| E_merge_t of term * term | E_merge_t of term * term
| E_congruence of node * node (* caused by normal congruence *) | E_congruence of node * node (* caused by normal congruence *)
| E_and of explanation * explanation | E_and of explanation * explanation
| E_theory of explanation
type repr = node type repr = node
@ -166,6 +167,7 @@ module Make (A: CC_ARG)
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
| E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" Term.pp a Term.pp b | E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" Term.pp a Term.pp b
| E_theory e -> Fmt.fprintf out "(@[th@ %a@])" pp e
| E_and (a,b) -> | E_and (a,b) ->
Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
@ -174,6 +176,7 @@ module Make (A: CC_ARG)
let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b) let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b)
let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b) let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b)
let[@inline] mk_lit l : t = E_lit l let[@inline] mk_lit l : t = E_lit l
let mk_theory e = E_theory e
let rec mk_list l = let rec mk_list l =
match l with match l with
@ -275,7 +278,7 @@ module Make (A: CC_ARG)
and ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit and ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
and ev_on_post_merge = t -> actions -> N.t -> N.t -> unit and ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
and ev_on_new_term = t -> N.t -> term -> unit and ev_on_new_term = t -> N.t -> term -> unit
and ev_on_conflict = t -> lit list -> unit and ev_on_conflict = t -> th:bool -> lit list -> unit
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
let[@inline] size_ (r:repr) = r.n_size let[@inline] size_ (r:repr) = r.n_size
@ -368,11 +371,11 @@ module Make (A: CC_ARG)
n.n_expl <- FL_none; n.n_expl <- FL_none;
end end
let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ = let raise_conflict (cc:t) ~th (acts:actions) (e:lit list) : _ =
(* clear tasks queue *) (* clear tasks queue *)
Vec.clear cc.pending; Vec.clear cc.pending;
Vec.clear cc.combine; Vec.clear cc.combine;
List.iter (fun f -> f cc e) cc.on_conflict; List.iter (fun f -> f cc ~th e) cc.on_conflict;
Stat.incr cc.count_conflict; Stat.incr cc.count_conflict;
Actions.raise_conflict acts e P.default Actions.raise_conflict acts e P.default
@ -429,7 +432,7 @@ module Make (A: CC_ARG)
n n
(* decompose explanation [e] into a list of literals added to [acc] *) (* decompose explanation [e] into a list of literals added to [acc] *)
let rec explain_decompose cc (acc:lit list) (e:explanation) : _ list = let rec explain_decompose cc ~th (acc:lit list) (e:explanation) : _ list =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
match e with match e with
| E_reduction -> acc | E_reduction -> acc
@ -438,49 +441,51 @@ module Make (A: CC_ARG)
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2); assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2); assert (List.length a1 = List.length a2);
List.fold_left2 (explain_pair cc) acc a1 a2 List.fold_left2 (explain_pair cc ~th) acc a1 a2
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2); assert (List.length a1 = List.length a2);
let acc = explain_pair cc acc f1 f2 in let acc = explain_pair cc ~th acc f1 f2 in
List.fold_left2 (explain_pair cc) acc a1 a2 List.fold_left2 (explain_pair cc ~th) acc a1 a2
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
let acc = explain_pair cc acc a1 a2 in let acc = explain_pair cc ~th acc a1 a2 in
let acc = explain_pair cc acc b1 b2 in let acc = explain_pair cc ~th acc b1 b2 in
explain_pair cc acc c1 c2 explain_pair cc ~th acc c1 c2
| _ -> | _ ->
assert false assert false
end end
| E_lit lit -> lit :: acc | E_lit lit -> lit :: acc
| E_merge (a,b) -> explain_pair cc acc a b | E_theory e' ->
th := true; explain_decompose cc ~th acc e'
| E_merge (a,b) -> explain_pair cc ~th acc a b
| E_merge_t (a,b) -> | E_merge_t (a,b) ->
(* find nodes for [a] and [b] on the fly *) (* find nodes for [a] and [b] on the fly *)
begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with
| a, b -> explain_pair cc acc a b | a, b -> explain_pair cc ~th acc a b
| exception Not_found -> | exception Not_found ->
Error.errorf "expl: cannot find node(s) for %a, %a" Term.pp a Term.pp b Error.errorf "expl: cannot find node(s) for %a, %a" Term.pp a Term.pp b
end end
| E_and (a,b) -> | E_and (a,b) ->
let acc = explain_decompose cc acc a in let acc = explain_decompose cc ~th acc a in
explain_decompose cc acc b explain_decompose cc ~th acc b
and explain_pair (cc:t) (acc:lit list) (a:node) (b:node) : _ list = and explain_pair (cc:t) ~th (acc:lit list) (a:node) (b:node) : _ list =
Log.debugf 5 Log.debugf 5
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
assert (N.equal (find_ a) (find_ b)); assert (N.equal (find_ a) (find_ b));
let ancestor = find_common_ancestor cc a b in let ancestor = find_common_ancestor cc a b in
let acc = explain_along_path cc acc a ancestor in let acc = explain_along_path cc ~th acc a ancestor in
explain_along_path cc acc b ancestor explain_along_path cc ~th acc b ancestor
(* explain why [a = parent_a], where [a -> ... -> target] in the (* explain why [a = parent_a], where [a -> ... -> target] in the
proof forest *) proof forest *)
and explain_along_path cc acc (a:node) (target:node) : _ list = and explain_along_path cc ~th acc (a:node) (target:node) : _ list =
let rec aux acc n = let rec aux acc n =
if n == target then acc if n == target then acc
else ( else (
match n.n_expl with match n.n_expl with
| FL_none -> assert false | FL_none -> assert false
| FL_some {next=next_n; expl=expl} -> | FL_some {next=next_n; expl=expl} ->
let acc = explain_decompose cc acc expl in let acc = explain_decompose cc ~th acc expl in
(* now prove [next_n = target] *) (* now prove [next_n = target] *)
aux acc next_n aux acc next_n
) )
@ -631,10 +636,11 @@ module Make (A: CC_ARG)
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ \ (fun k->k "(@[<hv>cc.merge.true_false_conflict@ \
@[:r1 %a@ :t1 %a@]@ @[:r2 %a@ :t2 %a@]@ :e_ab %a@])" @[:r1 %a@ :t1 %a@]@ @[:r2 %a@ :t2 %a@]@ :e_ab %a@])"
N.pp ra N.pp a N.pp rb N.pp b Expl.pp e_ab); N.pp ra N.pp a N.pp rb N.pp b Expl.pp e_ab);
let lits = explain_decompose cc [] e_ab in let th = ref false in
let lits = explain_pair cc lits a ra in let lits = explain_decompose cc ~th [] e_ab in
let lits = explain_pair cc lits b rb in let lits = explain_pair cc ~th lits a ra in
raise_conflict cc acts (List.rev_map Lit.neg lits) let lits = explain_pair cc ~th lits b rb in
raise_conflict cc ~th:!th acts (List.rev_map Lit.neg lits)
); );
(* We will merge [r_from] into [r_into]. (* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always we try to ensure that [size ra <= size rb] in general, but always
@ -727,8 +733,9 @@ module Make (A: CC_ARG)
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
(* explanation for [t1 =e= t2 = r2] *) (* explanation for [t1 =e= t2 = r2] *)
let half_expl = lazy ( let half_expl = lazy (
let lits = explain_decompose cc [] e_12 in let th = ref false in
explain_pair cc lits r2 t2 let lits = explain_decompose cc ~th [] e_12 in
th, explain_pair cc ~th lits r2 t2
) in ) in
(* TODO: flag per class, `or`-ed on merge, to indicate if the class (* TODO: flag per class, `or`-ed on merge, to indicate if the class
contains at least one lit *) contains at least one lit *)
@ -745,7 +752,10 @@ module Make (A: CC_ARG)
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *) (* complete explanation with the [u1=t1] chunk *)
let reason = let reason =
let e = lazy (explain_pair cc (Lazy.force half_expl) u1 t1) in let e = lazy (
let lazy (th, acc) = half_expl in
explain_pair cc ~th acc u1 t1
) in
fun () -> Lazy.force e fun () -> Lazy.force e
in in
List.iter (fun f -> f cc lit reason) cc.on_propagate; List.iter (fun f -> f cc lit reason) cc.on_propagate;
@ -808,9 +818,10 @@ module Make (A: CC_ARG)
let raise_conflict_from_expl cc (acts:actions) expl = let raise_conflict_from_expl cc (acts:actions) expl =
Log.debugf 5 Log.debugf 5
(fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); (fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl);
let lits = explain_decompose cc [] expl in let th = ref true in
let lits = explain_decompose cc ~th [] expl in
let lits = List.rev_map Lit.neg lits in let lits = List.rev_map Lit.neg lits in
raise_conflict cc acts lits raise_conflict cc ~th:!th acts lits
let merge cc n1 n2 expl = let merge cc n1 n2 expl =
Log.debugf 5 Log.debugf 5

View file

@ -192,6 +192,7 @@ module type CC_S = sig
val mk_merge_t : term -> term -> t val mk_merge_t : term -> term -> t
val mk_lit : lit -> t val mk_lit : lit -> t
val mk_list : t list -> t val mk_list : t list -> t
val mk_theory : t -> t (* TODO: indicate what theory, or even provide a lemma *)
end end
type node = N.t type node = N.t
@ -216,7 +217,7 @@ module type CC_S = sig
type ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit type ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
type ev_on_post_merge = t -> actions -> N.t -> N.t -> unit type ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
type ev_on_new_term = t -> N.t -> term -> unit type ev_on_new_term = t -> N.t -> term -> unit
type ev_on_conflict = t -> lit list -> unit type ev_on_conflict = t -> th:bool -> lit list -> unit
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
val create : val create :
@ -445,7 +446,7 @@ module type SOLVER_INTERNAL = sig
(** Callback to add data on terms when they are added to the congruence (** Callback to add data on terms when they are added to the congruence
closure *) closure *)
val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit val on_cc_conflict : t -> (CC.t -> th:bool -> lit list -> unit) -> unit
(** Callback called on every CC conflict *) (** Callback called on every CC conflict *)
val on_cc_propagate : t -> (CC.t -> lit -> (unit -> lit list) -> unit) -> unit val on_cc_propagate : t -> (CC.t -> lit -> (unit -> lit list) -> unit) -> unit

View file

@ -71,9 +71,11 @@ module Check_cc = struct
~create_and_setup:(fun si -> ~create_and_setup:(fun si ->
let n_calls = Stat.mk_int (Solver.Solver_internal.stats si) "check-cc.call" in let n_calls = Stat.mk_int (Solver.Solver_internal.stats si) "check-cc.call" in
Solver.Solver_internal.on_cc_conflict si Solver.Solver_internal.on_cc_conflict si
(fun c -> (fun cc ~th c ->
Stat.incr n_calls; if not th then (
check_conflict si c)) Stat.incr n_calls;
check_conflict si cc c
)))
() ()
end end

View file

@ -169,7 +169,7 @@ module Make(A : ARG) : S with module A = A = struct
(* build full explanation of why the constructor terms are equal *) (* build full explanation of why the constructor terms are equal *)
(* TODO: have a sort of lemma (injectivity) here to justify this in the proof *) (* TODO: have a sort of lemma (injectivity) here to justify this in the proof *)
let expl = let expl =
Expl.mk_list [ Expl.mk_theory @@ Expl.mk_list [
e_n1_n2; e_n1_n2;
Expl.mk_merge n1 c1.c_n; Expl.mk_merge n1 c1.c_n;
Expl.mk_merge n2 c2.c_n; Expl.mk_merge n2 c2.c_n;
@ -331,7 +331,7 @@ module Make(A : ARG) : S with module A = A = struct
Log.debugf 5 Log.debugf 5
(fun k->k "(@[%s.on-new-term.is-a.reduce@ :t %a@ :to %B@ :n %a@ :sub-cstor %a@])" (fun k->k "(@[%s.on-new-term.is-a.reduce@ :t %a@ :to %B@ :n %a@ :sub-cstor %a@])"
name T.pp t is_true N.pp n Monoid_cstor.pp cstor); name T.pp t is_true N.pp n Monoid_cstor.pp cstor);
SI.CC.merge cc n (SI.CC.n_bool cc is_true) (Expl.mk_merge n_u repr_u) SI.CC.merge cc n (SI.CC.n_bool cc is_true) (Expl.mk_theory @@ Expl.mk_merge n_u repr_u)
end end
| T_select (c_t, i, u) -> | T_select (c_t, i, u) ->
let n_u = SI.CC.add_term cc u in let n_u = SI.CC.add_term cc u in
@ -344,7 +344,7 @@ module Make(A : ARG) : S with module A = A = struct
assert (i < IArray.length cstor.c_args); assert (i < IArray.length cstor.c_args);
let u_i = IArray.get cstor.c_args i in let u_i = IArray.get cstor.c_args i in
let n_u_i = SI.CC.add_term cc u_i in let n_u_i = SI.CC.add_term cc u_i in
SI.CC.merge cc n n_u_i (Expl.mk_merge n_u repr_u) SI.CC.merge cc n n_u_i (Expl.mk_theory @@ Expl.mk_merge n_u repr_u)
| Some _ -> () | Some _ -> ()
| None -> | None ->
N_tbl.add self.to_decide repr_u (); (* needs to be decided *) N_tbl.add self.to_decide repr_u (); (* needs to be decided *)
@ -364,7 +364,7 @@ module Make(A : ARG) : S with module A = A = struct
name Monoid_parents.pp_is_a is_a2 is_true N.pp n1 N.pp n2 Monoid_cstor.pp c1); name Monoid_parents.pp_is_a is_a2 is_true N.pp n1 N.pp n2 Monoid_cstor.pp c1);
SI.CC.merge cc is_a2.is_a_n (SI.CC.n_bool cc is_true) SI.CC.merge cc is_a2.is_a_n (SI.CC.n_bool cc is_true)
Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2; Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2;
mk_merge_t (N.term n2) is_a2.is_a_arg]) mk_merge_t (N.term n2) is_a2.is_a_arg] |> mk_theory)
in in
let merge_select n1 (c1:Monoid_cstor.t) n2 (sel2:Monoid_parents.select) = let merge_select n1 (c1:Monoid_cstor.t) n2 (sel2:Monoid_parents.select) =
if A.Cstor.equal c1.c_cstor sel2.sel_cstor then ( if A.Cstor.equal c1.c_cstor sel2.sel_cstor then (
@ -376,7 +376,7 @@ module Make(A : ARG) : S with module A = A = struct
let n_u_i = SI.CC.add_term cc u_i in let n_u_i = SI.CC.add_term cc u_i in
SI.CC.merge cc sel2.sel_n n_u_i SI.CC.merge cc sel2.sel_n n_u_i
Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2; Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2;
mk_merge_t (N.term n2) sel2.sel_arg]); mk_merge_t (N.term n2) sel2.sel_arg] |> mk_theory);
) )
in in
let merge_c_p n1 n2 = let merge_c_p n1 n2 =
@ -409,7 +409,7 @@ module Make(A : ARG) : S with module A = A = struct
| Current of parent_uplink option | Current of parent_uplink option
let pp_st out st = let pp_st out st =
Fmt.fprintf out "(@[st :cstor %a@ :flag %s@])" Fmt.fprintf out "(@[acycl.st :cstor %a@ :flag %s@])"
Monoid_cstor.pp st.cstor Monoid_cstor.pp st.cstor
(match st.flag with (match st.flag with
| New -> "new" | Done -> "done" | New -> "new" | Done -> "done"
@ -454,7 +454,7 @@ module Make(A : ARG) : S with module A = A = struct
| None -> c0 | None -> c0
| Some parent -> mk_path c0 n parent | Some parent -> mk_path c0 n parent
in in
SI.CC.raise_conflict_from_expl cc acts (Expl.mk_list c) SI.CC.raise_conflict_from_expl cc acts (Expl.mk_list c |> Expl.mk_theory)
(* traverse constructor arguments *) (* traverse constructor arguments *)
and traverse_sub n st: unit = and traverse_sub n st: unit =
IArray.iter IArray.iter
@ -468,6 +468,10 @@ module Make(A : ARG) : S with module A = A = struct
st.cstor.Monoid_cstor.c_args; st.cstor.Monoid_cstor.c_args;
in in
begin begin
(* TODO: instead, create whole graph here (repr -> cstor*repr list)
and then just look for cycles in the graph using DFS.
Be sure to annotate edges with all info for conflicts, so that the
full conflict is just the cycle itself. *)
(* populate tbl with [repr->cstor] *) (* populate tbl with [repr->cstor] *)
ST_cstors.iter_all self.cstors ST_cstors.iter_all self.cstors
(fun (n, cstor) -> (fun (n, cstor) ->
@ -495,7 +499,7 @@ module Make(A : ARG) : S with module A = A = struct
Log.debugf 50 Log.debugf 50
(fun k->k"(@[%s.assign-is-a@ :lhs %a@ :rhs %a@ :lit %a@])" (fun k->k"(@[%s.assign-is-a@ :lhs %a@ :rhs %a@ :lit %a@])"
name T.pp u T.pp rhs SI.Lit.pp lit); name T.pp u T.pp rhs SI.Lit.pp lit);
SI.cc_merge_t solver acts u rhs (Expl.mk_lit lit) SI.cc_merge_t solver acts u rhs (Expl.mk_theory @@ Expl.mk_lit lit)
| _ -> () | _ -> ()
in in
Iter.iter check_lit trail Iter.iter check_lit trail