mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 11:15:43 -05:00
feat(cc): add small union-find on the side to make expls smaller
This commit is contained in:
parent
9c57dad3f1
commit
cca2c48f07
3 changed files with 93 additions and 10 deletions
26
src/cc/CC.ml
26
src/cc/CC.ml
|
|
@ -286,11 +286,14 @@ let find_common_ancestor self (a : e_node) (b : e_node) : e_node =
|
|||
module Expl_state = struct
|
||||
type t = {
|
||||
mutable lits: Lit.t list;
|
||||
proven_eq: Small_uf.t;
|
||||
mutable th_lemmas:
|
||||
(Lit.t * (Lit.t * Lit.t list) list * Proof_term.step_id) list;
|
||||
}
|
||||
|
||||
let create () : t = { lits = []; th_lemmas = [] }
|
||||
let create () : t =
|
||||
{ lits = []; th_lemmas = []; proven_eq = Small_uf.create () }
|
||||
|
||||
let[@inline] copy self : t = { self with lits = self.lits }
|
||||
let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits
|
||||
|
||||
|
|
@ -298,7 +301,7 @@ module Expl_state = struct
|
|||
self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas
|
||||
|
||||
let merge self other =
|
||||
let { lits = o_lits; th_lemmas = o_lemmas } = other in
|
||||
let { lits = o_lits; th_lemmas = o_lemmas; proven_eq = _ } = other in
|
||||
self.lits <- List.rev_append o_lits self.lits;
|
||||
self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas;
|
||||
()
|
||||
|
|
@ -339,7 +342,7 @@ module Expl_state = struct
|
|||
|
||||
let to_resolved_expl (self : t) : Resolved_expl.t =
|
||||
(* FIXME: package the th lemmas too *)
|
||||
let { lits; th_lemmas = _ } = self in
|
||||
let { lits; th_lemmas = _l; proven_eq = _ } = self in
|
||||
let s2 = copy self in
|
||||
let pr proof = proof_of_th_lemmas s2 proof in
|
||||
{ Resolved_expl.lits; pr }
|
||||
|
|
@ -396,15 +399,18 @@ and explain_expls self (es : explanation list) : Expl_state.t =
|
|||
List.iter (explain_decompose_expl self st) es;
|
||||
st
|
||||
|
||||
(* explain why [a =_E b] *)
|
||||
and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) :
|
||||
unit =
|
||||
if a != b then (
|
||||
if a != b && not (Small_uf.same_class st.proven_eq a.n_term b.n_term) then (
|
||||
Log.debugf 5 (fun k ->
|
||||
k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b);
|
||||
assert (E_node.equal (find_ a) (find_ b));
|
||||
let ancestor = find_common_ancestor cc a b in
|
||||
explain_along_path cc st a ancestor;
|
||||
explain_along_path cc st b ancestor
|
||||
explain_along_path cc st b ancestor;
|
||||
(* we now know that [a=b]. *)
|
||||
Small_uf.merge st.proven_eq a.n_term b.n_term
|
||||
)
|
||||
|
||||
(* explain why [a = target], where [a -> ... -> target] in the
|
||||
|
|
@ -415,11 +421,11 @@ and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) :
|
|||
if n != target then (
|
||||
match n.n_expl with
|
||||
| FL_none -> assert false
|
||||
| FL_some { next = next_a; expl } ->
|
||||
(* prove [a = next_n] *)
|
||||
| FL_some { next = next_n; expl } ->
|
||||
(* prove [n = next_n] *)
|
||||
explain_decompose_expl self st expl;
|
||||
(* now prove [next_a = target] *)
|
||||
aux next_a
|
||||
(* now prove [next_n = target] *)
|
||||
aux next_n
|
||||
)
|
||||
in
|
||||
aux a
|
||||
|
|
@ -510,7 +516,7 @@ let n_is_bool_value (self : t) n : bool =
|
|||
merges. *)
|
||||
let lits_and_proof_of_expl (self : t) (st : Expl_state.t) :
|
||||
Lit.t list * Proof_term.step_id =
|
||||
let { Expl_state.lits; th_lemmas = _ } = st in
|
||||
let { Expl_state.lits; th_lemmas = _; proven_eq = _ } = st in
|
||||
let pr = Expl_state.proof_of_th_lemmas st self.proof in
|
||||
lits, pr
|
||||
|
||||
|
|
|
|||
60
src/cc/small_uf.ml
Normal file
60
src/cc/small_uf.ml
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
open Sidekick_core
|
||||
module T = Term
|
||||
module T_tbl = Term.Tbl
|
||||
|
||||
type node = {
|
||||
mutable n_next: node; (* next in class *)
|
||||
mutable n_size: int; (* size of class *)
|
||||
}
|
||||
|
||||
module Node = struct
|
||||
type t = node
|
||||
|
||||
let[@inline] equal (n1 : t) n2 = n1 == n2
|
||||
let[@inline] size (n : t) = n.n_size
|
||||
let[@inline] is_root n = n == n.n_next
|
||||
|
||||
let[@unroll 2] rec root n =
|
||||
if n.n_next == n then
|
||||
n
|
||||
else (
|
||||
let r = root n.n_next in
|
||||
n.n_next <- r;
|
||||
r
|
||||
)
|
||||
|
||||
let make () : t =
|
||||
let rec n = { n_size = 1; n_next = n } in
|
||||
n
|
||||
end
|
||||
|
||||
type t = { tbl: node T_tbl.t }
|
||||
|
||||
let create () : t = { tbl = T_tbl.create 8 }
|
||||
let clear (self : t) : unit = T_tbl.clear self.tbl
|
||||
|
||||
let add_ (self : t) (t : T.t) : node =
|
||||
try T_tbl.find self.tbl t
|
||||
with Not_found ->
|
||||
let n = Node.make () in
|
||||
T_tbl.add self.tbl t n;
|
||||
n
|
||||
|
||||
let merge (self : t) (t1 : T.t) (t2 : T.t) =
|
||||
let n1 = add_ self t1 |> Node.root in
|
||||
let n2 = add_ self t2 |> Node.root in
|
||||
if n1 != n2 then (
|
||||
let n1, n2 =
|
||||
if Node.size n1 > Node.size n2 then
|
||||
n1, n2
|
||||
else
|
||||
n2, n1
|
||||
in
|
||||
n2.n_next <- n1;
|
||||
n1.n_size <- n1.n_size + n2.n_size
|
||||
)
|
||||
|
||||
let same_class (self : t) (t1 : T.t) (t2 : T.t) : bool =
|
||||
let n1 = add_ self t1 |> Node.root in
|
||||
let n2 = add_ self t2 |> Node.root in
|
||||
n1 == n2
|
||||
17
src/cc/small_uf.mli
Normal file
17
src/cc/small_uf.mli
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
(** Small union find.
|
||||
|
||||
No backtracking nor explanations.
|
||||
*)
|
||||
|
||||
open Sidekick_core
|
||||
|
||||
type t
|
||||
(** An instance of the congruence closure. Mutable *)
|
||||
|
||||
val create : unit -> t
|
||||
|
||||
val clear : t -> unit
|
||||
(** Fully reset the state *)
|
||||
|
||||
val merge : t -> Term.t -> Term.t -> unit
|
||||
val same_class : t -> Term.t -> Term.t -> bool
|
||||
Loading…
Add table
Reference in a new issue