diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 3d1c5443..bd48ca08 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -29,23 +29,26 @@ module Make(CC_A: ARG) = struct module Bits : sig type t = private int type field + type bitfield_gen val empty : t val equal : t -> t -> bool - val mk_field : unit -> field + val mk_field : bitfield_gen -> field + val mk_gen : unit -> bitfield_gen val get : field -> t -> bool val set : field -> bool -> t -> t val merge : t -> t -> t end = struct + type bitfield_gen = int ref let max_width = Sys.word_size - 2 - let width = ref 0 + let mk_gen() = ref 0 type t = int type field = int let empty : t = 0 - let mk_field () : field = - let n = !width in + let mk_field (gen:bitfield_gen) : field = + let n = !gen in if n > max_width then Error.errorf "maximum number of CC bitfields reached"; - incr width; - n + incr gen; + 1 lsl n let[@inline] get field x = (x land field) <> 0 let[@inline] set field b x = if b then x lor field else x land (lnot field) @@ -53,12 +56,6 @@ module Make(CC_A: ARG) = struct let equal : t -> t -> bool = Pervasives.(=) end - let field_is_pending = Bits.mk_field() - (** true iff the node is in the [cc.pending] queue *) - - let field_marked_explain = Bits.mk_field() - (** used to mark traversed nodes when looking for a common ancestor *) - (** A node of the congruence closure. An equivalence class is represented by its "root" element, the representative. *) @@ -137,7 +134,6 @@ module Make(CC_A: ARG) = struct Bag.to_seq n.n_parents type bitfield = Bits.field - let allocate_bitfield = Bits.mk_field let[@inline] get_field f t = Bits.get f t.n_bits let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits end @@ -243,7 +239,8 @@ module Make(CC_A: ARG) = struct mutable on_new_term: ev_on_new_term list; mutable on_conflict: ev_on_conflict list; mutable on_propagate: ev_on_propagate list; - (* pairs to explain *) + bitgen: Bits.bitfield_gen; + field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *) true_ : node lazy_t; false_ : node lazy_t; stat: Stat.t; @@ -266,6 +263,7 @@ module Make(CC_A: ARG) = struct let[@inline] true_ cc = Lazy.force cc.true_ let[@inline] false_ cc = Lazy.force cc.false_ let[@inline] term_state cc = cc.tst + let[@inline] allocate_bitfield cc = Bits.mk_field cc.bitgen let[@inline] on_backtrack cc f : unit = Backtrack_stack.push_if_nonzero_level cc.undo f @@ -327,11 +325,8 @@ module Make(CC_A: ARG) = struct Sig_tbl.add cc.signatures_tbl s n let push_pending cc t : unit = - if not @@ N.get_field field_is_pending t then ( - Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" N.pp t); - N.set_field field_is_pending true t; - Vec.push cc.pending t - ) + Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" N.pp t); + Vec.push cc.pending t let merge_classes cc t u e : unit = Log.debugf 5 @@ -354,7 +349,6 @@ module Make(CC_A: ARG) = struct let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ = (* clear tasks queue *) - Vec.iter (N.set_field field_is_pending false) cc.pending; Vec.clear cc.pending; Vec.clear cc.combine; List.iter (fun f -> f cc e) cc.on_conflict; @@ -374,10 +368,10 @@ module Make(CC_A: ARG) = struct - if [n] is marked, then all the predecessors of [n] from [a] or [b] are marked too. *) - let find_common_ancestor (a:node) (b:node) : node = + let find_common_ancestor cc (a:node) (b:node) : node = (* catch up to the other node *) let rec find1 a = - if N.get_field field_marked_explain a then a + if N.get_field cc.field_marked_explain a then a else ( match a.n_expl with | FL_none -> assert false @@ -386,11 +380,11 @@ module Make(CC_A: ARG) = struct in let rec find2 a b = if N.equal a b then a - else if N.get_field field_marked_explain a then a - else if N.get_field field_marked_explain b then b + else if N.get_field cc.field_marked_explain a then a + else if N.get_field cc.field_marked_explain b then b else ( - N.set_field field_marked_explain true a; - N.set_field field_marked_explain true b; + N.set_field cc.field_marked_explain true a; + N.set_field cc.field_marked_explain true b; match a.n_expl, b.n_expl with | FL_some r1, FL_some r2 -> find2 r1.next r2.next | FL_some r, FL_none -> find1 r.next @@ -401,8 +395,8 @@ module Make(CC_A: ARG) = struct in (* cleanup tags on nodes traversed in [find2] *) let rec cleanup_ n = - if N.get_field field_marked_explain n then ( - N.set_field field_marked_explain false n; + if N.get_field cc.field_marked_explain n then ( + N.set_field cc.field_marked_explain false n; match n.n_expl with | FL_none -> () | FL_some {next;_} -> cleanup_ next; @@ -452,7 +446,7 @@ module Make(CC_A: ARG) = struct Log.debugf 5 (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); assert (N.equal (find_ a) (find_ b)); - let ancestor = find_common_ancestor a b in + let ancestor = find_common_ancestor cc a b in let acc = explain_along_path cc acc a ancestor in explain_along_path cc acc b ancestor @@ -560,7 +554,6 @@ module Make(CC_A: ARG) = struct done and task_pending_ cc (n:node) : unit = - N.set_field field_is_pending false n; (* check if some parent collided *) begin match n.n_sig0 with | None -> () (* no-op *) @@ -735,38 +728,6 @@ module Make(CC_A: ARG) = struct CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default | _ -> ()) - let check_invariants_ (cc:t) = - Log.debug 5 "(cc.check-invariants)"; - Log.debugf 15 (fun k-> k "%a" pp_full cc); - assert (T.equal (T.bool cc.tst true) (true_ cc).n_term); - assert (T.equal (T.bool cc.tst false) (false_ cc).n_term); - assert (not @@ same_class (true_ cc) (false_ cc)); - assert (Vec.is_empty cc.combine); - assert (Vec.is_empty cc.pending); - (* check that subterms are internalized *) - T_tbl.iter - (fun t n -> - assert (T.equal t n.n_term); - assert (not @@ N.get_field field_is_pending n); - assert (N.equal n.n_root n.n_next.n_root); - (* check proper signature. - note that some signatures in the sig table can be obsolete (they - were not removed) but there must be a valid, up-to-date signature for - each term *) - begin match CCOpt.map update_sig n.n_sig0 with - | None -> () - | Some s -> - Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s); - (* add, but only if not present already *) - begin match Sig_tbl.find cc.signatures_tbl s with - | exception Not_found -> assert false - | repr_s -> assert (same_class n repr_s) - end - end; - ) - cc.tbl; - () - module Debug_ = struct let pp out _ = Fmt.string out "cc" end @@ -779,7 +740,6 @@ module Make(CC_A: ARG) = struct Backtrack_stack.push_level self.undo let pop_levels (self:t) n : unit = - Vec.iter (N.set_field field_is_pending false) self.pending; Vec.clear self.pending; Vec.clear self.combine; Log.debugf 15 @@ -845,10 +805,13 @@ module Make(CC_A: ARG) = struct ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in + let bitgen = Bits.mk_gen () in + let field_marked_explain = Bits.mk_field bitgen in let rec cc = { tst; tbl = T_tbl.create size; signatures_tbl = Sig_tbl.create size; + bitgen; on_merge; on_new_term; on_conflict; @@ -859,6 +822,7 @@ module Make(CC_A: ARG) = struct true_; false_; stat; + field_marked_explain; count_conflict=Stat.mk_int stat "cc.conflicts"; count_props=Stat.mk_int stat "cc.propagations"; count_merge=Stat.mk_int stat "cc.merges"; diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 36ef8cea..788d2b15 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -178,7 +178,6 @@ module type CC_S = sig All fields are initially 0, are backtracked automatically, and are merged automatically when classes are merged. *) - val allocate_bitfield : unit -> bitfield val get_field : bitfield -> t -> bool val set_field : bitfield -> bool -> t -> unit end @@ -228,6 +227,10 @@ module type CC_S = sig t (** Create a new congruence closure. *) + val allocate_bitfield : t -> N.bitfield + (** Allocate a new bitfield for the nodes. + See {!N.bitfield}. *) + (* TODO: remove? this is managed by the solver anyway? *) val on_merge : t -> ev_on_merge -> unit (** Add a function to be called when two classes are merged *)