mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
feat(cc.plugin): plugins have state, passed at init
This commit is contained in:
parent
e9dae47d0b
commit
94ba945bf3
8 changed files with 106 additions and 26 deletions
|
|
@ -49,6 +49,7 @@ type combine_task =
|
|||
type t = {
|
||||
view_as_cc: view_as_cc;
|
||||
tst: Term.store;
|
||||
stat: Stat.t;
|
||||
proof: Proof_trace.t;
|
||||
tbl: e_node T_tbl.t; (* internalization [term -> e_node] *)
|
||||
signatures_tbl: e_node Sig_tbl.t;
|
||||
|
|
@ -108,6 +109,7 @@ let n_bool self b =
|
|||
|
||||
let[@inline] term_store self = self.tst
|
||||
let[@inline] proof self = self.proof
|
||||
let[@inline] stat self = self.stat
|
||||
|
||||
let allocate_bitfield self ~descr : bitfield =
|
||||
Log.debugf 5 (fun k -> k "(@[cc.allocate-bit-field@ :descr %s@])" descr);
|
||||
|
|
@ -851,6 +853,7 @@ let create_ ?(stat = Stat.global) ?(size = `Big) (tst : Term.store)
|
|||
view_as_cc;
|
||||
tst;
|
||||
proof;
|
||||
stat;
|
||||
tbl = T_tbl.create size;
|
||||
signatures_tbl = Sig_tbl.create size;
|
||||
bitgen;
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ type t
|
|||
|
||||
val term_store : t -> Term.store
|
||||
val proof : t -> Proof_trace.t
|
||||
val stat : t -> Stat.t
|
||||
|
||||
val find : t -> e_node -> repr
|
||||
(** Current representative *)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ module Make (M : MONOID_PLUGIN_ARG) :
|
|||
module CC = CC
|
||||
open A
|
||||
|
||||
(* plugin's state *)
|
||||
let plugin_st = M.create cc
|
||||
|
||||
(* repr -> value for the class *)
|
||||
let values : M.t Cls_tbl.t = Cls_tbl.create ?size ()
|
||||
|
||||
|
|
@ -62,7 +65,7 @@ module Make (M : MONOID_PLUGIN_ARG) :
|
|||
let on_new_term cc n (t : Term.t) : CC.Handler_action.t list =
|
||||
(*Log.debugf 50 (fun k->k "(@[monoid[%s].on-new-term.try@ %a@])" M.name N.pp n);*)
|
||||
let acts = ref [] in
|
||||
let maybe_m, l = M.of_term cc n t in
|
||||
let maybe_m, l = M.of_term cc plugin_st n t in
|
||||
(match maybe_m with
|
||||
| Some v ->
|
||||
Log.debugf 20 (fun k ->
|
||||
|
|
@ -84,7 +87,7 @@ module Make (M : MONOID_PLUGIN_ARG) :
|
|||
Error.errorf "node %a has bitfield but no value" E_node.pp n_u
|
||||
in
|
||||
|
||||
match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with
|
||||
match M.merge cc plugin_st n_u m_u n_u m_u' (Expl.mk_list []) with
|
||||
| Error (CC.Handler_action.Conflict expl) ->
|
||||
Error.errorf
|
||||
"when merging@ @[for node %a@],@ values %a and %a:@ conflict %a"
|
||||
|
|
@ -118,7 +121,7 @@ module Make (M : MONOID_PLUGIN_ARG) :
|
|||
"(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 \
|
||||
%a@ :val2 %a@])@])"
|
||||
M.name E_node.pp n1 M.pp v1 E_node.pp n2 M.pp v2);
|
||||
(match M.merge cc n1 v1 n2 v2 e_n1_n2 with
|
||||
(match M.merge cc plugin_st n1 v1 n2 v2 e_n1_n2 with
|
||||
| Ok (v', merge_acts) ->
|
||||
acts := merge_acts;
|
||||
Cls_tbl.remove values n2;
|
||||
|
|
@ -140,8 +143,8 @@ module Make (M : MONOID_PLUGIN_ARG) :
|
|||
in
|
||||
Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) iter_all
|
||||
|
||||
(* setup *)
|
||||
let () =
|
||||
(* hook into the CC's events *)
|
||||
Event.on (CC.on_new_term cc) ~f:(fun (_, r, t) -> on_new_term cc r t);
|
||||
Event.on (CC.on_pre_merge2 cc) ~f:(fun (_, ra, rb, expl) ->
|
||||
on_pre_merge cc ra rb expl);
|
||||
|
|
|
|||
|
|
@ -15,12 +15,18 @@ module type MONOID_PLUGIN_ARG = sig
|
|||
|
||||
include Sidekick_sigs.PRINT with type t := t
|
||||
|
||||
type state
|
||||
|
||||
val create : CC.t -> state
|
||||
(** Initialize state from the congruence closure *)
|
||||
|
||||
val name : string
|
||||
(** name of the monoid structure (short) *)
|
||||
|
||||
(* FIXME: for subs, return list of e_nodes, and assume of_term already
|
||||
returned data for them. *)
|
||||
val of_term : CC.t -> E_node.t -> Term.t -> t option * (E_node.t * t) list
|
||||
val of_term :
|
||||
CC.t -> state -> E_node.t -> Term.t -> t option * (E_node.t * t) list
|
||||
(** [of_term n t], where [t] is the Term.t annotating node [n],
|
||||
must return [maybe_m, l], where:
|
||||
|
||||
|
|
@ -34,6 +40,7 @@ module type MONOID_PLUGIN_ARG = sig
|
|||
|
||||
val merge :
|
||||
CC.t ->
|
||||
state ->
|
||||
E_node.t ->
|
||||
t ->
|
||||
E_node.t ->
|
||||
|
|
|
|||
|
|
@ -10,9 +10,21 @@ module type ARG = Intf.ARG
|
|||
module Make (A : ARG) : sig
|
||||
val theory : SMT.theory
|
||||
end = struct
|
||||
type state = { tst: T.store; gensym: Gensym.t }
|
||||
type state = {
|
||||
tst: T.store;
|
||||
gensym: Gensym.t;
|
||||
n_simplify: int Stat.counter;
|
||||
n_clauses: int Stat.counter;
|
||||
}
|
||||
|
||||
let create ~stat tst : state =
|
||||
{
|
||||
tst;
|
||||
gensym = Gensym.create tst;
|
||||
n_simplify = Stat.mk_int stat "th.bool.simplified";
|
||||
n_clauses = Stat.mk_int stat "th.bool.cnf-clauses";
|
||||
}
|
||||
|
||||
let create tst : state = { tst; gensym = Gensym.create tst }
|
||||
let[@inline] not_ tst t = A.mk_bool tst (B_not t)
|
||||
let[@inline] eq tst a b = A.mk_bool tst (B_eq (a, b))
|
||||
|
||||
|
|
@ -42,7 +54,11 @@ end = struct
|
|||
~res:[ Lit.atom (A.mk_bool tst (B_eq (a, b))) ]
|
||||
in
|
||||
|
||||
let[@inline] ret u = Some (u, Iter.of_list !steps) in
|
||||
let[@inline] ret u =
|
||||
Stat.incr self.n_simplify;
|
||||
Some (u, Iter.of_list !steps)
|
||||
in
|
||||
|
||||
(* proof is [t <=> u] *)
|
||||
let ret_bequiv t1 u =
|
||||
(add_step_ @@ mk_step_ @@ fun () -> Proof_rules.lemma_bool_equiv t1 u);
|
||||
|
|
@ -123,7 +139,7 @@ end = struct
|
|||
let[@inline] mk_step_ r = Proof_trace.add_step PA.proof r in
|
||||
|
||||
(* handle boolean equality *)
|
||||
let equiv_ _si ~is_xor ~t t_a t_b : unit =
|
||||
let equiv_ (self : state) _si ~is_xor ~t t_a t_b : unit =
|
||||
let a = PA.mk_lit t_a in
|
||||
let b = PA.mk_lit t_b in
|
||||
let a =
|
||||
|
|
@ -137,23 +153,30 @@ end = struct
|
|||
|
||||
(* proxy => a<=> b,
|
||||
¬proxy => a xor b *)
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ Lit.neg lit; Lit.neg a; b ]
|
||||
(if is_xor then
|
||||
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "xor-e+" [ t ]
|
||||
else
|
||||
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "eq-e" [ t; t_a ]);
|
||||
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ Lit.neg lit; Lit.neg b; a ]
|
||||
(if is_xor then
|
||||
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "xor-e-" [ t ]
|
||||
else
|
||||
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "eq-e" [ t; t_b ]);
|
||||
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause [ lit; a; b ]
|
||||
(if is_xor then
|
||||
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "xor-i" [ t; t_a ]
|
||||
else
|
||||
mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "eq-i+" [ t ]);
|
||||
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ lit; Lit.neg a; Lit.neg b ]
|
||||
(if is_xor then
|
||||
|
|
@ -174,10 +197,13 @@ end = struct
|
|||
List.iter
|
||||
(fun u ->
|
||||
let t_u = Lit.term u in
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ Lit.neg lit; u ]
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "and-e" [ t; t_u ]))
|
||||
subs;
|
||||
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
(lit :: List.map Lit.neg subs)
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "and-i" [ t ])
|
||||
|
|
@ -189,10 +215,13 @@ end = struct
|
|||
List.iter
|
||||
(fun u ->
|
||||
let t_u = Lit.term u in
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ Lit.neg u; lit ]
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "or-i" [ t; t_u ]))
|
||||
subs;
|
||||
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause (Lit.neg lit :: subs)
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "or-e" [ t ])
|
||||
| B_imply (a, b) ->
|
||||
|
|
@ -208,29 +237,35 @@ end = struct
|
|||
List.iter
|
||||
(fun u ->
|
||||
let t_u = Lit.term u in
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ Lit.neg u; lit ]
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "imp-i" [ t; t_u ]))
|
||||
subs;
|
||||
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause (Lit.neg lit :: subs)
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_bool_c "imp-e" [ t ])
|
||||
| B_ite (a, b, c) ->
|
||||
let lit_a = PA.mk_lit a in
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ Lit.neg lit_a; PA.mk_lit (eq self.tst t b) ]
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_ite_true ~ite:t);
|
||||
|
||||
Stat.incr self.n_clauses;
|
||||
PA.add_clause
|
||||
[ lit_a; PA.mk_lit (eq self.tst t c) ]
|
||||
(mk_step_ @@ fun () -> Proof_rules.lemma_ite_false ~ite:t)
|
||||
| B_eq _ | B_neq _ -> ()
|
||||
| B_equiv (a, b) -> equiv_ si ~t ~is_xor:false a b
|
||||
| B_xor (a, b) -> equiv_ si ~t ~is_xor:true a b
|
||||
| B_equiv (a, b) -> equiv_ self si ~t ~is_xor:false a b
|
||||
| B_xor (a, b) -> equiv_ self si ~t ~is_xor:true a b
|
||||
| B_atom _ -> ());
|
||||
()
|
||||
|
||||
let create_and_setup si =
|
||||
Log.debug 2 "(th-bool.setup)";
|
||||
let st = create (SI.tst si) in
|
||||
let st = create ~stat:(SI.stats si) (SI.tst si) in
|
||||
SI.add_simplifier si (simplify st);
|
||||
SI.on_preprocess si (cnf st);
|
||||
st
|
||||
|
|
|
|||
|
|
@ -23,18 +23,26 @@ end = struct
|
|||
|
||||
let name = name
|
||||
|
||||
type state = { n_merges: int Stat.counter; n_conflict: int Stat.counter }
|
||||
|
||||
let create cc : state =
|
||||
{
|
||||
n_merges = Stat.mk_int (CC.stat cc) "th.cstor.merges";
|
||||
n_conflict = Stat.mk_int (CC.stat cc) "th.cstor.conflicts";
|
||||
}
|
||||
|
||||
let pp out (v : t) =
|
||||
Fmt.fprintf out "(@[cstor %a@ :term %a@])" Const.pp v.cstor T.pp_debug v.t
|
||||
|
||||
(* attach data to constructor terms *)
|
||||
let of_term cc n (t : T.t) : _ option * _ =
|
||||
let of_term cc _ n (t : T.t) : _ option * _ =
|
||||
match A.view_as_cstor t with
|
||||
| T_cstor (cstor, args) ->
|
||||
let args = CCArray.map (CC.add_term cc) args in
|
||||
Some { n; t; cstor; args }, []
|
||||
| _ -> None, []
|
||||
|
||||
let merge _cc n1 v1 n2 v2 e_n1_n2 : _ result =
|
||||
let merge _cc state n1 v1 n2 v2 e_n1_n2 : _ result =
|
||||
Log.debugf 5 (fun k ->
|
||||
k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])" name
|
||||
E_node.pp n1 T.pp_debug v1.t E_node.pp n2 T.pp_debug v2.t);
|
||||
|
|
@ -50,14 +58,18 @@ end = struct
|
|||
assert (CCArray.length v1.args = CCArray.length v2.args);
|
||||
let acts =
|
||||
CCArray.map2
|
||||
(fun u1 u2 -> CC.Handler_action.Act_merge (u1, u2, expl))
|
||||
(fun u1 u2 ->
|
||||
Stat.incr state.n_merges;
|
||||
CC.Handler_action.Act_merge (u1, u2, expl))
|
||||
v1.args v2.args
|
||||
|> Array.to_list
|
||||
in
|
||||
Ok (v1, acts)
|
||||
) else
|
||||
) else (
|
||||
(* different function: disjointness *)
|
||||
Stat.incr state.n_conflict;
|
||||
Error (CC.Handler_action.Conflict expl)
|
||||
)
|
||||
end
|
||||
|
||||
module ST = Sidekick_cc.Plugin.Make (Monoid)
|
||||
|
|
|
|||
|
|
@ -150,6 +150,14 @@ end = struct
|
|||
module Monoid_cstor = struct
|
||||
let name = "th-data.cstor"
|
||||
|
||||
type state = { n_merges: int Stat.counter; n_conflict: int Stat.counter }
|
||||
|
||||
let create cc : state =
|
||||
{
|
||||
n_merges = Stat.mk_int (CC.stat cc) "th.data.cstor-merges";
|
||||
n_conflict = Stat.mk_int (CC.stat cc) "th.data.cstor-conflicts";
|
||||
}
|
||||
|
||||
(* associate to each class a unique constructor term in the class (if any) *)
|
||||
type t = { c_n: E_node.t; c_cstor: A.Cstor.t; c_args: E_node.t list }
|
||||
|
||||
|
|
@ -158,14 +166,14 @@ end = struct
|
|||
A.Cstor.pp v.c_cstor E_node.pp v.c_n (Util.pp_list E_node.pp) v.c_args
|
||||
|
||||
(* attach data to constructor terms *)
|
||||
let of_term cc n (t : Term.t) : _ option * _ list =
|
||||
let of_term cc _ n (t : Term.t) : _ option * _ list =
|
||||
match A.view_as_data t with
|
||||
| T_cstor (cstor, args) ->
|
||||
let args = List.map (CC.add_term cc) args in
|
||||
Some { c_n = n; c_cstor = cstor; c_args = args }, []
|
||||
| _ -> None, []
|
||||
|
||||
let merge cc n1 c1 n2 c2 e_n1_n2 : _ result =
|
||||
let merge cc state n1 c1 n2 c2 e_n1_n2 : _ result =
|
||||
Log.debugf 5 (fun k ->
|
||||
k "(@[%s.merge@ (@[:c1 %a@ %a@])@ (@[:c2 %a@ %a@])@])" name E_node.pp
|
||||
n1 pp c1 E_node.pp n2 pp c2);
|
||||
|
|
@ -194,8 +202,10 @@ end = struct
|
|||
let acts = ref [] in
|
||||
CCList.iteri2
|
||||
(fun i u1 u2 ->
|
||||
Stat.incr state.n_merges;
|
||||
acts := CC.Handler_action.Act_merge (u1, u2, expl_merge i) :: !acts)
|
||||
c1.c_args c2.c_args;
|
||||
|
||||
Ok (c1, !acts)
|
||||
) else (
|
||||
(* different function: disjointness *)
|
||||
|
|
@ -205,6 +215,7 @@ end = struct
|
|||
@@ fun () -> Proof_rules.lemma_cstor_distinct t1 t2
|
||||
in
|
||||
|
||||
Stat.incr state.n_conflict;
|
||||
Error (CC.Handler_action.Conflict expl)
|
||||
)
|
||||
end
|
||||
|
|
@ -214,6 +225,10 @@ end = struct
|
|||
module Monoid_parents = struct
|
||||
let name = "th-data.parents"
|
||||
|
||||
type state = unit
|
||||
|
||||
let create _ = ()
|
||||
|
||||
type select = {
|
||||
sel_n: E_node.t;
|
||||
sel_cstor: A.Cstor.t;
|
||||
|
|
@ -243,7 +258,7 @@ end = struct
|
|||
v.parent_is_a
|
||||
|
||||
(* attach data to constructor terms *)
|
||||
let of_term cc n (t : Term.t) : _ option * _ list =
|
||||
let of_term cc () n (t : Term.t) : _ option * _ list =
|
||||
match A.view_as_data t with
|
||||
| T_select (c, i, u) ->
|
||||
let u = CC.add_term cc u in
|
||||
|
|
@ -266,7 +281,7 @@ end = struct
|
|||
None, [ u, m_sel ]
|
||||
| T_cstor _ | T_other _ -> None, []
|
||||
|
||||
let merge _cc n1 v1 n2 v2 _e : _ result =
|
||||
let merge _cc () n1 v1 n2 v2 _e : _ result =
|
||||
Log.debugf 5 (fun k ->
|
||||
k "(@[%s.merge@ @[:c1 %a@ :v %a@]@ @[:c2 %a@ :v %a@]@])" name
|
||||
E_node.pp n1 pp v1 E_node.pp n2 pp v2);
|
||||
|
|
@ -795,7 +810,7 @@ end = struct
|
|||
case_split_done = Term.Tbl.create 16;
|
||||
cards = Card.create ();
|
||||
stat_acycl_conflict =
|
||||
Stat.mk_int (SI.stats solver) "data.acycl.conflict";
|
||||
Stat.mk_int (SI.stats solver) "th.data.acycl.conflict";
|
||||
}
|
||||
in
|
||||
Log.debugf 1 (fun k -> k "(setup :%s)" name);
|
||||
|
|
|
|||
|
|
@ -78,6 +78,10 @@ module Make (A : ARG) = (* : S with module A = A *) struct
|
|||
module Monoid_exprs = struct
|
||||
let name = "lra.const"
|
||||
|
||||
type state = unit
|
||||
|
||||
let create _ = ()
|
||||
|
||||
type single = { le: LE.t; n: E_node.t }
|
||||
type t = single list
|
||||
|
||||
|
|
@ -89,7 +93,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
|
|||
| [ x ] -> pp_single out x
|
||||
| _ -> Fmt.fprintf out "(@[exprs@ %a@])" (Util.pp_list pp_single) self
|
||||
|
||||
let of_term _cc n t =
|
||||
let of_term _cc () n t =
|
||||
match A.view_as_lra t with
|
||||
| LRA_const _ | LRA_op _ | LRA_mult _ ->
|
||||
let le = as_linexp t in
|
||||
|
|
@ -100,7 +104,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
|
|||
|
||||
(* merge lists. If two linear expressions equal up to a constant are
|
||||
merged, conflict. *)
|
||||
let merge _cc n1 l1 n2 l2 expl_12 : _ result =
|
||||
let merge _cc () n1 l1 n2 l2 expl_12 : _ result =
|
||||
try
|
||||
let i = Iter.(product (of_list l1) (of_list l2)) in
|
||||
i (fun (s1, s2) ->
|
||||
|
|
@ -138,7 +142,8 @@ module Make (A : ARG) = (* : S with module A = A *) struct
|
|||
mutable last_res: SimpSolver.result option;
|
||||
}
|
||||
|
||||
let create ?(stat = Stat.create ()) (si : SI.t) : state =
|
||||
let create (si : SI.t) : state =
|
||||
let stat = SI.stats si in
|
||||
let proof = SI.proof si in
|
||||
let tst = SI.tst si in
|
||||
{
|
||||
|
|
@ -692,8 +697,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
|
|||
|
||||
let create_and_setup si =
|
||||
Log.debug 2 "(th-lra.setup)";
|
||||
let stat = SI.stats si in
|
||||
let st = create ~stat si in
|
||||
let st = create si in
|
||||
SMT.Registry.set (SI.registry si) k_state st;
|
||||
SI.add_simplifier si (simplify st);
|
||||
SI.on_preprocess si (preproc_lra st);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue