This commit is contained in:
Simon Cruanes 2019-06-05 20:47:58 -05:00
parent fc80d8c6e9
commit b3f328027c
5 changed files with 113 additions and 118 deletions

View file

@ -13,19 +13,6 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) View.t =
module type ARG = Sidekick_core.CC_ARG
module type S = Sidekick_core.CC_S
module Bits = CCBitField.Make()
let field_is_pending = Bits.mk_field()
(** true iff the node is in the [cc.pending] queue *)
let field_usr1 = Bits.mk_field()
(** General purpose *)
let field_usr2 = Bits.mk_field()
(** General purpose *)
let () = Bits.freeze()
module Make(CC_A: ARG) = struct
module A = CC_A.A
module CC_A = CC_A
@ -33,13 +20,42 @@ module Make(CC_A: ARG) = struct
type term_state = A.Term.state
type lit = A.Lit.t
type fun_ = A.Fun.t
type proof = A.Proof.t
type th_data = CC_A.Data.t
type lemma = A.Lemma.t
type actions = CC_A.Actions.t
module T = A.Term
module Fun = A.Fun
module Bits : sig
type t = private int
type field = private int
val empty : t
val mk_field : unit -> field
val inter : t -> t -> t
val union : t -> t -> t
val get : field -> t -> bool
val set : field -> t -> t
val unset : field -> t -> t
end = struct
type t = int
type field = int
let max_width = Sys.word_size - 2
let mk_field =
let n_ = ref 0 in
fun () -> let x = 1 lsl !n_ in incr n_; x
let empty = 0
let[@inline] get f x = x land f <> 0
let[@inline] set f x = x lor f
let[@inline] unset f x = x land (lnot f)
let[@inline] union x y = x lor y
let[@inline] inter x y = x land y
end
let field_is_pending = Bits.mk_field()
(** true iff the node is in the [cc.pending] queue *)
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
the representative. *)
@ -51,9 +67,8 @@ module Make(CC_A: ARG) = struct
mutable n_root: node; (* representative of congruence class (itself if a representative) *)
mutable n_next: node; (* pointer to next element of congruence class *)
mutable n_size: int; (* size of the class *)
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
mutable n_as_lit: lit option;
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
mutable n_th_data: th_data; (* theory data *)
}
and signature = (fun_, node, node list) view
@ -85,7 +100,6 @@ module Make(CC_A: ARG) = struct
let[@inline] term n = n.n_term
let[@inline] pp out n = T.pp out n.n_term
let[@inline] as_lit n = n.n_as_lit
let[@inline] th_data n = n.n_th_data
let make (t:term) : t =
let rec n = {
@ -98,7 +112,6 @@ module Make(CC_A: ARG) = struct
n_expl=FL_none;
n_next=n;
n_size=1;
n_th_data=CC_A.Data.empty;
} in
n
@ -122,12 +135,8 @@ module Make(CC_A: ARG) = struct
Bag.to_seq n.n_parents
let[@inline] get_field f t = Bits.get f t.n_bits
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
let[@inline] get_field_usr1 t = get_field field_usr1 t
let[@inline] set_field_usr1 t b = set_field field_usr1 b t
let[@inline] get_field_usr2 t = get_field field_usr2 t
let[@inline] set_field_usr2 t b = set_field field_usr2 b t
let[@inline] set_field f t = t.n_bits <- Bits.set f t.n_bits
let[@inline] unset_field f t = t.n_bits <- Bits.unset f t.n_bits
end
module N_tbl = CCHashtbl.Make(N)
@ -227,7 +236,7 @@ module Make(CC_A: ARG) = struct
pending: node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
mutable on_merge: ev_on_merge list;
mutable on_merge: (Bits.field * ev_on_merge) list;
mutable on_new_term: ev_on_new_term list;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
@ -245,8 +254,8 @@ module Make(CC_A: ARG) = struct
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
and ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit
and ev_on_new_term = t -> N.t -> term -> th_data option
and ev_on_merge = t -> N.t -> N.t -> Expl.t -> unit
and ev_on_new_term = t -> N.t -> term -> unit
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
@ -329,7 +338,7 @@ module Make(CC_A: ARG) = struct
let push_pending cc t : unit =
if not @@ N.get_field field_is_pending t then (
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
N.set_field field_is_pending true t;
N.set_field field_is_pending t;
Vec.push cc.pending t
)
@ -354,11 +363,11 @@ module Make(CC_A: ARG) = struct
let raise_conflict (cc:t) (acts:actions) (e:conflict) : _ =
(* clear tasks queue *)
Vec.iter (N.set_field field_is_pending false) cc.pending;
Vec.iter (N.unset_field field_is_pending) cc.pending;
Vec.clear cc.pending;
Vec.clear cc.combine;
Stat.incr cc.count_conflict;
CC_A.Actions.raise_conflict acts e A.Proof.default
CC_A.Actions.raise_conflict acts e A.Lemma.default
let[@inline] all_classes cc : repr Iter.t =
T_tbl.values cc.tbl
@ -502,16 +511,7 @@ module Make(CC_A: ARG) = struct
(* [n] might be merged with other equiv classes *)
push_pending cc n;
);
(* initial theory data *)
let th_data =
List.fold_left
(fun data f ->
match f cc n t with
| None -> data
| Some d -> CC_A.Data.merge data d)
CC_A.Data.empty cc.on_new_term
in
n.n_th_data <- th_data;
List.iter (fun f -> f cc n t) cc.on_new_term; (* notify *)
n
(* compute the initial signature of the given node *)
@ -572,7 +572,7 @@ module Make(CC_A: ARG) = struct
done
and task_pending_ cc (n:node) : unit =
N.set_field field_is_pending false n;
N.unset_field field_is_pending n;
(* check if some parent collided *)
begin match n.n_sig0 with
| None -> () (* no-op *)
@ -654,17 +654,21 @@ module Make(CC_A: ARG) = struct
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* call [on_merge] functions, and merge theory data items *)
begin
let th_into = r_into.n_th_data in
let th_from = r_from.n_th_data in
let new_data = CC_A.Data.merge th_into th_from in
let bits_into = r_into.n_bits in
let bits_from = r_from.n_bits in
let inter = Bits.inter bits_into bits_from in
let union = Bits.union bits_into bits_from in
(* restore old data, if it changed *)
if new_data != th_into then (
on_backtrack cc (fun () -> r_into.n_th_data <- th_into);
if union != bits_into then (
on_backtrack cc (fun () -> r_into.n_bits <- bits_into);
);
r_into.n_th_data <- new_data;
(* explanation is [a=ra & e_ab & b=rb] *)
r_into.n_bits <- union;
(* call merge handlers with explanation [a=ra & e_ab & b=rb].
Only do so for handlers whose bit is on for both classes. *)
let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in
List.iter (fun f -> f cc r_into th_into r_from th_from expl) cc.on_merge;
List.iter (fun (f_field,f) ->
if Bits.get f_field inter then f cc r_into r_from expl)
cc.on_merge;
end;
begin
(* parents might have a different signature, check for collisions *)
@ -740,7 +744,7 @@ module Make(CC_A: ARG) = struct
let e = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
List.iter yield e
in
CC_A.Actions.propagate acts lit ~reason A.Proof.default
CC_A.Actions.propagate acts lit ~reason A.Lemma.default
| _ -> ())
module Theory = struct
@ -760,39 +764,39 @@ module Make(CC_A: ARG) = struct
let add_term = add_term
end
let check_invariants_ (cc:t) =
Log.debug 5 "(cc.check-invariants)";
Log.debugf 15 (fun k-> k "%a" pp_full cc);
assert (T.equal (T.bool cc.tst true) (true_ cc).n_term);
assert (T.equal (T.bool cc.tst false) (false_ cc).n_term);
assert (not @@ same_class (true_ cc) (false_ cc));
assert (Vec.is_empty cc.combine);
assert (Vec.is_empty cc.pending);
(* check that subterms are internalized *)
T_tbl.iter
(fun t n ->
assert (T.equal t n.n_term);
assert (not @@ N.get_field field_is_pending n);
assert (N.equal n.n_root n.n_next.n_root);
(* check proper signature.
note that some signatures in the sig table can be obsolete (they
were not removed) but there must be a valid, up-to-date signature for
each term *)
begin match CCOpt.map update_sig n.n_sig0 with
| None -> ()
| Some s ->
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s);
(* add, but only if not present already *)
begin match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found -> assert false
| repr_s -> assert (same_class n repr_s)
end
end;
)
cc.tbl;
()
module Debug_ = struct
let check_invariants_ (cc:t) =
Log.debug 5 "(cc.check-invariants)";
Log.debugf 15 (fun k-> k "%a" pp_full cc);
assert (T.equal (T.bool cc.tst true) (true_ cc).n_term);
assert (T.equal (T.bool cc.tst false) (false_ cc).n_term);
assert (not @@ same_class (true_ cc) (false_ cc));
assert (Vec.is_empty cc.combine);
assert (Vec.is_empty cc.pending);
(* check that subterms are internalized *)
T_tbl.iter
(fun t n ->
assert (T.equal t n.n_term);
assert (not @@ N.get_field field_is_pending n);
assert (N.equal n.n_root n.n_next.n_root);
(* check proper signature.
note that some signatures in the sig table can be obsolete (they
were not removed) but there must be a valid, up-to-date signature for
each term *)
begin match CCOpt.map update_sig n.n_sig0 with
| None -> ()
| Some s ->
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s);
(* add, but only if not present already *)
begin match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found -> assert false
| repr_s -> assert (same_class n repr_s)
end
end;
)
cc.tbl;
()
let[@inline] check_invariants (cc:t) : unit =
if Util._CHECK_INVARIANTS then check_invariants_ cc
let pp out _ = Fmt.string out "cc"
@ -806,7 +810,7 @@ module Make(CC_A: ARG) = struct
Backtrack_stack.push_level self.undo
let pop_levels (self:t) n : unit =
Vec.iter (N.set_field field_is_pending false) self.pending;
Vec.iter (N.unset_field field_is_pending) self.pending;
Vec.clear self.pending;
Vec.clear self.combine;
Log.debugf 15

View file

@ -117,7 +117,7 @@ module type CORE_TYPES = sig
val pp : t Fmt.printer
end
module Proof : sig
module Lemma : sig
type t
val pp : t Fmt.printer
@ -136,20 +136,12 @@ module type CC_ARG = sig
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t
(** View the term through the lens of the congruence closure *)
(** Monoid embedded in every node *)
module Data : sig
type t
val merge : t -> t -> t
val pp : t Fmt.printer
val empty : t
end
module Actions : sig
type t
val raise_conflict : t -> Lit.t list -> Proof.t -> 'a
val raise_conflict : t -> Lit.t list -> Lemma.t -> 'a
val propagate : t -> Lit.t -> reason:Lit.t Iter.t -> Proof.t -> unit
val propagate : t -> Lit.t -> reason:Lit.t Iter.t -> Lemma.t -> unit
end
end
@ -160,8 +152,7 @@ module type CC_S = sig
type term = A.Term.t
type fun_ = A.Fun.t
type lit = A.Lit.t
type proof = A.Proof.t
type th_data = CC_A.Data.t
type lemma = A.Lemma.t
type actions = CC_A.Actions.t
type t
@ -197,19 +188,6 @@ module type CC_S = sig
val iter_class : t -> t Iter.t
(** Traverse the congruence class.
Precondition: [is_root n] (see {!find} below) *)
val iter_parents : t -> t Iter.t
(** Traverse the parents of the class.
Precondition: [is_root n] (see {!find} below) *)
val th_data : t -> th_data
(** Access theory data for this node *)
val get_field_usr1 : t -> bool
val set_field_usr1 : t -> bool -> unit
val get_field_usr2 : t -> bool
val set_field_usr2 : t -> bool -> unit
end
module Expl : sig
@ -261,8 +239,8 @@ module type CC_S = sig
To be used in theories *)
end
type ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit
type ev_on_new_term = t -> N.t -> term -> th_data option
type ev_on_merge = t -> N.t -> N.t -> Expl.t -> unit
type ev_on_new_term = t -> N.t -> term -> unit
val create :
?stat:Stat.t ->
@ -351,7 +329,7 @@ module type SOLVER_INTERNAL = sig
type lit = A.Lit.t
type term = A.Term.t
type term_state = A.Term.state
type proof = A.Proof.t
type lemma = A.Lemma.t
(** {3 Main type for a solver} *)
type t
@ -500,7 +478,7 @@ module type SOLVER = sig
type term = A.Term.t
type ty = A.Ty.t
type lit = A.Lit.t
type proof = A.Proof.t
type lemma = A.Lemma.t
type value = A.Value.t
(** {3 A theory}
@ -582,6 +560,14 @@ module type SOLVER = sig
*)
end
module Proof : sig
type t
(* TODO: expose more? *)
end
type proof = Proof.t
(** {3 Main API} *)
val stats : t -> Stat.t

View file

@ -1,7 +1,12 @@
(** {1 Process Statements} *)
open Sidekick_smt
open Sidekick_base_term
module Solver : Sidekick_msat_solver.S
with type A.Term.t = Term.t
and type A.Ty.t = Ty.t
and type A.Fun.t = Cst.t
type 'a or_error = ('a, string) CCResult.t

View file

@ -3,12 +3,12 @@
(** {1 Preprocessing AST} *)
module ID = Sidekick_smt.ID
module ID = Sidekick_base_term.ID
module Loc = Locations
module Fmt = CCFormat
module Log = Msat.Log
module A = Sidekick_smt.Ast
module A = Sidekick_base_term.Ast
module PA = Parse_ast
type 'a or_error = ('a, string) CCResult.t

View file

@ -16,7 +16,7 @@ module Ctx : sig
end
module PA = Parse_ast
module A = Sidekick_smt.Ast
module A = Sidekick_base_term.Ast
val conv_term : Ctx.t -> PA.term -> A.term