mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-10 21:24:06 -05:00
feat(cc): callback on propagations
This commit is contained in:
parent
357dc73426
commit
ef1110925f
3 changed files with 17 additions and 1 deletions
|
|
@ -213,6 +213,7 @@ module Make(CC_A: ARG) = struct
|
||||||
mutable on_merge: ev_on_merge list;
|
mutable on_merge: ev_on_merge list;
|
||||||
mutable on_new_term: ev_on_new_term list;
|
mutable on_new_term: ev_on_new_term list;
|
||||||
mutable on_conflict: ev_on_conflict list;
|
mutable on_conflict: ev_on_conflict list;
|
||||||
|
mutable on_propagate: ev_on_propagate list;
|
||||||
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
|
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
|
||||||
(* proof state *)
|
(* proof state *)
|
||||||
ps_queue: (node*node) Vec.t;
|
ps_queue: (node*node) Vec.t;
|
||||||
|
|
@ -233,6 +234,7 @@ module Make(CC_A: ARG) = struct
|
||||||
and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
|
and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
|
||||||
and ev_on_new_term = t -> N.t -> term -> unit
|
and ev_on_new_term = t -> N.t -> term -> unit
|
||||||
and ev_on_conflict = t -> lit list -> unit
|
and ev_on_conflict = t -> lit list -> unit
|
||||||
|
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
|
||||||
|
|
||||||
let[@inline] size_ (r:repr) = r.n_size
|
let[@inline] size_ (r:repr) = r.n_size
|
||||||
let[@inline] true_ cc = Lazy.force cc.true_
|
let[@inline] true_ cc = Lazy.force cc.true_
|
||||||
|
|
@ -710,6 +712,7 @@ module Make(CC_A: ARG) = struct
|
||||||
let e = lazy (explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1) in
|
let e = lazy (explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1) in
|
||||||
fun () -> Lazy.force e
|
fun () -> Lazy.force e
|
||||||
in
|
in
|
||||||
|
List.iter (fun f -> f cc lit reason) cc.on_propagate;
|
||||||
Stat.incr cc.count_props;
|
Stat.incr cc.count_props;
|
||||||
CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default
|
CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default
|
||||||
| _ -> ())
|
| _ -> ())
|
||||||
|
|
@ -793,6 +796,7 @@ module Make(CC_A: ARG) = struct
|
||||||
basically, just have [n] point to true/false and thus acquire
|
basically, just have [n] point to true/false and thus acquire
|
||||||
the corresponding value, so its superterms (like [ite]) can evaluate
|
the corresponding value, so its superterms (like [ite]) can evaluate
|
||||||
properly *)
|
properly *)
|
||||||
|
(* TODO: use oriented merge (force direction [n -> rhs]) *)
|
||||||
merge_classes cc n rhs (Expl.mk_lit lit)
|
merge_classes cc n rhs (Expl.mk_lit lit)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
@ -816,9 +820,11 @@ module Make(CC_A: ARG) = struct
|
||||||
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
|
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
|
||||||
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
|
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
|
||||||
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
|
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
|
||||||
|
let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate
|
||||||
|
|
||||||
let create ?(stat=Stat.global)
|
let create ?(stat=Stat.global)
|
||||||
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(size=`Big)
|
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[])
|
||||||
|
?(size=`Big)
|
||||||
(tst:term_state) : t =
|
(tst:term_state) : t =
|
||||||
let size = match size with `Small -> 128 | `Big -> 2048 in
|
let size = match size with `Small -> 128 | `Big -> 2048 in
|
||||||
let rec cc = {
|
let rec cc = {
|
||||||
|
|
@ -828,6 +834,7 @@ module Make(CC_A: ARG) = struct
|
||||||
on_merge;
|
on_merge;
|
||||||
on_new_term;
|
on_new_term;
|
||||||
on_conflict;
|
on_conflict;
|
||||||
|
on_propagate;
|
||||||
pending=Vec.create();
|
pending=Vec.create();
|
||||||
combine=Vec.create();
|
combine=Vec.create();
|
||||||
ps_lits=[];
|
ps_lits=[];
|
||||||
|
|
|
||||||
|
|
@ -221,12 +221,14 @@ module type CC_S = sig
|
||||||
type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
|
type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
|
||||||
type ev_on_new_term = t -> N.t -> term -> unit
|
type ev_on_new_term = t -> N.t -> term -> unit
|
||||||
type ev_on_conflict = t -> lit list -> unit
|
type ev_on_conflict = t -> lit list -> unit
|
||||||
|
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
|
||||||
|
|
||||||
val create :
|
val create :
|
||||||
?stat:Stat.t ->
|
?stat:Stat.t ->
|
||||||
?on_merge:ev_on_merge list ->
|
?on_merge:ev_on_merge list ->
|
||||||
?on_new_term:ev_on_new_term list ->
|
?on_new_term:ev_on_new_term list ->
|
||||||
?on_conflict:ev_on_conflict list ->
|
?on_conflict:ev_on_conflict list ->
|
||||||
|
?on_propagate:ev_on_propagate list ->
|
||||||
?size:[`Small | `Big] ->
|
?size:[`Small | `Big] ->
|
||||||
term_state ->
|
term_state ->
|
||||||
t
|
t
|
||||||
|
|
@ -242,6 +244,9 @@ module type CC_S = sig
|
||||||
val on_conflict : t -> ev_on_conflict -> unit
|
val on_conflict : t -> ev_on_conflict -> unit
|
||||||
(** Called when the congruence closure finds a conflict *)
|
(** Called when the congruence closure finds a conflict *)
|
||||||
|
|
||||||
|
val on_propagate : t -> ev_on_propagate -> unit
|
||||||
|
(** Called when the congruence closure propagates a literal *)
|
||||||
|
|
||||||
val set_as_lit : t -> N.t -> lit -> unit
|
val set_as_lit : t -> N.t -> lit -> unit
|
||||||
(** map the given node to a literal. *)
|
(** map the given node to a literal. *)
|
||||||
|
|
||||||
|
|
@ -412,6 +417,9 @@ module type SOLVER_INTERNAL = sig
|
||||||
val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit
|
val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit
|
||||||
(** Callback called on every CC conflict *)
|
(** Callback called on every CC conflict *)
|
||||||
|
|
||||||
|
val on_cc_propagate : t -> (CC.t -> lit -> (unit -> lit list) -> unit) -> unit
|
||||||
|
(** Callback called on every CC propagation *)
|
||||||
|
|
||||||
val on_partial_check : t -> (t -> actions -> lit Iter.t -> unit) -> unit
|
val on_partial_check : t -> (t -> actions -> lit Iter.t -> unit) -> unit
|
||||||
(** Register callbacked to be called with the slice of literals
|
(** Register callbacked to be called with the slice of literals
|
||||||
newly added on the trail.
|
newly added on the trail.
|
||||||
|
|
|
||||||
|
|
@ -210,6 +210,7 @@ module Make(A : ARG)
|
||||||
let on_cc_new_term self f = CC.on_new_term (cc self) f
|
let on_cc_new_term self f = CC.on_new_term (cc self) f
|
||||||
let on_cc_merge self f = CC.on_merge (cc self) f
|
let on_cc_merge self f = CC.on_merge (cc self) f
|
||||||
let on_cc_conflict self f = CC.on_conflict (cc self) f
|
let on_cc_conflict self f = CC.on_conflict (cc self) f
|
||||||
|
let on_cc_propagate self f = CC.on_propagate (cc self) f
|
||||||
|
|
||||||
let cc_add_term self t = CC.add_term (cc self) t
|
let cc_add_term self t = CC.add_term (cc self) t
|
||||||
let cc_find self n = CC.find (cc self) n
|
let cc_find self n = CC.find (cc self) n
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue