diff --git a/src/cc/CC.ml b/src/cc/CC.ml index 0f8cfe5b..cc68e9a0 100644 --- a/src/cc/CC.ml +++ b/src/cc/CC.ml @@ -286,11 +286,14 @@ let find_common_ancestor self (a : e_node) (b : e_node) : e_node = module Expl_state = struct type t = { mutable lits: Lit.t list; + proven_eq: Small_uf.t; mutable th_lemmas: (Lit.t * (Lit.t * Lit.t list) list * Proof_term.step_id) list; } - let create () : t = { lits = []; th_lemmas = [] } + let create () : t = + { lits = []; th_lemmas = []; proven_eq = Small_uf.create () } + let[@inline] copy self : t = { self with lits = self.lits } let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits @@ -298,7 +301,7 @@ module Expl_state = struct self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas let merge self other = - let { lits = o_lits; th_lemmas = o_lemmas } = other in + let { lits = o_lits; th_lemmas = o_lemmas; proven_eq = _ } = other in self.lits <- List.rev_append o_lits self.lits; self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas; () @@ -339,7 +342,7 @@ module Expl_state = struct let to_resolved_expl (self : t) : Resolved_expl.t = (* FIXME: package the th lemmas too *) - let { lits; th_lemmas = _ } = self in + let { lits; th_lemmas = _l; proven_eq = _ } = self in let s2 = copy self in let pr proof = proof_of_th_lemmas s2 proof in { Resolved_expl.lits; pr } @@ -396,15 +399,18 @@ and explain_expls self (es : explanation list) : Expl_state.t = List.iter (explain_decompose_expl self st) es; st +(* explain why [a =_E b] *) and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) : unit = - if a != b then ( + if a != b && not (Small_uf.same_class st.proven_eq a.n_term b.n_term) then ( Log.debugf 5 (fun k -> k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b); assert (E_node.equal (find_ a) (find_ b)); let ancestor = find_common_ancestor cc a b in explain_along_path cc st a ancestor; - explain_along_path cc st b ancestor + explain_along_path cc st b ancestor; + (* we now know that [a=b]. *) + Small_uf.merge st.proven_eq a.n_term b.n_term ) (* explain why [a = target], where [a -> ... -> target] in the @@ -415,11 +421,11 @@ and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) : if n != target then ( match n.n_expl with | FL_none -> assert false - | FL_some { next = next_a; expl } -> - (* prove [a = next_n] *) + | FL_some { next = next_n; expl } -> + (* prove [n = next_n] *) explain_decompose_expl self st expl; - (* now prove [next_a = target] *) - aux next_a + (* now prove [next_n = target] *) + aux next_n ) in aux a @@ -510,7 +516,7 @@ let n_is_bool_value (self : t) n : bool = merges. *) let lits_and_proof_of_expl (self : t) (st : Expl_state.t) : Lit.t list * Proof_term.step_id = - let { Expl_state.lits; th_lemmas = _ } = st in + let { Expl_state.lits; th_lemmas = _; proven_eq = _ } = st in let pr = Expl_state.proof_of_th_lemmas st self.proof in lits, pr diff --git a/src/cc/small_uf.ml b/src/cc/small_uf.ml new file mode 100644 index 00000000..a4c59e25 --- /dev/null +++ b/src/cc/small_uf.ml @@ -0,0 +1,60 @@ +open Sidekick_core +module T = Term +module T_tbl = Term.Tbl + +type node = { + mutable n_next: node; (* next in class *) + mutable n_size: int; (* size of class *) +} + +module Node = struct + type t = node + + let[@inline] equal (n1 : t) n2 = n1 == n2 + let[@inline] size (n : t) = n.n_size + let[@inline] is_root n = n == n.n_next + + let[@unroll 2] rec root n = + if n.n_next == n then + n + else ( + let r = root n.n_next in + n.n_next <- r; + r + ) + + let make () : t = + let rec n = { n_size = 1; n_next = n } in + n +end + +type t = { tbl: node T_tbl.t } + +let create () : t = { tbl = T_tbl.create 8 } +let clear (self : t) : unit = T_tbl.clear self.tbl + +let add_ (self : t) (t : T.t) : node = + try T_tbl.find self.tbl t + with Not_found -> + let n = Node.make () in + T_tbl.add self.tbl t n; + n + +let merge (self : t) (t1 : T.t) (t2 : T.t) = + let n1 = add_ self t1 |> Node.root in + let n2 = add_ self t2 |> Node.root in + if n1 != n2 then ( + let n1, n2 = + if Node.size n1 > Node.size n2 then + n1, n2 + else + n2, n1 + in + n2.n_next <- n1; + n1.n_size <- n1.n_size + n2.n_size + ) + +let same_class (self : t) (t1 : T.t) (t2 : T.t) : bool = + let n1 = add_ self t1 |> Node.root in + let n2 = add_ self t2 |> Node.root in + n1 == n2 diff --git a/src/cc/small_uf.mli b/src/cc/small_uf.mli new file mode 100644 index 00000000..d14e7912 --- /dev/null +++ b/src/cc/small_uf.mli @@ -0,0 +1,17 @@ +(** Small union find. + + No backtracking nor explanations. +*) + +open Sidekick_core + +type t +(** An instance of the congruence closure. Mutable *) + +val create : unit -> t + +val clear : t -> unit +(** Fully reset the state *) + +val merge : t -> Term.t -> Term.t -> unit +val same_class : t -> Term.t -> Term.t -> bool