diff --git a/src/cc/CC.ml b/src/cc/CC.ml index 64786b2f..56718a49 100644 --- a/src/cc/CC.ml +++ b/src/cc/CC.ml @@ -302,41 +302,48 @@ module Expl_state = struct let { lits = o_lits; th_lemmas = o_lemmas } = other in self.lits <- List.rev_append o_lits self.lits; self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas; - () + let emit_cc_eq_proof (self : t) (_tracer : Proof.Tracer.t) : Proof.Pterm.t = + let neg_lits = List.rev_map Lit.neg self.lits in + Proof.Pterm.apply_rule ~lits:neg_lits "core.cc-eq-proof" + (* proof of [\/_i ¬lits[i]] *) let proof_of_th_lemmas (self : t) (tracer : Proof.Tracer.t) : Proof.Pterm.delayed = Proof.Pterm.delay @@ fun () -> - (* Emit each sub-proof immediately; use its offset (Step.id) as a P_ref. *) - let bind (t : Proof.Pterm.t) : Proof.Step.id = - Proof.Tracer.add_step tracer (Proof.Pterm.delay (fun () -> t)) - in - - let p_lits1 = List.rev_map Lit.neg self.lits in - let p_lits2 = - self.th_lemmas |> List.rev_map (fun (lit_t_u, _, _) -> Lit.neg lit_t_u) - in - let p_cc = Proof.Core_rules.lemma_cc (List.rev_append p_lits1 p_lits2) in - let resolve_with_th_proof pr (lit_t_u, sub_proofs, pr_th) = - let pr_th = pr_th () in - let pr_th = - List.fold_left - (fun pr_th (lit_i, hyps_i) -> - let lemma_i = - bind - @@ Proof.Core_rules.lemma_cc (lit_i :: List.rev_map Lit.neg hyps_i) - in - Proof.Core_rules.proof_res ~pivot:(Lit.term lit_i) lemma_i - (bind pr_th)) - pr_th sub_proofs + if self.th_lemmas = [] then + emit_cc_eq_proof self tracer + else ( + let bind (t : Proof.Pterm.t) : Proof.Step.id = + Proof.Tracer.add_step tracer (Proof.Pterm.delay (fun () -> t)) in - Proof.Core_rules.proof_res ~pivot:(Lit.term lit_t_u) (bind pr_th) - (bind pr) - in - let body = List.fold_left resolve_with_th_proof p_cc self.th_lemmas in - body + + let p_lits1 = List.rev_map Lit.neg self.lits in + let p_lits2 = + self.th_lemmas |> List.rev_map (fun (lit_t_u, _, _) -> Lit.neg lit_t_u) + in + let p_cc = Proof.Core_rules.lemma_cc (List.rev_append p_lits1 p_lits2) in + let resolve_with_th_proof pr (lit_t_u, sub_proofs, pr_th) = + let pr_th = pr_th () in + let pr_th = + List.fold_left + (fun pr_th (lit_i, hyps_i) -> + let lemma_i = + bind + @@ Proof.Core_rules.lemma_cc + (lit_i :: List.rev_map Lit.neg hyps_i) + in + Proof.Core_rules.proof_res ~pivot:(Lit.term lit_i) lemma_i + (bind pr_th)) + pr_th sub_proofs + in + Proof.Core_rules.proof_res ~pivot:(Lit.term lit_t_u) (bind pr_th) + (bind pr) + in + let body = List.fold_left resolve_with_th_proof p_cc self.th_lemmas in + body + ) let to_resolved_expl (self : t) (tracer : Proof.Tracer.t) : Resolved_expl.t = let { lits; th_lemmas = _ } = self in diff --git a/src/proof_minidag/proof_encoder.ml b/src/proof_minidag/proof_encoder.ml index 69950aeb..e2f1a110 100644 --- a/src/proof_minidag/proof_encoder.ml +++ b/src/proof_minidag/proof_encoder.ml @@ -34,7 +34,7 @@ let emit_seq self ~hyps ~concls = (** Emit [p.hyp] with the given conclusion offsets and no hypotheses. *) let emit_hyp self concls = let seq = emit_seq self ~hyps:[] ~concls in - nd self "p.hyp" (fun e -> E.ref e seq) + nd self "hol.hypothesis" (fun e -> E.ref e seq) (** Emit [sk.sorry] with a descriptive message. *) let emit_sorry self msg = nd self "sk.sorry" (fun e -> E.string e msg) @@ -60,10 +60,158 @@ let emit_sat_rup self hyp_sids = let dag_offs = List.map (step_off self) hyp_sids in nd self "sk.sat_rup" (fun e -> List.iter (E.ref e) dag_offs) -(** CC conflict: oracle step referencing all conflicting lits. *) -let emit_cc_conflict self lits = - let lit_offs = List.map (encode_lit' self) lits in - nd self "sk.cc_conflict" (fun e -> List.iter (E.ref e) lit_offs) +(** Extract equality pairs [(a, b)] from positive equality literals. A positive + literal [a = b] (encoded as [(= a b)]) yields [(a, b)]. A negative literal + [not(a = b)] is skipped. *) +let eq_pairs_of_lits (_tst : Term.store) lits = + List.filter_map + (fun lit -> + if Lit.sign lit then ( + let t = Lit.term lit in + match Term.view t with + | Term.E_app (eq, a) -> + (match Term.view eq with + | Term.E_app (_, b) -> Some (a, b) + | _ -> None) + | _ -> None + ) else + None) + lits + +(** Compute congruence closure steps from a set of equalities. Uses a union-find + over [Term.t] (by identity/physical equality). Returns a list of [(t, u)] + pairs for [eq.c] steps in an order that respects dependencies (congruence + steps use terms whose sub-terms are already equal via prior unions or + congruences). *) +let compute_cc_steps eq_pairs = + let uf = Term.Tbl.create 16 in + let rec find t = + match Term.Tbl.find_opt uf t with + | None -> t + | Some r -> + if Term.equal r t then + t + else ( + let r = find r in + Term.Tbl.replace uf t r; + r + ) + in + let union a b = + let a = find a and b = find b in + if not (Term.equal a b) then Term.Tbl.replace uf a b + in + (* Collect all application terms reachable from both sides of equalities *) + let all_terms = + let seen = Term.Tbl.create 16 in + let acc = ref [] in + let rec collect t = + if Term.Tbl.mem seen t then + () + else ( + Term.Tbl.add seen t (); + acc := t :: !acc; + match Term.view t with + | Term.E_app (f, x) -> + collect f; + collect x + | _ -> () + ) + in + List.iter + (fun (a, b) -> + collect a; + collect b) + eq_pairs; + !acc + in + let app_terms = + List.filter_map + (fun t -> + match Term.view t with + | Term.E_app (f, x) -> Some (t, f, x) + | _ -> None) + all_terms + in + (* Step 1: do all initial unions *) + List.iter (fun (a, b) -> union a b) eq_pairs; + (* Step 2: fixed-point congruence detection *) + let module PairKey = struct + type t = Term.t * Term.t + + let equal (a1, b1) (a2, b2) = Term.equal a1 a2 && Term.equal b1 b2 + let hash (a, b) = Hash.combine2 (Term.hash a) (Term.hash b) + end in + let module PairTbl = CCHashtbl.Make (PairKey) in + let congr_pairs = ref [] in + let changed = ref true in + while !changed do + changed := false; + let app_sig = PairTbl.create 16 in + List.iter + (fun (t, f, x) -> + let key = find f, find x in + match PairTbl.get app_sig key with + | None -> PairTbl.add app_sig key t + | Some other -> + if not (Term.equal (find other) (find t)) then ( + congr_pairs := (other, t) :: !congr_pairs; + union other t; + changed := true + )) + app_terms + done; + eq_pairs, List.rev !congr_pairs + +(** Emit a [p.eq] proof from negated conflict literals. Extracts equalities, + computes congruence closure, emits [eq.u]/[eq.c] steps. *) +let emit_cc_eq_proof self neg_lits = + let true_off = encode_term' self (Term.true_ self.tst) in + let false_off = encode_term' self (Term.false_ self.tst) in + let lit_offs = List.map (encode_lit' self) neg_lits in + (* Build dag: one hypothesis per negated literal *) + let dag_offs = + Array.of_list + (List.map + (fun lit_off -> + let seq = emit_seq self ~hyps:[] ~concls:[ lit_off ] in + nd self "hol.hypothesis" (fun e -> E.ref e seq)) + lit_offs) + in + (* Extract equalities and compute congruences *) + let eq_pairs, congr_pairs = + compute_cc_steps (eq_pairs_of_lits self.tst neg_lits) + in + (* Emit eq.u for each equality, using dag[i] for the i-th equality literal *) + let eq_step_offs = ref [] in + List.iteri + (fun i (a, b) -> + let eq_lit = encode_term' self (Term.eq self.tst a b) in + let step = + nd self "eq.u" (fun e -> + E.ref e dag_offs.(i); + E.ref e eq_lit) + in + eq_step_offs := !eq_step_offs @ [ step ]) + eq_pairs; + (* Emit eq.c for each congruence pair *) + List.iter + (fun (t, u) -> + let t_off = encode_term' self t in + let u_off = encode_term' self u in + let step = + nd self "eq.c" (fun e -> + E.ref e t_off; + E.ref e u_off) + in + eq_step_offs := !eq_step_offs @ [ step ]) + congr_pairs; + nd self "p.eq" (fun e -> + E.ref e true_off; + E.ref e false_off; + List.iter (E.ref e) !eq_step_offs; + E.null e; + Array.iter (E.ref e) dag_offs) (** Boolean axiom: any [bool.*] rule name. *) let emit_bool_ax self name term_args = @@ -111,7 +259,10 @@ let rec encode_rule self (r : Pterm.rule_apply) : E.offset = E.ref e o1; E.ref e o2) | _ -> emit_sorry self "core.p1: bad args") - | "core.lemma-cc" -> emit_cc_conflict self lit_args + | "core.lemma-cc" -> + let lit_offs = List.map (encode_lit' self) lit_args in + nd self "sk.cc_conflict" (fun e -> List.iter (E.ref e) lit_offs) + | "core.cc-eq-proof" -> emit_cc_eq_proof self lit_args | "core.define-term" -> (match term_args with | [ c; rhs ] -> emit_hyp self [ encode_term' self (Term.eq self.tst c rhs) ]