split on_merge into two events: pre and post merge

This commit is contained in:
Alexander Bentkamp 2019-08-21 10:31:43 +02:00 committed by Simon Cruanes
parent cde983df86
commit 7fe6f07c0b
4 changed files with 33 additions and 16 deletions

View file

@ -247,7 +247,8 @@ module Make(CC_A: ARG) = struct
pending: node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
mutable on_merge: ev_on_merge list;
mutable on_pre_merge: ev_on_pre_merge list;
mutable on_post_merge: ev_on_post_merge list;
mutable on_new_term: ev_on_new_term list;
mutable on_conflict: ev_on_conflict list;
mutable on_propagate: ev_on_propagate list;
@ -266,7 +267,8 @@ module Make(CC_A: ARG) = struct
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
and ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
and ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
and ev_on_new_term = t -> N.t -> term -> unit
and ev_on_conflict = t -> lit list -> unit
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
@ -638,11 +640,11 @@ module Make(CC_A: ARG) = struct
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* call [on_merge] functions, and merge theory data items *)
(* call [on_pre_merge] functions, and merge theory data items *)
begin
(* explanation is [a=ra & e_ab & b=rb] *)
let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in
List.iter (fun f -> f cc acts r_into r_from expl) cc.on_merge;
List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge;
end;
begin
(* parents might have a different signature, check for collisions *)
@ -695,6 +697,10 @@ module Make(CC_A: ARG) = struct
| _ -> assert false);
a.n_expl <- FL_some {next=b; expl=e_ab};
end;
(* call [on_post_merge] *)
begin
List.iter (fun f -> f cc acts r_into r_from) cc.on_post_merge;
end;
)
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
@ -797,13 +803,14 @@ module Make(CC_A: ARG) = struct
let[@inline] merge_t cc t1 t2 expl =
merge cc (add_term cc t1) (add_term cc t2) expl
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
let on_pre_merge cc f = cc.on_pre_merge <- f :: cc.on_pre_merge
let on_post_merge cc f = cc.on_post_merge <- f :: cc.on_post_merge
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate
let create ?(stat=Stat.global)
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[])
?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[])
?(size=`Big)
(tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
@ -814,7 +821,8 @@ module Make(CC_A: ARG) = struct
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
bitgen;
on_merge;
on_pre_merge;
on_post_merge;
on_new_term;
on_conflict;
on_propagate;

View file

@ -210,14 +210,16 @@ module type CC_S = sig
(** Add the term to the congruence closure, if not present already.
Will be backtracked. *)
type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
type ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
type ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
type ev_on_new_term = t -> N.t -> term -> unit
type ev_on_conflict = t -> lit list -> unit
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
val create :
?stat:Stat.t ->
?on_merge:ev_on_merge list ->
?on_pre_merge:ev_on_pre_merge list ->
?on_post_merge:ev_on_post_merge list ->
?on_new_term:ev_on_new_term list ->
?on_conflict:ev_on_conflict list ->
?on_propagate:ev_on_propagate list ->
@ -231,7 +233,10 @@ module type CC_S = sig
See {!N.bitfield}. *)
(* TODO: remove? this is managed by the solver anyway? *)
val on_merge : t -> ev_on_merge -> unit
val on_pre_merge : t -> ev_on_pre_merge -> unit
(** Add a function to be called when two classes are merged *)
val on_post_merge : t -> ev_on_post_merge -> unit
(** Add a function to be called when two classes are merged *)
val on_new_term : t -> ev_on_new_term -> unit
@ -418,8 +423,11 @@ module type SOLVER_INTERNAL = sig
(** Add/retrieve congruence closure node for this term.
To be used in theories *)
val on_cc_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> CC.Expl.t -> unit) -> unit
(** Callback for when two classes containing data for this key are merged *)
val on_cc_pre_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> CC.Expl.t -> unit) -> unit
(** Callback for when two classes containing data for this key are merged (called before) *)
val on_cc_post_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> unit) -> unit
(** Callback for when two classes containing data for this key are merged (called after)*)
val on_cc_new_term : t -> (CC.t -> CC.N.t -> term -> unit) -> unit
(** Callback to add data on terms when they are added to the congruence

View file

@ -251,7 +251,8 @@ module Make(A : ARG)
let on_final_check self f = self.on_final_check <- f :: self.on_final_check
let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check
let on_cc_new_term self f = CC.on_new_term (cc self) f
let on_cc_merge self f = CC.on_merge (cc self) f
let on_cc_pre_merge self f = CC.on_pre_merge (cc self) f
let on_cc_post_merge self f = CC.on_post_merge (cc self) f
let on_cc_conflict self f = CC.on_conflict (cc self) f
let on_cc_propagate self f = CC.on_propagate (cc self) f

View file

@ -40,9 +40,9 @@ module Make(A : ARG) : S with module A = A = struct
(* TODO: also allocate a bit in CC to filter out quickly classes without cstors *)
}
let on_merge (solver:SI.t) n1 tc1 n2 tc2 e_n1_n2 : unit =
let on_pre_merge (solver:SI.t) n1 tc1 n2 tc2 e_n1_n2 : unit =
Log.debugf 5
(fun k->k "(@[th-cstor.on_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])"
(fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])"
N.pp n1 T.pp tc1.t N.pp n2 T.pp tc2.t);
let expl = Expl.mk_list [e_n1_n2; Expl.mk_merge n1 tc1.n; Expl.mk_merge n2 tc2.n] in
match A.view_as_cstor tc1.t, A.view_as_cstor tc2.t with
@ -71,7 +71,7 @@ module Make(A : ARG) : S with module A = A = struct
cstors=N_tbl.create 32;
} in
(* TODO
SI.on_cc_merge solver on_merge;
SI.on_cc_pre_merge solver on_pre_merge;
SI.on_cc_new_term solver on_new_term;
*)
self