mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-07 11:45:41 -05:00
feat(cc): make bitfields non-global; remove dead code
This commit is contained in:
parent
ed4ba4057f
commit
2430eb754d
2 changed files with 31 additions and 64 deletions
|
|
@ -29,23 +29,26 @@ module Make(CC_A: ARG) = struct
|
|||
module Bits : sig
|
||||
type t = private int
|
||||
type field
|
||||
type bitfield_gen
|
||||
val empty : t
|
||||
val equal : t -> t -> bool
|
||||
val mk_field : unit -> field
|
||||
val mk_field : bitfield_gen -> field
|
||||
val mk_gen : unit -> bitfield_gen
|
||||
val get : field -> t -> bool
|
||||
val set : field -> bool -> t -> t
|
||||
val merge : t -> t -> t
|
||||
end = struct
|
||||
type bitfield_gen = int ref
|
||||
let max_width = Sys.word_size - 2
|
||||
let width = ref 0
|
||||
let mk_gen() = ref 0
|
||||
type t = int
|
||||
type field = int
|
||||
let empty : t = 0
|
||||
let mk_field () : field =
|
||||
let n = !width in
|
||||
let mk_field (gen:bitfield_gen) : field =
|
||||
let n = !gen in
|
||||
if n > max_width then Error.errorf "maximum number of CC bitfields reached";
|
||||
incr width;
|
||||
n
|
||||
incr gen;
|
||||
1 lsl n
|
||||
let[@inline] get field x = (x land field) <> 0
|
||||
let[@inline] set field b x =
|
||||
if b then x lor field else x land (lnot field)
|
||||
|
|
@ -53,12 +56,6 @@ module Make(CC_A: ARG) = struct
|
|||
let equal : t -> t -> bool = Pervasives.(=)
|
||||
end
|
||||
|
||||
let field_is_pending = Bits.mk_field()
|
||||
(** true iff the node is in the [cc.pending] queue *)
|
||||
|
||||
let field_marked_explain = Bits.mk_field()
|
||||
(** used to mark traversed nodes when looking for a common ancestor *)
|
||||
|
||||
(** A node of the congruence closure.
|
||||
An equivalence class is represented by its "root" element,
|
||||
the representative. *)
|
||||
|
|
@ -137,7 +134,6 @@ module Make(CC_A: ARG) = struct
|
|||
Bag.to_seq n.n_parents
|
||||
|
||||
type bitfield = Bits.field
|
||||
let allocate_bitfield = Bits.mk_field
|
||||
let[@inline] get_field f t = Bits.get f t.n_bits
|
||||
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
|
||||
end
|
||||
|
|
@ -243,7 +239,8 @@ module Make(CC_A: ARG) = struct
|
|||
mutable on_new_term: ev_on_new_term list;
|
||||
mutable on_conflict: ev_on_conflict list;
|
||||
mutable on_propagate: ev_on_propagate list;
|
||||
(* pairs to explain *)
|
||||
bitgen: Bits.bitfield_gen;
|
||||
field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *)
|
||||
true_ : node lazy_t;
|
||||
false_ : node lazy_t;
|
||||
stat: Stat.t;
|
||||
|
|
@ -266,6 +263,7 @@ module Make(CC_A: ARG) = struct
|
|||
let[@inline] true_ cc = Lazy.force cc.true_
|
||||
let[@inline] false_ cc = Lazy.force cc.false_
|
||||
let[@inline] term_state cc = cc.tst
|
||||
let[@inline] allocate_bitfield cc = Bits.mk_field cc.bitgen
|
||||
|
||||
let[@inline] on_backtrack cc f : unit =
|
||||
Backtrack_stack.push_if_nonzero_level cc.undo f
|
||||
|
|
@ -327,11 +325,8 @@ module Make(CC_A: ARG) = struct
|
|||
Sig_tbl.add cc.signatures_tbl s n
|
||||
|
||||
let push_pending cc t : unit =
|
||||
if not @@ N.get_field field_is_pending t then (
|
||||
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
|
||||
N.set_field field_is_pending true t;
|
||||
Vec.push cc.pending t
|
||||
)
|
||||
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
|
||||
Vec.push cc.pending t
|
||||
|
||||
let merge_classes cc t u e : unit =
|
||||
Log.debugf 5
|
||||
|
|
@ -354,7 +349,6 @@ module Make(CC_A: ARG) = struct
|
|||
|
||||
let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ =
|
||||
(* clear tasks queue *)
|
||||
Vec.iter (N.set_field field_is_pending false) cc.pending;
|
||||
Vec.clear cc.pending;
|
||||
Vec.clear cc.combine;
|
||||
List.iter (fun f -> f cc e) cc.on_conflict;
|
||||
|
|
@ -374,10 +368,10 @@ module Make(CC_A: ARG) = struct
|
|||
- if [n] is marked, then all the predecessors of [n]
|
||||
from [a] or [b] are marked too.
|
||||
*)
|
||||
let find_common_ancestor (a:node) (b:node) : node =
|
||||
let find_common_ancestor cc (a:node) (b:node) : node =
|
||||
(* catch up to the other node *)
|
||||
let rec find1 a =
|
||||
if N.get_field field_marked_explain a then a
|
||||
if N.get_field cc.field_marked_explain a then a
|
||||
else (
|
||||
match a.n_expl with
|
||||
| FL_none -> assert false
|
||||
|
|
@ -386,11 +380,11 @@ module Make(CC_A: ARG) = struct
|
|||
in
|
||||
let rec find2 a b =
|
||||
if N.equal a b then a
|
||||
else if N.get_field field_marked_explain a then a
|
||||
else if N.get_field field_marked_explain b then b
|
||||
else if N.get_field cc.field_marked_explain a then a
|
||||
else if N.get_field cc.field_marked_explain b then b
|
||||
else (
|
||||
N.set_field field_marked_explain true a;
|
||||
N.set_field field_marked_explain true b;
|
||||
N.set_field cc.field_marked_explain true a;
|
||||
N.set_field cc.field_marked_explain true b;
|
||||
match a.n_expl, b.n_expl with
|
||||
| FL_some r1, FL_some r2 -> find2 r1.next r2.next
|
||||
| FL_some r, FL_none -> find1 r.next
|
||||
|
|
@ -401,8 +395,8 @@ module Make(CC_A: ARG) = struct
|
|||
in
|
||||
(* cleanup tags on nodes traversed in [find2] *)
|
||||
let rec cleanup_ n =
|
||||
if N.get_field field_marked_explain n then (
|
||||
N.set_field field_marked_explain false n;
|
||||
if N.get_field cc.field_marked_explain n then (
|
||||
N.set_field cc.field_marked_explain false n;
|
||||
match n.n_expl with
|
||||
| FL_none -> ()
|
||||
| FL_some {next;_} -> cleanup_ next;
|
||||
|
|
@ -452,7 +446,7 @@ module Make(CC_A: ARG) = struct
|
|||
Log.debugf 5
|
||||
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
|
||||
assert (N.equal (find_ a) (find_ b));
|
||||
let ancestor = find_common_ancestor a b in
|
||||
let ancestor = find_common_ancestor cc a b in
|
||||
let acc = explain_along_path cc acc a ancestor in
|
||||
explain_along_path cc acc b ancestor
|
||||
|
||||
|
|
@ -560,7 +554,6 @@ module Make(CC_A: ARG) = struct
|
|||
done
|
||||
|
||||
and task_pending_ cc (n:node) : unit =
|
||||
N.set_field field_is_pending false n;
|
||||
(* check if some parent collided *)
|
||||
begin match n.n_sig0 with
|
||||
| None -> () (* no-op *)
|
||||
|
|
@ -735,38 +728,6 @@ module Make(CC_A: ARG) = struct
|
|||
CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default
|
||||
| _ -> ())
|
||||
|
||||
let check_invariants_ (cc:t) =
|
||||
Log.debug 5 "(cc.check-invariants)";
|
||||
Log.debugf 15 (fun k-> k "%a" pp_full cc);
|
||||
assert (T.equal (T.bool cc.tst true) (true_ cc).n_term);
|
||||
assert (T.equal (T.bool cc.tst false) (false_ cc).n_term);
|
||||
assert (not @@ same_class (true_ cc) (false_ cc));
|
||||
assert (Vec.is_empty cc.combine);
|
||||
assert (Vec.is_empty cc.pending);
|
||||
(* check that subterms are internalized *)
|
||||
T_tbl.iter
|
||||
(fun t n ->
|
||||
assert (T.equal t n.n_term);
|
||||
assert (not @@ N.get_field field_is_pending n);
|
||||
assert (N.equal n.n_root n.n_next.n_root);
|
||||
(* check proper signature.
|
||||
note that some signatures in the sig table can be obsolete (they
|
||||
were not removed) but there must be a valid, up-to-date signature for
|
||||
each term *)
|
||||
begin match CCOpt.map update_sig n.n_sig0 with
|
||||
| None -> ()
|
||||
| Some s ->
|
||||
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s);
|
||||
(* add, but only if not present already *)
|
||||
begin match Sig_tbl.find cc.signatures_tbl s with
|
||||
| exception Not_found -> assert false
|
||||
| repr_s -> assert (same_class n repr_s)
|
||||
end
|
||||
end;
|
||||
)
|
||||
cc.tbl;
|
||||
()
|
||||
|
||||
module Debug_ = struct
|
||||
let pp out _ = Fmt.string out "cc"
|
||||
end
|
||||
|
|
@ -779,7 +740,6 @@ module Make(CC_A: ARG) = struct
|
|||
Backtrack_stack.push_level self.undo
|
||||
|
||||
let pop_levels (self:t) n : unit =
|
||||
Vec.iter (N.set_field field_is_pending false) self.pending;
|
||||
Vec.clear self.pending;
|
||||
Vec.clear self.combine;
|
||||
Log.debugf 15
|
||||
|
|
@ -845,10 +805,13 @@ module Make(CC_A: ARG) = struct
|
|||
?(size=`Big)
|
||||
(tst:term_state) : t =
|
||||
let size = match size with `Small -> 128 | `Big -> 2048 in
|
||||
let bitgen = Bits.mk_gen () in
|
||||
let field_marked_explain = Bits.mk_field bitgen in
|
||||
let rec cc = {
|
||||
tst;
|
||||
tbl = T_tbl.create size;
|
||||
signatures_tbl = Sig_tbl.create size;
|
||||
bitgen;
|
||||
on_merge;
|
||||
on_new_term;
|
||||
on_conflict;
|
||||
|
|
@ -859,6 +822,7 @@ module Make(CC_A: ARG) = struct
|
|||
true_;
|
||||
false_;
|
||||
stat;
|
||||
field_marked_explain;
|
||||
count_conflict=Stat.mk_int stat "cc.conflicts";
|
||||
count_props=Stat.mk_int stat "cc.propagations";
|
||||
count_merge=Stat.mk_int stat "cc.merges";
|
||||
|
|
|
|||
|
|
@ -178,7 +178,6 @@ module type CC_S = sig
|
|||
All fields are initially 0, are backtracked automatically,
|
||||
and are merged automatically when classes are merged. *)
|
||||
|
||||
val allocate_bitfield : unit -> bitfield
|
||||
val get_field : bitfield -> t -> bool
|
||||
val set_field : bitfield -> bool -> t -> unit
|
||||
end
|
||||
|
|
@ -228,6 +227,10 @@ module type CC_S = sig
|
|||
t
|
||||
(** Create a new congruence closure. *)
|
||||
|
||||
val allocate_bitfield : t -> N.bitfield
|
||||
(** Allocate a new bitfield for the nodes.
|
||||
See {!N.bitfield}. *)
|
||||
|
||||
(* TODO: remove? this is managed by the solver anyway? *)
|
||||
val on_merge : t -> ev_on_merge -> unit
|
||||
(** Add a function to be called when two classes are merged *)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue