mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-09 04:35:35 -05:00
lra: refactor theory combination (have CC tell us what terms are subterms)
This commit is contained in:
parent
a7afce3af4
commit
a8e2764e92
4 changed files with 27 additions and 10 deletions
|
|
@ -239,8 +239,6 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
| None, _ ->
|
||||
(* non trivial linexp, give it a fresh name in the simplex *)
|
||||
let proxy = var_encoding_comb self ~pre:"_le" le_comb in
|
||||
T.Tbl.replace self.needs_th_combination proxy ();
|
||||
|
||||
let new_t =
|
||||
match pred with
|
||||
| Eq | Neq -> assert false (* unreachable *)
|
||||
|
|
@ -251,8 +249,6 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
in
|
||||
|
||||
Log.debugf 10 (fun k->k "lra.preprocess:@ %a@ :into %a" T.pp t T.pp new_t);
|
||||
|
||||
T.Tbl.add self.needs_th_combination new_t ();
|
||||
Some new_t
|
||||
|
||||
| Some (coeff, v), pred ->
|
||||
|
|
@ -271,8 +267,6 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
|
||||
let new_t = A.mk_lra tst (LRA_simplex_pred (v, op, q)) in
|
||||
Log.debugf 10 (fun k->k "lra.preprocess@ :%a@ :into %a" T.pp t T.pp new_t);
|
||||
|
||||
T.Tbl.add self.needs_th_combination new_t ();
|
||||
Some new_t
|
||||
end
|
||||
|
||||
|
|
@ -314,9 +308,7 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
Some proxy
|
||||
)
|
||||
|
||||
| LRA_other t when A.has_ty_real t ->
|
||||
T.Tbl.replace self.needs_th_combination t ();
|
||||
None
|
||||
| LRA_other t when A.has_ty_real t -> None
|
||||
| LRA_const _ | LRA_simplex_pred _ | LRA_simplex_var _ | LRA_other _ -> None
|
||||
|
||||
module Q_map = CCMap.Make(Q)
|
||||
|
|
@ -478,6 +470,14 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
do_th_combination self si acts model;
|
||||
()
|
||||
|
||||
(* look for subterms of type Real, for they will need theory combination *)
|
||||
let on_subterm (self:state) _ (t:T.t) : unit =
|
||||
if A.has_ty_real t &&
|
||||
not (T.Tbl.mem self.needs_th_combination t) then (
|
||||
Log.debugf 5 (fun k->k "(@[lra.needs-th-combination@ %a@])" T.pp t);
|
||||
T.Tbl.add self.needs_th_combination t ()
|
||||
)
|
||||
|
||||
let create_and_setup si =
|
||||
Log.debug 2 "(th-lra.setup)";
|
||||
let stat = SI.stats si in
|
||||
|
|
@ -486,6 +486,7 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
SI.add_preprocess si (preproc_lra st);
|
||||
SI.on_final_check si (final_check_ st);
|
||||
SI.on_partial_check si (partial_check_ st);
|
||||
SI.on_cc_is_subterm si (on_subterm st);
|
||||
SI.on_cc_post_merge si
|
||||
(fun _ _ n1 n2 ->
|
||||
if A.has_ty_real (N.term n1) then (
|
||||
|
|
|
|||
|
|
@ -260,6 +260,7 @@ module Make (A: CC_ARG)
|
|||
mutable on_new_term: ev_on_new_term list;
|
||||
mutable on_conflict: ev_on_conflict list;
|
||||
mutable on_propagate: ev_on_propagate list;
|
||||
mutable on_is_subterm : ev_on_is_subterm list;
|
||||
mutable new_merges: bool;
|
||||
bitgen: Bits.bitfield_gen;
|
||||
field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *)
|
||||
|
|
@ -281,6 +282,7 @@ module Make (A: CC_ARG)
|
|||
and ev_on_new_term = t -> N.t -> term -> unit
|
||||
and ev_on_conflict = t -> th:bool -> lit list -> unit
|
||||
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
|
||||
and ev_on_is_subterm = N.t -> term -> unit
|
||||
|
||||
let[@inline] size_ (r:repr) = r.n_size
|
||||
let[@inline] n_true cc = Lazy.force cc.true_
|
||||
|
|
@ -535,6 +537,7 @@ module Make (A: CC_ARG)
|
|||
on_backtrack self (fun () -> sub_r.n_parents <- old_parents);
|
||||
sub_r.n_parents <- Bag.cons n sub_r.n_parents;
|
||||
end;
|
||||
List.iter (fun f -> f sub u) self.on_is_subterm;
|
||||
sub
|
||||
in
|
||||
let[@inline] return x = Some x in
|
||||
|
|
@ -847,9 +850,11 @@ module Make (A: CC_ARG)
|
|||
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
|
||||
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
|
||||
let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate
|
||||
let on_is_subterm cc f = cc.on_is_subterm <- f :: cc.on_is_subterm
|
||||
|
||||
let create ?(stat=Stat.global)
|
||||
?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[])
|
||||
?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[])
|
||||
?(on_conflict=[]) ?(on_propagate=[]) ?(on_is_subterm=[])
|
||||
?(size=`Big)
|
||||
(tst:term_state) : t =
|
||||
let size = match size with `Small -> 128 | `Big -> 2048 in
|
||||
|
|
@ -865,6 +870,7 @@ module Make (A: CC_ARG)
|
|||
on_new_term;
|
||||
on_conflict;
|
||||
on_propagate;
|
||||
on_is_subterm;
|
||||
pending=Vec.create();
|
||||
combine=Vec.create();
|
||||
undo=Backtrack_stack.create();
|
||||
|
|
|
|||
|
|
@ -224,6 +224,7 @@ module type CC_S = sig
|
|||
type ev_on_new_term = t -> N.t -> term -> unit
|
||||
type ev_on_conflict = t -> th:bool -> lit list -> unit
|
||||
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
|
||||
type ev_on_is_subterm = N.t -> term -> unit
|
||||
|
||||
val create :
|
||||
?stat:Stat.t ->
|
||||
|
|
@ -232,6 +233,7 @@ module type CC_S = sig
|
|||
?on_new_term:ev_on_new_term list ->
|
||||
?on_conflict:ev_on_conflict list ->
|
||||
?on_propagate:ev_on_propagate list ->
|
||||
?on_is_subterm:ev_on_is_subterm list ->
|
||||
?size:[`Small | `Big] ->
|
||||
term_state ->
|
||||
t
|
||||
|
|
@ -261,6 +263,9 @@ module type CC_S = sig
|
|||
val on_propagate : t -> ev_on_propagate -> unit
|
||||
(** Called when the congruence closure propagates a literal *)
|
||||
|
||||
val on_is_subterm : t -> ev_on_is_subterm -> unit
|
||||
(** Called on terms that are subterms of function symbols *)
|
||||
|
||||
val set_as_lit : t -> N.t -> lit -> unit
|
||||
(** map the given node to a literal. *)
|
||||
|
||||
|
|
@ -475,6 +480,10 @@ module type SOLVER_INTERNAL = sig
|
|||
(** Callback to add data on terms when they are added to the congruence
|
||||
closure *)
|
||||
|
||||
val on_cc_is_subterm : t -> (CC.N.t -> term -> unit) -> unit
|
||||
(** Callback for when a term is a subterm of another term in the
|
||||
congruence closure *)
|
||||
|
||||
val on_cc_conflict : t -> (CC.t -> th:bool -> lit list -> unit) -> unit
|
||||
(** Callback called on every CC conflict *)
|
||||
|
||||
|
|
|
|||
|
|
@ -280,6 +280,7 @@ module Make(A : ARG)
|
|||
let on_cc_post_merge self f = CC.on_post_merge (cc self) f
|
||||
let on_cc_conflict self f = CC.on_conflict (cc self) f
|
||||
let on_cc_propagate self f = CC.on_propagate (cc self) f
|
||||
let on_cc_is_subterm self f = CC.on_is_subterm (cc self) f
|
||||
|
||||
let cc_add_term self t = CC.add_term (cc self) t
|
||||
let cc_mem_term self t = CC.mem_term (cc self) t
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue