feat(cc.plugin): plugins have state, passed at init

This commit is contained in:
Simon Cruanes 2022-08-14 23:21:49 -04:00
parent e9dae47d0b
commit 94ba945bf3
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
8 changed files with 106 additions and 26 deletions

View file

@ -49,6 +49,7 @@ type combine_task =
type t = {
view_as_cc: view_as_cc;
tst: Term.store;
stat: Stat.t;
proof: Proof_trace.t;
tbl: e_node T_tbl.t; (* internalization [term -> e_node] *)
signatures_tbl: e_node Sig_tbl.t;
@ -108,6 +109,7 @@ let n_bool self b =
let[@inline] term_store self = self.tst
let[@inline] proof self = self.proof
let[@inline] stat self = self.stat
let allocate_bitfield self ~descr : bitfield =
Log.debugf 5 (fun k -> k "(@[cc.allocate-bit-field@ :descr %s@])" descr);
@ -851,6 +853,7 @@ let create_ ?(stat = Stat.global) ?(size = `Big) (tst : Term.store)
view_as_cc;
tst;
proof;
stat;
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
bitgen;

View file

@ -45,6 +45,7 @@ type t
val term_store : t -> Term.store
val proof : t -> Proof_trace.t
val stat : t -> Stat.t
val find : t -> e_node -> repr
(** Current representative *)

View file

@ -33,6 +33,9 @@ module Make (M : MONOID_PLUGIN_ARG) :
module CC = CC
open A
(* plugin's state *)
let plugin_st = M.create cc
(* repr -> value for the class *)
let values : M.t Cls_tbl.t = Cls_tbl.create ?size ()
@ -62,7 +65,7 @@ module Make (M : MONOID_PLUGIN_ARG) :
let on_new_term cc n (t : Term.t) : CC.Handler_action.t list =
(*Log.debugf 50 (fun k->k "(@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n);*)
let acts = ref [] in
let maybe_m, l = M.of_term cc n t in
let maybe_m, l = M.of_term cc plugin_st n t in
(match maybe_m with
| Some v ->
Log.debugf 20 (fun k ->
@ -84,7 +87,7 @@ module Make (M : MONOID_PLUGIN_ARG) :
Error.errorf "node %a has bitfield but no value" E_node.pp n_u
in
match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with
match M.merge cc plugin_st n_u m_u n_u m_u' (Expl.mk_list []) with
| Error (CC.Handler_action.Conflict expl) ->
Error.errorf
"when merging@ @[for node %a@],@ values %a and %a:@ conflict %a"
@ -118,7 +121,7 @@ module Make (M : MONOID_PLUGIN_ARG) :
"(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 \
%a@ :val2 %a@])@])"
M.name E_node.pp n1 M.pp v1 E_node.pp n2 M.pp v2);
(match M.merge cc n1 v1 n2 v2 e_n1_n2 with
(match M.merge cc plugin_st n1 v1 n2 v2 e_n1_n2 with
| Ok (v', merge_acts) ->
acts := merge_acts;
Cls_tbl.remove values n2;
@ -140,8 +143,8 @@ module Make (M : MONOID_PLUGIN_ARG) :
in
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) iter_all
(* setup *)
let () =
(* hook into the CC's events *)
Event.on (CC.on_new_term cc) ~f:(fun (_, r, t) -> on_new_term cc r t);
Event.on (CC.on_pre_merge2 cc) ~f:(fun (_, ra, rb, expl) ->
on_pre_merge cc ra rb expl);

View file

@ -15,12 +15,18 @@ module type MONOID_PLUGIN_ARG = sig
include Sidekick_sigs.PRINT with type t := t
type state
val create : CC.t -> state
(** Initialize state from the congruence closure *)
val name : string
(** name of the monoid structure (short) *)
(* FIXME: for subs, return list of e_nodes, and assume of_term already
returned data for them. *)
val of_term : CC.t -> E_node.t -> Term.t -> t option * (E_node.t * t) list
val of_term :
CC.t -> state -> E_node.t -> Term.t -> t option * (E_node.t * t) list
(** [of_term n t], where [t] is the Term.t annotating node [n],
must return [maybe_m, l], where:
@ -34,6 +40,7 @@ module type MONOID_PLUGIN_ARG = sig
val merge :
CC.t ->
state ->
E_node.t ->
t ->
E_node.t ->

View file

@ -10,9 +10,21 @@ module type ARG = Intf.ARG
module Make (A : ARG) : sig
val theory : SMT.theory
end = struct
type state = { tst: T.store; gensym: Gensym.t }
type state = {
tst: T.store;
gensym: Gensym.t;
n_simplify: int Stat.counter;
n_clauses: int Stat.counter;
}
let create ~stat tst : state =
{
tst;
gensym = Gensym.create tst;
n_simplify = Stat.mk_int stat "th.bool.simplified";
n_clauses = Stat.mk_int stat "th.bool.cnf-clauses";
}
let create tst : state = { tst; gensym = Gensym.create tst }
let[@inline] not_ tst t = A.mk_bool tst (B_not t)
let[@inline] eq tst a b = A.mk_bool tst (B_eq (a, b))
@ -42,7 +54,11 @@ end = struct
~res:[ Lit.atom (A.mk_bool tst (B_eq (a, b))) ]
in
let[@inline] ret u = Some (u, Iter.of_list !steps) in
let[@inline] ret u =
Stat.incr self.n_simplify;
Some (u, Iter.of_list !steps)
in
(* proof is [t <=> u] *)
let ret_bequiv t1 u =
(add_step_ @@ mk_step_ @@ fun () -> Proof_rules.lemma_bool_equiv t1 u);
@ -123,7 +139,7 @@ end = struct
let[@inline] mk_step_ r = Proof_trace.add_step PA.proof r in
(* handle boolean equality *)
let equiv_ _si ~is_xor ~t t_a t_b : unit =
let equiv_ (self : state) _si ~is_xor ~t t_a t_b : unit =
let a = PA.mk_lit t_a in
let b = PA.mk_lit t_b in
let a =
@ -137,23 +153,30 @@ end = struct
(* proxy => a<=> b,
¬proxy => a xor b *)
Stat.incr self.n_clauses;
PA.add_clause
[ Lit.neg lit; Lit.neg a; b ]
(if is_xor then
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "xor-e+" [ t ]
else
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "eq-e" [ t; t_a ]);
Stat.incr self.n_clauses;
PA.add_clause
[ Lit.neg lit; Lit.neg b; a ]
(if is_xor then
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "xor-e-" [ t ]
else
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "eq-e" [ t; t_b ]);
Stat.incr self.n_clauses;
PA.add_clause [ lit; a; b ]
(if is_xor then
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "xor-i" [ t; t_a ]
else
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "eq-i+" [ t ]);
Stat.incr self.n_clauses;
PA.add_clause
[ lit; Lit.neg a; Lit.neg b ]
(if is_xor then
@ -174,10 +197,13 @@ end = struct
List.iter
(fun u ->
let t_u = Lit.term u in
Stat.incr self.n_clauses;
PA.add_clause
[ Lit.neg lit; u ]
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "and-e" [ t; t_u ]))
subs;
Stat.incr self.n_clauses;
PA.add_clause
(lit :: List.map Lit.neg subs)
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "and-i" [ t ])
@ -189,10 +215,13 @@ end = struct
List.iter
(fun u ->
let t_u = Lit.term u in
Stat.incr self.n_clauses;
PA.add_clause
[ Lit.neg u; lit ]
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "or-i" [ t; t_u ]))
subs;
Stat.incr self.n_clauses;
PA.add_clause (Lit.neg lit :: subs)
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "or-e" [ t ])
| B_imply (a, b) ->
@ -208,29 +237,35 @@ end = struct
List.iter
(fun u ->
let t_u = Lit.term u in
Stat.incr self.n_clauses;
PA.add_clause
[ Lit.neg u; lit ]
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "imp-i" [ t; t_u ]))
subs;
Stat.incr self.n_clauses;
PA.add_clause (Lit.neg lit :: subs)
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "imp-e" [ t ])
| B_ite (a, b, c) ->
let lit_a = PA.mk_lit a in
Stat.incr self.n_clauses;
PA.add_clause
[ Lit.neg lit_a; PA.mk_lit (eq self.tst t b) ]
(mk_step_ @@ fun () -> Proof_rules.lemma_ite_true ~ite:t);
Stat.incr self.n_clauses;
PA.add_clause
[ lit_a; PA.mk_lit (eq self.tst t c) ]
(mk_step_ @@ fun () -> Proof_rules.lemma_ite_false ~ite:t)
| B_eq _ | B_neq _ -> ()
| B_equiv (a, b) -> equiv_ si ~t ~is_xor:false a b
| B_xor (a, b) -> equiv_ si ~t ~is_xor:true a b
| B_equiv (a, b) -> equiv_ self si ~t ~is_xor:false a b
| B_xor (a, b) -> equiv_ self si ~t ~is_xor:true a b
| B_atom _ -> ());
()
let create_and_setup si =
Log.debug 2 "(th-bool.setup)";
let st = create (SI.tst si) in
let st = create ~stat:(SI.stats si) (SI.tst si) in
SI.add_simplifier si (simplify st);
SI.on_preprocess si (cnf st);
st

View file

@ -23,18 +23,26 @@ end = struct
let name = name
type state = { n_merges: int Stat.counter; n_conflict: int Stat.counter }
let create cc : state =
{
n_merges = Stat.mk_int (CC.stat cc) "th.cstor.merges";
n_conflict = Stat.mk_int (CC.stat cc) "th.cstor.conflicts";
}
let pp out (v : t) =
Fmt.fprintf out "(@[cstor %a@ :term %a@])" Const.pp v.cstor T.pp_debug v.t
(* attach data to constructor terms *)
let of_term cc n (t : T.t) : _ option * _ =
let of_term cc _ n (t : T.t) : _ option * _ =
match A.view_as_cstor t with
| T_cstor (cstor, args) ->
let args = CCArray.map (CC.add_term cc) args in
Some { n; t; cstor; args }, []
| _ -> None, []
let merge _cc n1 v1 n2 v2 e_n1_n2 : _ result =
let merge _cc state n1 v1 n2 v2 e_n1_n2 : _ result =
Log.debugf 5 (fun k ->
k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])" name
E_node.pp n1 T.pp_debug v1.t E_node.pp n2 T.pp_debug v2.t);
@ -50,14 +58,18 @@ end = struct
assert (CCArray.length v1.args = CCArray.length v2.args);
let acts =
CCArray.map2
(fun u1 u2 -> CC.Handler_action.Act_merge (u1, u2, expl))
(fun u1 u2 ->
Stat.incr state.n_merges;
CC.Handler_action.Act_merge (u1, u2, expl))
v1.args v2.args
|> Array.to_list
in
Ok (v1, acts)
) else
) else (
(* different function: disjointness *)
Stat.incr state.n_conflict;
Error (CC.Handler_action.Conflict expl)
)
end
module ST = Sidekick_cc.Plugin.Make (Monoid)

View file

@ -150,6 +150,14 @@ end = struct
module Monoid_cstor = struct
let name = "th-data.cstor"
type state = { n_merges: int Stat.counter; n_conflict: int Stat.counter }
let create cc : state =
{
n_merges = Stat.mk_int (CC.stat cc) "th.data.cstor-merges";
n_conflict = Stat.mk_int (CC.stat cc) "th.data.cstor-conflicts";
}
(* associate to each class a unique constructor term in the class (if any) *)
type t = { c_n: E_node.t; c_cstor: A.Cstor.t; c_args: E_node.t list }
@ -158,14 +166,14 @@ end = struct
A.Cstor.pp v.c_cstor E_node.pp v.c_n (Util.pp_list E_node.pp) v.c_args
(* attach data to constructor terms *)
let of_term cc n (t : Term.t) : _ option * _ list =
let of_term cc _ n (t : Term.t) : _ option * _ list =
match A.view_as_data t with
| T_cstor (cstor, args) ->
let args = List.map (CC.add_term cc) args in
Some { c_n = n; c_cstor = cstor; c_args = args }, []
| _ -> None, []
let merge cc n1 c1 n2 c2 e_n1_n2 : _ result =
let merge cc state n1 c1 n2 c2 e_n1_n2 : _ result =
Log.debugf 5 (fun k ->
k "(@[%s.merge@ (@[:c1 %a@ %a@])@ (@[:c2 %a@ %a@])@])" name E_node.pp
n1 pp c1 E_node.pp n2 pp c2);
@ -194,8 +202,10 @@ end = struct
let acts = ref [] in
CCList.iteri2
(fun i u1 u2 ->
Stat.incr state.n_merges;
acts := CC.Handler_action.Act_merge (u1, u2, expl_merge i) :: !acts)
c1.c_args c2.c_args;
Ok (c1, !acts)
) else (
(* different function: disjointness *)
@ -205,6 +215,7 @@ end = struct
@@ fun () -> Proof_rules.lemma_cstor_distinct t1 t2
in
Stat.incr state.n_conflict;
Error (CC.Handler_action.Conflict expl)
)
end
@ -214,6 +225,10 @@ end = struct
module Monoid_parents = struct
let name = "th-data.parents"
type state = unit
let create _ = ()
type select = {
sel_n: E_node.t;
sel_cstor: A.Cstor.t;
@ -243,7 +258,7 @@ end = struct
v.parent_is_a
(* attach data to constructor terms *)
let of_term cc n (t : Term.t) : _ option * _ list =
let of_term cc () n (t : Term.t) : _ option * _ list =
match A.view_as_data t with
| T_select (c, i, u) ->
let u = CC.add_term cc u in
@ -266,7 +281,7 @@ end = struct
None, [ u, m_sel ]
| T_cstor _ | T_other _ -> None, []
let merge _cc n1 v1 n2 v2 _e : _ result =
let merge _cc () n1 v1 n2 v2 _e : _ result =
Log.debugf 5 (fun k ->
k "(@[%s.merge@ @[:c1 %a@ :v %a@]@ @[:c2 %a@ :v %a@]@])" name
E_node.pp n1 pp v1 E_node.pp n2 pp v2);
@ -795,7 +810,7 @@ end = struct
case_split_done = Term.Tbl.create 16;
cards = Card.create ();
stat_acycl_conflict =
Stat.mk_int (SI.stats solver) "data.acycl.conflict";
Stat.mk_int (SI.stats solver) "th.data.acycl.conflict";
}
in
Log.debugf 1 (fun k -> k "(setup :%s)" name);

View file

@ -78,6 +78,10 @@ module Make (A : ARG) = (* : S with module A = A *) struct
module Monoid_exprs = struct
let name = "lra.const"
type state = unit
let create _ = ()
type single = { le: LE.t; n: E_node.t }
type t = single list
@ -89,7 +93,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
| [ x ] -> pp_single out x
| _ -> Fmt.fprintf out "(@[exprs@ %a@])" (Util.pp_list pp_single) self
let of_term _cc n t =
let of_term _cc () n t =
match A.view_as_lra t with
| LRA_const _ | LRA_op _ | LRA_mult _ ->
let le = as_linexp t in
@ -100,7 +104,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* merge lists. If two linear expressions equal up to a constant are
merged, conflict. *)
let merge _cc n1 l1 n2 l2 expl_12 : _ result =
let merge _cc () n1 l1 n2 l2 expl_12 : _ result =
try
let i = Iter.(product (of_list l1) (of_list l2)) in
i (fun (s1, s2) ->
@ -138,7 +142,8 @@ module Make (A : ARG) = (* : S with module A = A *) struct
mutable last_res: SimpSolver.result option;
}
let create ?(stat = Stat.create ()) (si : SI.t) : state =
let create (si : SI.t) : state =
let stat = SI.stats si in
let proof = SI.proof si in
let tst = SI.tst si in
{
@ -692,8 +697,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
let create_and_setup si =
Log.debug 2 "(th-lra.setup)";
let stat = SI.stats si in
let st = create ~stat si in
let st = create si in
SMT.Registry.set (SI.registry si) k_state st;
SI.add_simplifier si (simplify st);
SI.on_preprocess si (preproc_lra st);