diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 9f48086b..dec46ea5 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -139,6 +139,17 @@ module Make(CC_A: ARG) = struct module N_tbl = CCHashtbl.Make(N) + (* non-recursive, inlinable function for [find] *) + let[@inline] find_ (n:node) : repr = + let n2 = n.n_root in + assert (N.is_root n2); + n2 + + let[@inline] same_class (n1:node)(n2:node): bool = + N.equal (find_ n1) (find_ n2) + + let[@inline] find _ n = find_ n + module Expl = struct type t = explanation @@ -153,7 +164,9 @@ module Make(CC_A: ARG) = struct let mk_reduction : t = E_reduction let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) - let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b) + let[@inline] mk_merge a b : t = + assert (same_class a b); + if N.equal a b then mk_reduction else E_merge (a,b) let[@inline] mk_merge_t a b : t = if T.equal a b then mk_reduction else E_merge_t (a,b) let[@inline] mk_lit l : t = E_lit l @@ -271,16 +284,6 @@ module Make(CC_A: ARG) = struct Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t - (* non-recursive, inlinable function for [find] *) - let[@inline] find_ (n:node) : repr = - let n2 = n.n_root in - assert (N.is_root n2); - n2 - - let[@inline] same_class (n1:node)(n2:node): bool = - N.equal (find_ n1) (find_ n2) - - let[@inline] find _ n = find_ n (* print full state *) let pp_full out (cc:t) : unit =