feat(cc): callback on propagations

This commit is contained in:
Simon Cruanes 2019-06-07 14:54:24 -05:00
parent 357dc73426
commit ef1110925f
3 changed files with 17 additions and 1 deletions

View file

@ -213,6 +213,7 @@ module Make(CC_A: ARG) = struct
mutable on_merge: ev_on_merge list;
mutable on_new_term: ev_on_new_term list;
mutable on_conflict: ev_on_conflict list;
mutable on_propagate: ev_on_propagate list;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
@ -233,6 +234,7 @@ module Make(CC_A: ARG) = struct
and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
and ev_on_new_term = t -> N.t -> term -> unit
and ev_on_conflict = t -> lit list -> unit
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
@ -710,6 +712,7 @@ module Make(CC_A: ARG) = struct
let e = lazy (explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1) in
fun () -> Lazy.force e
in
List.iter (fun f -> f cc lit reason) cc.on_propagate;
Stat.incr cc.count_props;
CC_A.Actions.propagate acts lit ~reason CC_A.A.Proof.default
| _ -> ())
@ -793,6 +796,7 @@ module Make(CC_A: ARG) = struct
basically, just have [n] point to true/false and thus acquire
the corresponding value, so its superterms (like [ite]) can evaluate
properly *)
(* TODO: use oriented merge (force direction [n -> rhs]) *)
merge_classes cc n rhs (Expl.mk_lit lit)
end
@ -816,9 +820,11 @@ module Make(CC_A: ARG) = struct
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate
let create ?(stat=Stat.global)
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(size=`Big)
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[])
?(size=`Big)
(tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = {
@ -828,6 +834,7 @@ module Make(CC_A: ARG) = struct
on_merge;
on_new_term;
on_conflict;
on_propagate;
pending=Vec.create();
combine=Vec.create();
ps_lits=[];

View file

@ -221,12 +221,14 @@ module type CC_S = sig
type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
type ev_on_new_term = t -> N.t -> term -> unit
type ev_on_conflict = t -> lit list -> unit
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
val create :
?stat:Stat.t ->
?on_merge:ev_on_merge list ->
?on_new_term:ev_on_new_term list ->
?on_conflict:ev_on_conflict list ->
?on_propagate:ev_on_propagate list ->
?size:[`Small | `Big] ->
term_state ->
t
@ -242,6 +244,9 @@ module type CC_S = sig
val on_conflict : t -> ev_on_conflict -> unit
(** Called when the congruence closure finds a conflict *)
val on_propagate : t -> ev_on_propagate -> unit
(** Called when the congruence closure propagates a literal *)
val set_as_lit : t -> N.t -> lit -> unit
(** map the given node to a literal. *)
@ -412,6 +417,9 @@ module type SOLVER_INTERNAL = sig
val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit
(** Callback called on every CC conflict *)
val on_cc_propagate : t -> (CC.t -> lit -> (unit -> lit list) -> unit) -> unit
(** Callback called on every CC propagation *)
val on_partial_check : t -> (t -> actions -> lit Iter.t -> unit) -> unit
(** Register callbacked to be called with the slice of literals
newly added on the trail.

View file

@ -210,6 +210,7 @@ module Make(A : ARG)
let on_cc_new_term self f = CC.on_new_term (cc self) f
let on_cc_merge self f = CC.on_merge (cc self) f
let on_cc_conflict self f = CC.on_conflict (cc self) f
let on_cc_propagate self f = CC.on_propagate (cc self) f
let cc_add_term self t = CC.add_term (cc self) t
let cc_find self n = CC.find (cc self) n