diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index a64b2454..81937507 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -94,6 +94,7 @@ module Make (A: CC_ARG) | E_merge_t of term * term | E_congruence of node * node (* caused by normal congruence *) | E_and of explanation * explanation + | E_theory of explanation type repr = node @@ -166,6 +167,7 @@ module Make (A: CC_ARG) | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b | E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" Term.pp a Term.pp b + | E_theory e -> Fmt.fprintf out "(@[th@ %a@])" pp e | E_and (a,b) -> Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b @@ -174,6 +176,7 @@ module Make (A: CC_ARG) let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b) let[@inline] mk_merge_t a b : t = if Term.equal a b then mk_reduction else E_merge_t (a,b) let[@inline] mk_lit l : t = E_lit l + let mk_theory e = E_theory e let rec mk_list l = match l with @@ -275,7 +278,7 @@ module Make (A: CC_ARG) and ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit and ev_on_post_merge = t -> actions -> N.t -> N.t -> unit and ev_on_new_term = t -> N.t -> term -> unit - and ev_on_conflict = t -> lit list -> unit + and ev_on_conflict = t -> th:bool -> lit list -> unit and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit let[@inline] size_ (r:repr) = r.n_size @@ -368,11 +371,11 @@ module Make (A: CC_ARG) n.n_expl <- FL_none; end - let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ = + let raise_conflict (cc:t) ~th (acts:actions) (e:lit list) : _ = (* clear tasks queue *) Vec.clear cc.pending; Vec.clear cc.combine; - List.iter (fun f -> f cc e) cc.on_conflict; + List.iter (fun f -> f cc ~th e) cc.on_conflict; Stat.incr cc.count_conflict; Actions.raise_conflict acts e P.default @@ -429,7 +432,7 @@ module Make (A: CC_ARG) n (* decompose explanation [e] into a list of literals added to [acc] *) - let rec explain_decompose cc (acc:lit list) (e:explanation) : _ list = + let rec explain_decompose cc ~th (acc:lit list) (e:explanation) : _ list = Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); match e with | E_reduction -> acc @@ -438,49 +441,51 @@ module Make (A: CC_ARG) | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> assert (Fun.equal f1 f2); assert (List.length a1 = List.length a2); - List.fold_left2 (explain_pair cc) acc a1 a2 + List.fold_left2 (explain_pair cc ~th) acc a1 a2 | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> assert (List.length a1 = List.length a2); - let acc = explain_pair cc acc f1 f2 in - List.fold_left2 (explain_pair cc) acc a1 a2 + let acc = explain_pair cc ~th acc f1 f2 in + List.fold_left2 (explain_pair cc ~th) acc a1 a2 | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> - let acc = explain_pair cc acc a1 a2 in - let acc = explain_pair cc acc b1 b2 in - explain_pair cc acc c1 c2 + let acc = explain_pair cc ~th acc a1 a2 in + let acc = explain_pair cc ~th acc b1 b2 in + explain_pair cc ~th acc c1 c2 | _ -> assert false end | E_lit lit -> lit :: acc - | E_merge (a,b) -> explain_pair cc acc a b + | E_theory e' -> + th := true; explain_decompose cc ~th acc e' + | E_merge (a,b) -> explain_pair cc ~th acc a b | E_merge_t (a,b) -> (* find nodes for [a] and [b] on the fly *) begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with - | a, b -> explain_pair cc acc a b + | a, b -> explain_pair cc ~th acc a b | exception Not_found -> Error.errorf "expl: cannot find node(s) for %a, %a" Term.pp a Term.pp b end | E_and (a,b) -> - let acc = explain_decompose cc acc a in - explain_decompose cc acc b + let acc = explain_decompose cc ~th acc a in + explain_decompose cc ~th acc b - and explain_pair (cc:t) (acc:lit list) (a:node) (b:node) : _ list = + and explain_pair (cc:t) ~th (acc:lit list) (a:node) (b:node) : _ list = Log.debugf 5 (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); assert (N.equal (find_ a) (find_ b)); let ancestor = find_common_ancestor cc a b in - let acc = explain_along_path cc acc a ancestor in - explain_along_path cc acc b ancestor + let acc = explain_along_path cc ~th acc a ancestor in + explain_along_path cc ~th acc b ancestor (* explain why [a = parent_a], where [a -> ... -> target] in the proof forest *) - and explain_along_path cc acc (a:node) (target:node) : _ list = + and explain_along_path cc ~th acc (a:node) (target:node) : _ list = let rec aux acc n = if n == target then acc else ( match n.n_expl with | FL_none -> assert false | FL_some {next=next_n; expl=expl} -> - let acc = explain_decompose cc acc expl in + let acc = explain_decompose cc ~th acc expl in (* now prove [next_n = target] *) aux acc next_n ) @@ -631,10 +636,11 @@ module Make (A: CC_ARG) (fun k->k "(@[cc.merge.true_false_conflict@ \ @[:r1 %a@ :t1 %a@]@ @[:r2 %a@ :t2 %a@]@ :e_ab %a@])" N.pp ra N.pp a N.pp rb N.pp b Expl.pp e_ab); - let lits = explain_decompose cc [] e_ab in - let lits = explain_pair cc lits a ra in - let lits = explain_pair cc lits b rb in - raise_conflict cc acts (List.rev_map Lit.neg lits) + let th = ref false in + let lits = explain_decompose cc ~th [] e_ab in + let lits = explain_pair cc ~th lits a ra in + let lits = explain_pair cc ~th lits b rb in + raise_conflict cc ~th:!th acts (List.rev_map Lit.neg lits) ); (* We will merge [r_from] into [r_into]. we try to ensure that [size ra <= size rb] in general, but always @@ -727,8 +733,9 @@ module Make (A: CC_ARG) and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = (* explanation for [t1 =e= t2 = r2] *) let half_expl = lazy ( - let lits = explain_decompose cc [] e_12 in - explain_pair cc lits r2 t2 + let th = ref false in + let lits = explain_decompose cc ~th [] e_12 in + th, explain_pair cc ~th lits r2 t2 ) in (* TODO: flag per class, `or`-ed on merge, to indicate if the class contains at least one lit *) @@ -745,7 +752,10 @@ module Make (A: CC_ARG) Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); (* complete explanation with the [u1=t1] chunk *) let reason = - let e = lazy (explain_pair cc (Lazy.force half_expl) u1 t1) in + let e = lazy ( + let lazy (th, acc) = half_expl in + explain_pair cc ~th acc u1 t1 + ) in fun () -> Lazy.force e in List.iter (fun f -> f cc lit reason) cc.on_propagate; @@ -808,9 +818,10 @@ module Make (A: CC_ARG) let raise_conflict_from_expl cc (acts:actions) expl = Log.debugf 5 (fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); - let lits = explain_decompose cc [] expl in + let th = ref true in + let lits = explain_decompose cc ~th [] expl in let lits = List.rev_map Lit.neg lits in - raise_conflict cc acts lits + raise_conflict cc ~th:!th acts lits let merge cc n1 n2 expl = Log.debugf 5 diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 0f554612..55c4d657 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -192,6 +192,7 @@ module type CC_S = sig val mk_merge_t : term -> term -> t val mk_lit : lit -> t val mk_list : t list -> t + val mk_theory : t -> t (* TODO: indicate what theory, or even provide a lemma *) end type node = N.t @@ -216,7 +217,7 @@ module type CC_S = sig type ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit type ev_on_post_merge = t -> actions -> N.t -> N.t -> unit type ev_on_new_term = t -> N.t -> term -> unit - type ev_on_conflict = t -> lit list -> unit + type ev_on_conflict = t -> th:bool -> lit list -> unit type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit val create : @@ -445,7 +446,7 @@ module type SOLVER_INTERNAL = sig (** Callback to add data on terms when they are added to the congruence closure *) - val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit + val on_cc_conflict : t -> (CC.t -> th:bool -> lit list -> unit) -> unit (** Callback called on every CC conflict *) val on_cc_propagate : t -> (CC.t -> lit -> (unit -> lit list) -> unit) -> unit diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index 1a8b9e18..19e2c274 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -71,9 +71,11 @@ module Check_cc = struct ~create_and_setup:(fun si -> let n_calls = Stat.mk_int (Solver.Solver_internal.stats si) "check-cc.call" in Solver.Solver_internal.on_cc_conflict si - (fun c -> - Stat.incr n_calls; - check_conflict si c)) + (fun cc ~th c -> + if not th then ( + Stat.incr n_calls; + check_conflict si cc c + ))) () end diff --git a/src/th-data/Sidekick_th_data.ml b/src/th-data/Sidekick_th_data.ml index e023ed6e..235fc3e4 100644 --- a/src/th-data/Sidekick_th_data.ml +++ b/src/th-data/Sidekick_th_data.ml @@ -169,7 +169,7 @@ module Make(A : ARG) : S with module A = A = struct (* build full explanation of why the constructor terms are equal *) (* TODO: have a sort of lemma (injectivity) here to justify this in the proof *) let expl = - Expl.mk_list [ + Expl.mk_theory @@ Expl.mk_list [ e_n1_n2; Expl.mk_merge n1 c1.c_n; Expl.mk_merge n2 c2.c_n; @@ -331,7 +331,7 @@ module Make(A : ARG) : S with module A = A = struct Log.debugf 5 (fun k->k "(@[%s.on-new-term.is-a.reduce@ :t %a@ :to %B@ :n %a@ :sub-cstor %a@])" name T.pp t is_true N.pp n Monoid_cstor.pp cstor); - SI.CC.merge cc n (SI.CC.n_bool cc is_true) (Expl.mk_merge n_u repr_u) + SI.CC.merge cc n (SI.CC.n_bool cc is_true) (Expl.mk_theory @@ Expl.mk_merge n_u repr_u) end | T_select (c_t, i, u) -> let n_u = SI.CC.add_term cc u in @@ -344,7 +344,7 @@ module Make(A : ARG) : S with module A = A = struct assert (i < IArray.length cstor.c_args); let u_i = IArray.get cstor.c_args i in let n_u_i = SI.CC.add_term cc u_i in - SI.CC.merge cc n n_u_i (Expl.mk_merge n_u repr_u) + SI.CC.merge cc n n_u_i (Expl.mk_theory @@ Expl.mk_merge n_u repr_u) | Some _ -> () | None -> N_tbl.add self.to_decide repr_u (); (* needs to be decided *) @@ -364,7 +364,7 @@ module Make(A : ARG) : S with module A = A = struct name Monoid_parents.pp_is_a is_a2 is_true N.pp n1 N.pp n2 Monoid_cstor.pp c1); SI.CC.merge cc is_a2.is_a_n (SI.CC.n_bool cc is_true) Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2; - mk_merge_t (N.term n2) is_a2.is_a_arg]) + mk_merge_t (N.term n2) is_a2.is_a_arg] |> mk_theory) in let merge_select n1 (c1:Monoid_cstor.t) n2 (sel2:Monoid_parents.select) = if A.Cstor.equal c1.c_cstor sel2.sel_cstor then ( @@ -376,7 +376,7 @@ module Make(A : ARG) : S with module A = A = struct let n_u_i = SI.CC.add_term cc u_i in SI.CC.merge cc sel2.sel_n n_u_i Expl.(mk_list [mk_merge n1 c1.c_n; mk_merge n1 n2; - mk_merge_t (N.term n2) sel2.sel_arg]); + mk_merge_t (N.term n2) sel2.sel_arg] |> mk_theory); ) in let merge_c_p n1 n2 = @@ -409,7 +409,7 @@ module Make(A : ARG) : S with module A = A = struct | Current of parent_uplink option let pp_st out st = - Fmt.fprintf out "(@[st :cstor %a@ :flag %s@])" + Fmt.fprintf out "(@[acycl.st :cstor %a@ :flag %s@])" Monoid_cstor.pp st.cstor (match st.flag with | New -> "new" | Done -> "done" @@ -454,7 +454,7 @@ module Make(A : ARG) : S with module A = A = struct | None -> c0 | Some parent -> mk_path c0 n parent in - SI.CC.raise_conflict_from_expl cc acts (Expl.mk_list c) + SI.CC.raise_conflict_from_expl cc acts (Expl.mk_list c |> Expl.mk_theory) (* traverse constructor arguments *) and traverse_sub n st: unit = IArray.iter @@ -468,6 +468,10 @@ module Make(A : ARG) : S with module A = A = struct st.cstor.Monoid_cstor.c_args; in begin + (* TODO: instead, create whole graph here (repr -> cstor*repr list) + and then just look for cycles in the graph using DFS. + Be sure to annotate edges with all info for conflicts, so that the + full conflict is just the cycle itself. *) (* populate tbl with [repr->cstor] *) ST_cstors.iter_all self.cstors (fun (n, cstor) -> @@ -495,7 +499,7 @@ module Make(A : ARG) : S with module A = A = struct Log.debugf 50 (fun k->k"(@[%s.assign-is-a@ :lhs %a@ :rhs %a@ :lit %a@])" name T.pp u T.pp rhs SI.Lit.pp lit); - SI.cc_merge_t solver acts u rhs (Expl.mk_lit lit) + SI.cc_merge_t solver acts u rhs (Expl.mk_theory @@ Expl.mk_lit lit) | _ -> () in Iter.iter check_lit trail