diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 89ef0beb..98792614 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -57,49 +57,16 @@ module Make (A: CC_ARG) module Bit_field : Int_id.S = Int_id.Make() type bitfield = Bit_field.t - (* TODO: remove - module Bits : sig - type t = private int - type field - type bitfield_gen - val empty : t - val equal : t -> t -> bool - val mk_field : bitfield_gen -> field - val mk_gen : unit -> bitfield_gen - val get : field -> t -> bool - val set : field -> bool -> t -> t - val merge : t -> t -> t - end = struct - type bitfield_gen = int ref - let max_width = Sys.word_size - 2 - let mk_gen() = ref 0 - type t = int - type field = int - let empty : t = 0 - let mk_field (gen:bitfield_gen) : field = - let n = !gen in - if n > max_width then Error.errorf "maximum number of CC bitfields reached"; - incr gen; - 1 lsl n - let[@inline] get field x = (x land field) <> 0 - let[@inline] set field b x = - if b then x lor field else x land (lnot field) - let merge = (lor) - let equal : t -> t -> bool = CCEqual.poly - end - *) - (* TODO: sparse vec for n_sig0? *) - (* the node store, holds data for all the nodes *) type node_store = { n_term: term Vec.t; (* term for the node *) - n_sig0: signature Vec.t; (* initial signature, if any *) + n_sig0: signature Vec.t; (* initial signature, if any. to be modified. *) n_parents: node Bag.t Vec.t; (* node -> parents(class(node)) *) n_root: NVec.t; (* node -> repr(class(node)) *) n_next: NVec.t; (* node -> next(class(node)) *) n_size: VecI32.t; (* node -> size(class(node)) *) - n_as_lit: lit Int_tbl.t; (* root -> literal, if any *) + n_as_lit: lit Node0.Tbl.t; (* root -> literal, if any *) n_expl: explanation_forest_link Vec.t; (* proof forest *) n_bitfields: Bitvec.t Vec.t; (* bitfield idx -> atom -> bool *) } @@ -130,7 +97,24 @@ module Make (A: CC_ARG) let[@inline] term self n = Vec.get self.n_term (n:t:>int) let[@inline] pp self out n = Term.pp out (term self n) - let[@inline] as_lit self n = n.n_as_lit + + let[@inline] sig0 self n = Vec.get self.n_sig0 (n:t:>int) + let[@inline] set_sig0 self n s = Vec.set self.n_sig0 (n:t:>int) s + let[@inline] size self n = VecI32.get self.n_size (n:t:>int) + let[@inline] set_size self n sz = VecI32.set self.n_size (n:t:>int) sz + let[@inline] next self n = NVec.get self.n_next (n:t:>int) + let[@inline] set_next self n r = NVec.set self.n_next (n:t:>int) r + let[@inline] root self n = NVec.get self.n_root (n:t:>int) + let[@inline] set_root self n r = NVec.set self.n_root (n:t:>int) r + let[@inline] expl self n = Vec.get self.n_expl (n:t:>int) + let[@inline] set_expl self n e = Vec.set self.n_expl (n:t:>int) e + let[@inline] parents self n = Vec.get self.n_parents (n:t:>int) + let[@inline] set_parents self n b = Vec.set self.n_parents (n:t:>int) b + let[@inline] upd_parents ~f self n = set_parents self n (f (parents self n)) + + let[@inline] as_lit self n = Tbl.get self.n_as_lit n + let[@inline] set_as_lit self n lit = Tbl.replace self.n_as_lit n lit + let[@inline] clear_as_lit self n = Tbl.remove self.n_as_lit n let alloc (self:store) (t:term) : t = let { @@ -138,13 +122,27 @@ module Make (A: CC_ARG) } = self in let n = Node0.of_int_unsafe (Vec.size n_term) in Vec.push n_term t; - Vec.push n_sig0 (Opaque n); (* to be changed *) + Vec.push n_sig0 (Opaque n); (* will be updated *) Vec.push n_parents Bag.empty; NVec.push n_root n; NVec.push n_next n; VecI32.push n_size 1; n + (* dealloc node. It must be the last node allocated. *) + let dealloc (self:store) (n:t) : unit = + assert ((n:>int) + 1 = Vec.size self.n_term); + let { + n_term; n_sig0; n_parents; n_root; n_next; n_size + } = self in + ignore (Vec.pop_exn n_term : term); + ignore (Vec.pop_exn n_sig0 : signature); + ignore (Vec.pop_exn n_parents : _ Bag.t); + ignore (NVec.pop n_root : t); + ignore (NVec.pop n_next : t); + ignore (VecI32.pop n_size : int); + () + let[@inline] is_root (self:store) (n:node) : bool = let n2 = NVec.get self.n_root (n:t:>int) in equal n n2 @@ -167,9 +165,26 @@ module Make (A: CC_ARG) assert (is_root self n); Bag.to_iter (Vec.get self.n_parents (n:t:>int)) - (* TODO: use a vec of bitvec *) type nonrec bitfield = bitfield + let alloc_bitfield ~descr (self:store) : bitfield = + Log.debugf 5 (fun k->k "(@[cc.allocate-bit-field@ :descr %s@])" descr); + let field = Bit_field.of_int_unsafe (Vec.size self.n_bitfields) in + Vec.push self.n_bitfields (Bitvec.create()); + field + + let create () : store = { + n_term=Vec.create (); + n_sig0=Vec.create (); + n_root=NVec.create ~cap:1024 (); + n_next=NVec.create ~cap:1024 (); + n_parents=Vec.create (); + n_size=VecI32.create ~cap:1024 (); + n_expl=Vec.create (); + n_as_lit=Tbl.create 256; + n_bitfields=Vec.create(); + } + let[@inline] get_field (self:store) (f:bitfield) (n:t) = let bv = Vec.get self.n_bitfields (f:>int) in Bitvec.get bv (n:t:>int) @@ -193,12 +208,17 @@ module Make (A: CC_ARG) module Expl = struct type t = explanation - let rec pp out (e:explanation) = match e with + let rec pp nstore out (e:explanation) : unit = + let ppn = N.pp nstore in + let pp = pp nstore in + match e with | E_reduction -> Fmt.string out "reduction" | E_lit lit -> Lit.pp out lit - | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 - | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b - | E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp a Term.pp b + | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" ppn n1 ppn n2 + | E_merge (a,b) -> + Fmt.fprintf out "(@[merge@ %a@ %a@])" ppn a ppn b + | E_merge_t (a,b) -> + Fmt.fprintf out "(@[merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp a Term.pp b | E_theory e -> Fmt.fprintf out "(@[th@ %a@])" pp e | E_and (a,b) -> Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b @@ -254,15 +274,17 @@ module Make (A: CC_ARG) | If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c) | Not u -> H.combine2 70 (N.hash u) - let pp out = function + let pp nstore out s = + let ppn = N.pp nstore in + match s with | Bool b -> Fmt.bool out b | App_fun (f, []) -> Fun.pp out f - | App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list N.pp) l - | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f N.pp a - | Opaque t -> N.pp out t - | Not u -> Fmt.fprintf out "(@[not@ %a@])" N.pp u - | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b - | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c + | App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list ppn) l + | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" ppn f ppn a + | Opaque t -> ppn out t + | Not u -> Fmt.fprintf out "(@[not@ %a@])" ppn u + | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" ppn a ppn b + | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" ppn a ppn b ppn c end module Sig_tbl = CCHashtbl.Make(Signature) @@ -273,6 +295,7 @@ module Make (A: CC_ARG) type t = { tst: term_store; + nstore: N.store; tbl: node T_tbl.t; (* internalization [term -> node] *) signatures_tbl : node Sig_tbl.t; @@ -293,8 +316,7 @@ module Make (A: CC_ARG) mutable on_propagate: ev_on_propagate list; mutable on_is_subterm : ev_on_is_subterm list; mutable new_merges: bool; - bitgen: Bits.bitfield_gen; - field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *) + field_marked_explain: N.bitfield; (* used to mark traversed nodes when looking for a common ancestor *) true_ : node lazy_t; false_ : node lazy_t; stat: Stat.t; @@ -315,24 +337,32 @@ module Make (A: CC_ARG) and ev_on_propagate = t -> lit -> (unit -> lit list * (proof -> unit)) -> unit and ev_on_is_subterm = N.t -> term -> unit - let[@inline] size_ (r:repr) = r.n_size let[@inline] n_true cc = Lazy.force cc.true_ let[@inline] n_false cc = Lazy.force cc.false_ let n_bool cc b = if b then n_true cc else n_false cc - let[@inline] term_store cc = cc.tst - let allocate_bitfield ~descr cc = - Log.debugf 5 (fun k->k "(@[cc.allocate-bit-field@ :descr %s@])" descr); - Bits.mk_field cc.bitgen + let[@inline] term_store (cc:t) = cc.tst + let[@inline] n_store (cc:t) = cc.nstore + + (* new bitfield *) + let allocate_bitfield ~descr self = N.alloc_bitfield ~descr self.nstore + + (* iterate on existing bitfields *) + let[@inline] iter_bitfields (self:t) ~(f:Bit_field.t -> unit) : unit = + for i=0 to Vec.size self.nstore.n_bitfields - 1 do + f (Bit_field.of_int_unsafe i) + done let[@inline] on_backtrack cc f : unit = Backtrack_stack.push_if_nonzero_level cc.undo f - let[@inline] get_bitfield _cc field n = N.get_field field n + let[@inline] get_bitfield cc field n = + N.get_field cc.nstore field n + let set_bitfield cc field b n = - let old = N.get_field field n in + let old = N.get_field cc.nstore field n in if old <> b then ( - on_backtrack cc (fun () -> N.set_field field old n); - N.set_field field b n; + on_backtrack cc (fun () -> N.set_field cc.nstore field n old); + N.set_field cc.nstore field n b; ) (* check if [t] is in the congruence closure. @@ -341,19 +371,25 @@ module Make (A: CC_ARG) (* print full state *) let pp_full out (cc:t) : unit = + let nstore = cc.nstore in + let ppn = N.pp nstore in let pp_next out n = - Fmt.fprintf out "@ :next %a" N.pp n.n_next in + Fmt.fprintf out "@ :next %a" ppn (N.next nstore n) in let pp_root out n = - if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in - let pp_expl out n = match n.n_expl with + if N.is_root nstore n + then Fmt.string out " :is-root" + else Fmt.fprintf out "@ :root %a" ppn (N.root nstore n) in + let pp_expl out n = match N.expl nstore n with | FL_none -> () | FL_some e -> - Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl + Fmt.fprintf out " (@[:forest %a :expl %a@])" + ppn e.next (Expl.pp nstore) e.expl in let pp_n out n = - Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n pp_expl n + Fmt.fprintf out "(@[%a%a%a%a@])" + Term.pp (N.term nstore n) pp_root n pp_next n pp_expl n and pp_sig_e out (s,n) = - Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n + Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" (Signature.pp nstore) s ppn n pp_root n in Fmt.fprintf out "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ %a@])@])" @@ -361,11 +397,12 @@ module Make (A: CC_ARG) (Util.pp_iter ~sep:" " pp_sig_e) (Sig_tbl.to_iter cc.signatures_tbl) (* compute up-to-date signature *) - let update_sig (s:signature) : Signature.t = + let update_sig (self:t) (s:signature) : Signature.t = + let find = N.find self.nstore in View.map_view s ~f_f:(fun x->x) - ~f_t:find_ - ~f_ts:(List.map find_) + ~f_t:find + ~f_ts:(List.map find) (* find whether the given (parent) term corresponds to some signature in [signatures_] *) @@ -375,20 +412,23 @@ module Make (A: CC_ARG) (* add to signature table. Assume it's not present already *) let add_signature cc (s:signature) (n:node) : unit = assert (not @@ Sig_tbl.mem cc.signatures_tbl s); + let nstore = cc.nstore in Log.debugf 50 - (fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s N.pp n); + (fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" (Signature.pp nstore) s (N.pp nstore) n); on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s); Sig_tbl.add cc.signatures_tbl s n let push_pending cc t : unit = - Log.debugf 50 (fun k->k "(@[cc.push-pending@ %a@])" N.pp t); + let nstore = cc.nstore in + Log.debugf 50 (fun k->k "(@[cc.push-pending@ %a@])" (N.pp nstore) t); Vec.push cc.pending t let merge_classes cc t u e : unit = - if t != u && not (same_class t u) then ( + if t != u && not (N.same_class cc.nstore t u) then ( Log.debugf 50 - (fun k->k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" - N.pp t N.pp u Expl.pp e); + (fun k->let nstore=cc.nstore in + k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" + (N.pp nstore) t (N.pp nstore) u (Expl.pp nstore) e); Vec.push cc.combine @@ CT_merge (t,u,e) ) @@ -396,13 +436,13 @@ module Make (A: CC_ARG) so that it points to [n]. postcondition: [n.n_expl = None] *) let[@unroll 2] rec reroot_expl (cc:t) (n:node): unit = - begin match n.n_expl with + begin match N.expl cc.nstore n with | FL_none -> () (* already root *) | FL_some {next=u; expl=e_n_u} -> (* reroot to [u], then invert link between [u] and [n] *) reroot_expl cc u; - u.n_expl <- FL_some {next=n; expl=e_n_u}; - n.n_expl <- FL_none; + N.set_expl cc.nstore u (FL_some {next=n; expl=e_n_u}); + N.set_expl cc.nstore n FL_none; end let raise_conflict_ (cc:t) ~th (acts:actions) (e:lit list) p : _ = @@ -416,7 +456,7 @@ module Make (A: CC_ARG) let[@inline] all_classes cc : repr Iter.t = T_tbl.values cc.tbl - |> Iter.filter N.is_root + |> Iter.filter (N.is_root cc.nstore) (* find the closest common ancestor of [a] and [b] in the proof forest. @@ -430,21 +470,21 @@ module Make (A: CC_ARG) let find_common_ancestor cc (a:node) (b:node) : node = (* catch up to the other node *) let rec find1 a = - if N.get_field cc.field_marked_explain a then a + if N.get_field cc.nstore cc.field_marked_explain a then a else ( - match a.n_expl with + match N.expl cc.nstore a with | FL_none -> assert false | FL_some r -> find1 r.next ) in let rec find2 a b = if N.equal a b then a - else if N.get_field cc.field_marked_explain a then a - else if N.get_field cc.field_marked_explain b then b + else if N.get_field cc.nstore cc.field_marked_explain a then a + else if N.get_field cc.nstore cc.field_marked_explain b then b else ( - N.set_field cc.field_marked_explain true a; - N.set_field cc.field_marked_explain true b; - match a.n_expl, b.n_expl with + N.set_field cc.nstore cc.field_marked_explain a true; + N.set_field cc.nstore cc.field_marked_explain b true; + match N.expl cc.nstore a, N.expl cc.nstore b with | FL_some r1, FL_some r2 -> find2 r1.next r2.next | FL_some r, FL_none -> find1 r.next | FL_none, FL_some r -> find1 r.next @@ -454,9 +494,9 @@ module Make (A: CC_ARG) in (* cleanup tags on nodes traversed in [find2] *) let rec cleanup_ n = - if N.get_field cc.field_marked_explain n then ( - N.set_field cc.field_marked_explain false n; - match n.n_expl with + if N.get_field cc.nstore cc.field_marked_explain n then ( + N.set_field cc.nstore cc.field_marked_explain n false; + match N.expl cc.nstore n with | FL_none -> () | FL_some {next;_} -> cleanup_ next; ) @@ -468,19 +508,19 @@ module Make (A: CC_ARG) (* decompose explanation [e] into a list of literals added to [acc] *) let rec explain_decompose_expl cc ~th (acc:lit list) (e:explanation) : _ list = - Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); + Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" (Expl.pp cc.nstore) e); match e with | E_reduction -> acc | E_congruence (n1, n2) -> - begin match n1.n_sig0, n2.n_sig0 with - | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> + begin match N.sig0 cc.nstore n1, N.sig0 cc.nstore n2 with + | App_fun (f1, a1), App_fun (f2, a2) -> assert (Fun.equal f1 f2); assert (List.length a1 = List.length a2); List.fold_left2 (explain_equal cc ~th) acc a1 a2 - | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> + | App_ho (f1, a1), App_ho (f2, a2) -> let acc = explain_equal cc ~th acc f1 f2 in explain_equal cc ~th acc a1 a2 - | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> + | If (a1,b1,c1), If (a2,b2,c2) -> let acc = explain_equal cc ~th acc a1 a2 in let acc = explain_equal cc ~th acc b1 b2 in explain_equal cc ~th acc c1 c2 @@ -503,9 +543,11 @@ module Make (A: CC_ARG) explain_decompose_expl cc ~th acc b and explain_equal (cc:t) ~th (acc:lit list) (a:node) (b:node) : _ list = + let nstore = cc.nstore in Log.debugf 5 - (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); - assert (N.equal (find_ a) (find_ b)); + (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" + (N.pp nstore) a (N.pp nstore) b); + assert (N.same_class nstore a b); let ancestor = find_common_ancestor cc a b in let acc = explain_along_path cc ~th acc a ancestor in explain_along_path cc ~th acc b ancestor @@ -516,7 +558,7 @@ module Make (A: CC_ARG) let rec aux acc n = if n == target then acc else ( - match n.n_expl with + match N.expl cc.nstore n with | FL_none -> assert false | FL_some {next=next_n; expl=expl} -> let acc = explain_decompose_expl cc ~th acc expl in @@ -534,27 +576,35 @@ module Make (A: CC_ARG) and add_new_term_ cc (t:term) : node = assert (not @@ mem cc t); Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t); - let n = N.make t in + + let n = N.alloc cc.nstore t in + (* register sub-terms, add [t] to their parent list, and return the corresponding initial signature *) let sig0 = compute_sig0 cc n in - n.n_sig0 <- sig0; + (* remove term when we backtrack *) on_backtrack cc (fun () -> Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t); + N.dealloc cc.nstore n; T_tbl.remove cc.tbl t); + (* add term to the table *) T_tbl.add cc.tbl t n; - if CCOpt.is_some sig0 then ( - (* [n] might be merged with other equiv classes *) - push_pending cc n; - ); + + begin match sig0 with + | Opaque _ | Bool _ -> () + | App_ho _ | App_fun _ | If _ | Eq _ | Not _ -> + (* [n] might be merged with other equiv classes *) + push_pending cc n; + end; + List.iter (fun f -> f cc n t) cc.on_new_term; n (* compute the initial signature of the given node *) - and compute_sig0 (self:t) (n:node) : Signature.t option = + and compute_sig0 (self:t) (n:node) : Signature.t = (* add sub-term to [cc], and register [n] to its parents. Note that we return the exact sub-term, to get proper explanations, but we add to the sub-term's root's parent list. *) @@ -562,48 +612,49 @@ module Make (A: CC_ARG) let sub = add_term_rec_ self u in (* add [n] to [sub.root]'s parent list *) begin - let sub_r = find_ sub in - let old_parents = sub_r.n_parents in + let sub_r = N.find self.nstore sub in + let old_parents = N.parents self.nstore sub_r in if Bag.is_empty old_parents then ( (* first time it has parents: tell watchers that this is a subterm *) List.iter (fun f -> f sub u) self.on_is_subterm; ); - on_backtrack self (fun () -> sub_r.n_parents <- old_parents); - sub_r.n_parents <- Bag.cons n sub_r.n_parents; + on_backtrack self (fun () -> N.set_parents self.nstore sub_r old_parents); + N.upd_parents self.nstore sub_r ~f:(fun p -> Bag.cons n p); end; sub in - let[@inline] return x = Some x in - match A.cc_view n.n_term with - | Bool _ | Opaque _ -> None - | Eq (a,b) -> - let a = deref_sub a in - let b = deref_sub b in - return @@ Eq (a,b) - | Not u -> return @@ Not (deref_sub u) - | App_fun (f, args) -> - let args = args |> Iter.map deref_sub |> Iter.to_list in - if args<>[] then ( - return @@ App_fun (f, args) - ) else None - | App_ho (f, a ) -> - let f = deref_sub f in - let a = deref_sub a in - return @@ App_ho (f, a) - | If (a,b,c) -> - return @@ If (deref_sub a, deref_sub b, deref_sub c) + begin match A.cc_view (N.term self.nstore n) with + | Bool _ | Opaque _ -> Opaque n + | Eq (a,b) -> + let a = deref_sub a in + let b = deref_sub b in + Eq (a,b) + | Not u -> Not (deref_sub u) + | App_fun (f, args) -> + let args = args |> Iter.map deref_sub |> Iter.to_list in + if args<>[] then ( + App_fun (f, args) + ) else Opaque n + | App_ho (f, a ) -> + let f = deref_sub f in + let a = deref_sub a in + App_ho (f, a) + | If (a,b,c) -> + If (deref_sub a, deref_sub b, deref_sub c) + end let[@inline] add_term cc t : node = add_term_rec_ cc t let mem_term = mem - let set_as_lit cc (n:node) (lit:lit) : unit = - match n.n_as_lit with + let set_as_lit self (n:node) (lit:lit) : unit = + match N.as_lit self.nstore n with | Some _ -> () | None -> - Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" N.pp n Lit.pp lit); - on_backtrack cc (fun () -> n.n_as_lit <- None); - n.n_as_lit <- Some lit + Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" + (N.pp self.nstore) n Lit.pp lit); + on_backtrack self (fun () -> N.clear_as_lit self.nstore n); + N.set_as_lit self.nstore n lit (* is [n] true or false? *) let n_is_bool_value (self:t) n : bool = @@ -611,52 +662,55 @@ module Make (A: CC_ARG) (* main CC algo: add terms from [pending] to the signature table, check for collisions *) - let rec update_tasks (cc:t) (acts:actions) : unit = - while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do - while not @@ Vec.is_empty cc.pending do - task_pending_ cc (Vec.pop_exn cc.pending); + let rec update_tasks (self:t) (acts:actions) : unit = + while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do + while not @@ Vec.is_empty self.pending do + task_pending_ self (Vec.pop_exn self.pending); done; - while not @@ Vec.is_empty cc.combine do - task_combine_ cc acts (Vec.pop_exn cc.combine); + while not @@ Vec.is_empty self.combine do + task_combine_ self acts (Vec.pop_exn self.combine); done; done - and task_pending_ cc (n:node) : unit = + and task_pending_ (self:t) (n:node) : unit = (* check if some parent collided *) - begin match n.n_sig0 with - | None -> () (* no-op *) - | Some (Eq (a,b)) -> + begin match N.sig0 self.nstore n with + | Opaque _ -> () (* no-op *) + | Eq (a,b) -> (* if [a=b] is now true, merge [(a=b)] and [true] *) - if same_class a b then ( + if N.same_class self.nstore a b then ( let expl = Expl.mk_merge a b in Log.debugf 5 - (fun k->k "(@[pending.eq@ %a@ :r1 %a@ :r2 %a@])" N.pp n N.pp a N.pp b); - merge_classes cc n (n_true cc) expl + (fun k->k "(@[pending.eq@ %a@ :r1 %a@ :r2 %a@])" + (N.pp self.nstore) n (N.pp self.nstore) a + (N.pp self.nstore) b); + merge_classes self n (n_true self) expl ) - | Some (Not u) -> + | Not u -> (* [u = bool ==> not u = not bool] *) - let r_u = find_ u in - if N.equal r_u (n_true cc) then ( - let expl = Expl.mk_merge u (n_true cc) in - merge_classes cc n (n_false cc) expl - ) else if N.equal r_u (n_false cc) then ( - let expl = Expl.mk_merge u (n_false cc) in - merge_classes cc n (n_true cc) expl + let r_u = N.find self.nstore u in + if N.equal r_u (n_true self) then ( + let expl = Expl.mk_merge u (n_true self) in + merge_classes self n (n_false self) expl + ) else if N.equal r_u (n_false self) then ( + let expl = Expl.mk_merge u (n_false self) in + merge_classes self n (n_true self) expl ) - | Some s0 -> + | s0 -> (* update the signature by using [find] on each sub-node *) - let s = update_sig s0 in - match find_signature cc s with - | None -> - (* add to the signature table [sig(n) --> n] *) - add_signature cc s n - | Some u when N.equal n u -> () - | Some u -> - (* [t1] and [t2] must be applications of the same symbol to - arguments that are pairwise equal *) - assert (n != u); - let expl = Expl.mk_congruence n u in - merge_classes cc n u expl + let s = update_sig self s0 in + begin match find_signature self s with + | None -> + (* add to the signature table [sig(n) --> n] *) + add_signature self s n + | Some u when N.equal n u -> () + | Some u -> + (* [t1] and [t2] must be applications of the same symbol to + arguments that are pairwise equal *) + assert (n != u); + let expl = Expl.mk_congruence n u in + merge_classes self n u expl + end end and[@inline] task_combine_ cc acts = function @@ -666,115 +720,130 @@ module Make (A: CC_ARG) (* main CC algo: merge equivalence classes in [st.combine]. @raise Exn_unsat if merge fails *) - and task_merge_ cc acts a b e_ab : unit = - let ra = find_ a in - let rb = find_ b in + and task_merge_ self acts a b e_ab : unit = + let nstore = self.nstore in + let ra = N.find nstore a in + let rb = N.find nstore b in if not @@ N.equal ra rb then ( - assert (N.is_root ra); - assert (N.is_root rb); - Stat.incr cc.count_merge; + assert (N.is_root nstore ra); + assert (N.is_root nstore rb); + Stat.incr self.count_merge; (* check we're not merging [true] and [false] *) - if (N.equal ra (n_true cc) && N.equal rb (n_false cc)) || - (N.equal rb (n_true cc) && N.equal ra (n_false cc)) then ( + if (N.equal ra (n_true self) && N.equal rb (n_false self)) || + (N.equal rb (n_true self) && N.equal ra (n_false self)) then ( Log.debugf 5 - (fun k->k "(@[cc.merge.true_false_conflict@ \ + (fun k-> + let ppn = N.pp nstore in + k "(@[cc.merge.true_false_conflict@ \ @[:r1 %a@ :t1 %a@]@ @[:r2 %a@ :t2 %a@]@ :e_ab %a@])" - N.pp ra N.pp a N.pp rb N.pp b Expl.pp e_ab); + ppn ra ppn a ppn rb ppn b (Expl.pp nstore) e_ab); let th = ref false in (* TODO: C1: P.true_neq_false C2: lemma [lits |- true=false] (and resolve on theory proofs) C3: r1 C1 C2 *) - let lits = explain_decompose_expl cc ~th [] e_ab in - let lits = explain_equal cc ~th lits a ra in - let lits = explain_equal cc ~th lits b rb in + let lits = explain_decompose_expl self ~th [] e_ab in + let lits = explain_equal self ~th lits a ra in + let lits = explain_equal self ~th lits b rb in let emit_proof p = let p_lits = Iter.of_list lits |> Iter.map Lit.neg in P.lemma_cc p_lits p in - raise_conflict_ cc ~th:!th acts (List.rev_map Lit.neg lits) emit_proof + raise_conflict_ self ~th:!th acts (List.rev_map Lit.neg lits) emit_proof ); (* We will merge [r_from] into [r_into]. we try to ensure that [size ra <= size rb] in general, but always keep values as representative *) let r_from, r_into = - if n_is_bool_value cc ra then rb, ra - else if n_is_bool_value cc rb then ra, rb - else if size_ ra > size_ rb then rb, ra + if n_is_bool_value self ra then rb, ra + else if n_is_bool_value self rb then ra, rb + else if N.size nstore ra > N.size nstore rb then rb, ra else ra, rb in (* when merging terms with [true] or [false], possibly propagate them to SAT *) let merge_bool r1 t1 r2 t2 = - if N.equal r1 (n_true cc) then ( - propagate_bools cc acts r2 t2 r1 t1 e_ab true - ) else if N.equal r1 (n_false cc) then ( - propagate_bools cc acts r2 t2 r1 t1 e_ab false + if N.equal r1 (n_true self) then ( + propagate_bools self acts r2 t2 r1 t1 e_ab true + ) else if N.equal r1 (n_false self) then ( + propagate_bools self acts r2 t2 r1 t1 e_ab false ) in merge_bool ra a rb b; merge_bool rb b ra a; (* perform [union r_from r_into] *) - Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); + Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" + (N.pp nstore) r_from (N.pp nstore) r_into); (* call [on_pre_merge] functions, and merge theory data items *) begin (* explanation is [a=ra & e_ab & b=rb] *) let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in - List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge; + List.iter (fun f -> f self acts r_into r_from expl) self.on_pre_merge; end; begin (* parents might have a different signature, check for collisions *) - N.iter_parents r_from - (fun parent -> push_pending cc parent); + N.iter_parents nstore r_from + (fun parent -> push_pending self parent); (* for each node in [r_from]'s class, make it point to [r_into] *) - N.iter_class r_from + N.iter_class nstore r_from (fun u -> - assert (u.n_root == r_from); - u.n_root <- r_into); + assert (N.root nstore u == r_from); + N.set_root nstore u r_into); (* capture current state *) - let r_into_old_next = r_into.n_next in - let r_from_old_next = r_from.n_next in - let r_into_old_parents = r_into.n_parents in - let r_into_old_bits = r_into.n_bits in + let r_into_old_next = N.next nstore r_into in + let r_from_old_next = N.next nstore r_from in + let r_into_old_parents = N.parents nstore r_into in (* swap [into.next] and [from.next], merging the classes *) - r_into.n_next <- r_from_old_next; - r_from.n_next <- r_into_old_next; - r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; - r_into.n_size <- r_into.n_size + r_from.n_size; - r_into.n_bits <- Bits.merge r_into.n_bits r_from.n_bits; + N.set_next nstore r_into r_from_old_next; + N.set_next nstore r_from r_into_old_next; + N.upd_parents nstore r_into ~f:(fun p -> Bag.append p (N.parents nstore r_from)); + N.set_size nstore r_into (N.size nstore r_into + N.size nstore r_from); + + (* merge bitfields, and backtrack changes *) + + iter_bitfields self ~f:(fun field -> + let b_into = N.get_field nstore field r_into in + if not b_into then ( + let b_from = N.get_field nstore field r_from in + if b_from then ( + (* we modify the field of [r_into], remember to undo that *) + on_backtrack self (fun () -> N.set_field nstore field r_into false); + N.set_field nstore field r_into true; + ); + )); + (* on backtrack, unmerge classes and restore the pointers to [r_from] *) - on_backtrack cc + on_backtrack self (fun () -> Log.debugf 15 (fun k->k "(@[cc.undo_merge@ :from %a :into %a@])" - N.pp r_from N.pp r_into); - r_into.n_bits <- r_into_old_bits; - r_into.n_next <- r_into_old_next; - r_from.n_next <- r_from_old_next; - r_into.n_parents <- r_into_old_parents; + (N.pp nstore) r_from (N.pp nstore) r_into); + N.set_next nstore r_into r_into_old_next; + N.set_next nstore r_from r_from_old_next; + N.set_parents nstore r_into r_into_old_parents; (* NOTE: this must come after the restoration of [next] pointers, otherwise we'd iterate on too big a class *) - N.iter_class_ r_from (fun u -> u.n_root <- r_from); - r_into.n_size <- r_into.n_size - r_from.n_size; + N.iter_class nstore r_from (fun u -> N.set_root nstore u r_from); + N.set_size nstore r_into (N.size nstore r_into - N.size nstore r_from); ); end; (* update explanations (a -> b), arbitrarily. Note that here we merge the classes by adding a bridge between [a] and [b], not their roots. *) begin - reroot_expl cc a; - assert (a.n_expl = FL_none); + reroot_expl self a; + assert (N.expl nstore a == FL_none); (* on backtracking, link may be inverted, but we delete the one that bridges between [a] and [b] *) - on_backtrack cc - (fun () -> match a.n_expl, b.n_expl with - | FL_some e, _ when N.equal e.next b -> a.n_expl <- FL_none - | _, FL_some e when N.equal e.next a -> b.n_expl <- FL_none + on_backtrack self + (fun () -> match N.expl nstore a, N.expl nstore b with + | FL_some e, _ when N.equal e.next b -> N.set_expl nstore a FL_none + | _, FL_some e when N.equal e.next a -> N.set_expl nstore b FL_none | _ -> assert false); - a.n_expl <- FL_some {next=b; expl=e_ab}; + N.set_expl nstore a (FL_some {next=b; expl=e_ab}); end; (* call [on_post_merge] *) begin - List.iter (fun f -> f cc acts r_into r_from) cc.on_post_merge; + List.iter (fun f -> f self acts r_into r_from) self.on_post_merge; end; ) @@ -782,23 +851,23 @@ module Make (A: CC_ARG) in the equiv class of [r1] that is a known literal back to the SAT solver and which is not the one initially merged. We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) - and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = + and propagate_bools (self:t) acts r1 t1 r2 t2 (e_12:explanation) sign : unit = (* explanation for [t1 =e= t2 = r2] *) let half_expl = lazy ( let th = ref false in - let lits = explain_decompose_expl cc ~th [] e_12 in - th, explain_equal cc ~th lits r2 t2 + let lits = explain_decompose_expl self ~th [] e_12 in + th, explain_equal self ~th lits r2 t2 ) in (* TODO: flag per class, `or`-ed on merge, to indicate if the class contains at least one lit *) - N.iter_class r1 + N.iter_class self.nstore r1 (fun u1 -> (* propagate if: - [u1] is a proper literal - [t2 != r2], because that can only happen after an explicit merge (no way to obtain that by propagation) *) - match N.as_lit u1 with + match N.as_lit self.nstore u1 with | Some lit when not (N.equal r2 t2) -> let lit = if sign then lit else Lit.neg lit in (* apply sign *) Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); @@ -806,7 +875,7 @@ module Make (A: CC_ARG) let reason = let e = lazy ( let lazy (th, acc) = half_expl in - let lits = explain_equal cc ~th acc u1 t1 in + let lits = explain_equal self ~th acc u1 t1 in let emit_proof p = (* make a tautology, not a true guard *) let p_lits = Iter.cons lit (Iter.of_list lits |> Iter.map Lit.neg) in @@ -816,8 +885,8 @@ module Make (A: CC_ARG) ) in fun () -> Lazy.force e in - List.iter (fun f -> f cc lit reason) cc.on_propagate; - Stat.incr cc.count_props; + List.iter (fun f -> f self lit reason) self.on_propagate; + Stat.incr self.count_props; Actions.propagate acts lit ~reason | _ -> ()) @@ -873,23 +942,24 @@ module Make (A: CC_ARG) Iter.iter (assert_lit cc) lits (* raise a conflict *) - let raise_conflict_from_expl cc (acts:actions) expl = + let raise_conflict_from_expl self (acts:actions) expl = Log.debugf 5 - (fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); + (fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" (Expl.pp self.nstore) expl); let th = ref true in - let lits = explain_decompose_expl cc ~th [] expl in + let lits = explain_decompose_expl self ~th [] expl in let lits = List.rev_map Lit.neg lits in let emit_proof p = let p_lits = Iter.of_list lits in P.lemma_cc p_lits p in - raise_conflict_ cc ~th:!th acts lits emit_proof + raise_conflict_ self ~th:!th acts lits emit_proof - let merge cc n1 n2 expl = + let merge (self:t) n1 n2 expl = Log.debugf 5 - (fun k->k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" N.pp n1 N.pp n2 Expl.pp expl); - assert (T.Ty.equal (T.Term.ty n1.n_term) (T.Term.ty n2.n_term)); - merge_classes cc n1 n2 expl + (fun k->k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" + (N.pp self.nstore) n1 (N.pp self.nstore) n2 (Expl.pp self.nstore) expl); + assert (T.Ty.equal (T.Term.ty (N.term self.nstore n1)) (T.Term.ty (N.term self.nstore n2))); + merge_classes self n1 n2 expl let[@inline] merge_t cc t1 t2 expl = merge cc (add_term cc t1) (add_term cc t2) expl @@ -911,13 +981,14 @@ module Make (A: CC_ARG) ?(size=`Big) (tst:term_store) : t = let size = match size with `Small -> 128 | `Big -> 2048 in - let bitgen = Bits.mk_gen () in - let field_marked_explain = Bits.mk_field bitgen in + let nstore = N.create() in + let field_marked_explain = N.alloc_bitfield ~descr:"mark-explain" nstore in + assert ((field_marked_explain :> int) = 0); let rec cc = { tst; + nstore; tbl = T_tbl.create size; signatures_tbl = Sig_tbl.create size; - bitgen; on_pre_merge; on_post_merge; on_new_term; @@ -945,18 +1016,19 @@ module Make (A: CC_ARG) ignore (Lazy.force false_ : node); cc - let[@inline] find_t cc t : repr = - let n = T_tbl.find cc.tbl t in - find_ n + let[@inline] find self n = N.find self.nstore n + let[@inline] find_t self t : repr = + let n = T_tbl.find self.tbl t in + N.find self.nstore n - let[@inline] check cc acts : unit = + let[@inline] check self acts : unit = Log.debug 5 "(cc.check)"; - cc.new_merges <- false; - update_tasks cc acts + self.new_merges <- false; + update_tasks self acts let new_merges cc = cc.new_merges (* model: return all the classes *) - let get_model (cc:t) : repr Iter.t Iter.t = - all_classes cc |> Iter.map N.iter_class + let get_model (self:t) : repr Iter.t Iter.t = + all_classes self |> Iter.map (N.iter_class self.nstore) end diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 4442b0c5..eb46e2b7 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -437,7 +437,7 @@ module type CC_S = sig when asked to justify why 2 terms are equal. *) module Expl : sig type t - val pp : t Fmt.printer + val pp : N.store -> t Fmt.printer val mk_merge : N.t -> N.t -> t val mk_merge_t : term -> term -> t @@ -536,6 +536,8 @@ module type CC_S = sig There may be restrictions on how many distinct fields are allocated for a given congruence closure (e.g. at most {!Sys.int_size} fields). + + @param descr description for the field. *) val get_bitfield : t -> N.bitfield -> N.t -> bool @@ -1247,7 +1249,7 @@ end = struct Error.errorf "when merging@ @[for node %a@],@ \ values %a and %a:@ conflict %a" - (N.pp nstore) n_u M.pp m_u M.pp m_u' CC.Expl.pp expl + (N.pp nstore) n_u M.pp m_u M.pp m_u' (CC.Expl.pp nstore) expl | Ok m_u_merged -> Log.debugf 20 (fun k->k "(@[monoid[%s].on-new-term.sub.merged@ \