diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 997f66fe..cabc3e6a 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -213,6 +213,7 @@ module Make(CC_A: ARG) = struct mutable on_merge: ev_on_merge list; mutable on_new_term: ev_on_new_term list; mutable on_conflict: ev_on_conflict list; + mutable on_propagate: ev_on_propagate list; mutable ps_lits: lit list; (* TODO: thread it around instead? *) (* proof state *) ps_queue: (node*node) Vec.t; @@ -233,6 +234,7 @@ module Make(CC_A: ARG) = struct and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit and ev_on_new_term = t -> N.t -> term -> unit and ev_on_conflict = t -> lit list -> unit + and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit let[@inline] size_ (r:repr) = r.n_size let[@inline] true_ cc = Lazy.force cc.true_ @@ -710,6 +712,7 @@ module Make(CC_A: ARG) = struct let e = lazy (explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1) in fun () -> Lazy.force e in + List.iter (fun f -> f cc lit reason) cc.on_propagate; Stat.incr cc.count_props; CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default | _ -> ()) @@ -793,6 +796,7 @@ module Make(CC_A: ARG) = struct basically, just have [n] point to true/false and thus acquire the corresponding value, so its superterms (like [ite]) can evaluate properly *) + (* TODO: use oriented merge (force direction [n -> rhs]) *) merge_classes cc n rhs (Expl.mk_lit lit) end @@ -816,9 +820,11 @@ module Make(CC_A: ARG) = struct let on_merge cc f = cc.on_merge <- f :: cc.on_merge let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict + let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate let create ?(stat=Stat.global) - ?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(size=`Big) + ?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[]) + ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in let rec cc = { @@ -828,6 +834,7 @@ module Make(CC_A: ARG) = struct on_merge; on_new_term; on_conflict; + on_propagate; pending=Vec.create(); combine=Vec.create(); ps_lits=[]; diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 346ec14e..044b03eb 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -221,12 +221,14 @@ module type CC_S = sig type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit type ev_on_new_term = t -> N.t -> term -> unit type ev_on_conflict = t -> lit list -> unit + type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit val create : ?stat:Stat.t -> ?on_merge:ev_on_merge list -> ?on_new_term:ev_on_new_term list -> ?on_conflict:ev_on_conflict list -> + ?on_propagate:ev_on_propagate list -> ?size:[`Small | `Big] -> term_state -> t @@ -242,6 +244,9 @@ module type CC_S = sig val on_conflict : t -> ev_on_conflict -> unit (** Called when the congruence closure finds a conflict *) + val on_propagate : t -> ev_on_propagate -> unit + (** Called when the congruence closure propagates a literal *) + val set_as_lit : t -> N.t -> lit -> unit (** map the given node to a literal. *) @@ -412,6 +417,9 @@ module type SOLVER_INTERNAL = sig val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit (** Callback called on every CC conflict *) + val on_cc_propagate : t -> (CC.t -> lit -> (unit -> lit list) -> unit) -> unit + (** Callback called on every CC propagation *) + val on_partial_check : t -> (t -> actions -> lit Iter.t -> unit) -> unit (** Register callbacked to be called with the slice of literals newly added on the trail. diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index ed7841bb..1f85faac 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -210,6 +210,7 @@ module Make(A : ARG) let on_cc_new_term self f = CC.on_new_term (cc self) f let on_cc_merge self f = CC.on_merge (cc self) f let on_cc_conflict self f = CC.on_conflict (cc self) f + let on_cc_propagate self f = CC.on_propagate (cc self) f let cc_add_term self t = CC.add_term (cc self) t let cc_find self n = CC.find (cc self) n