split on_merge into two events: pre and post merge

This commit is contained in:
Alexander Bentkamp 2019-08-21 10:31:43 +02:00 committed by Simon Cruanes
parent cde983df86
commit 7fe6f07c0b
4 changed files with 33 additions and 16 deletions

View file

@ -247,7 +247,8 @@ module Make(CC_A: ARG) = struct
pending: node Vec.t; pending: node Vec.t;
combine: combine_task Vec.t; combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t; undo: (unit -> unit) Backtrack_stack.t;
mutable on_merge: ev_on_merge list; mutable on_pre_merge: ev_on_pre_merge list;
mutable on_post_merge: ev_on_post_merge list;
mutable on_new_term: ev_on_new_term list; mutable on_new_term: ev_on_new_term list;
mutable on_conflict: ev_on_conflict list; mutable on_conflict: ev_on_conflict list;
mutable on_propagate: ev_on_propagate list; mutable on_propagate: ev_on_propagate list;
@ -266,7 +267,8 @@ module Make(CC_A: ARG) = struct
several times. several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit and ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
and ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
and ev_on_new_term = t -> N.t -> term -> unit and ev_on_new_term = t -> N.t -> term -> unit
and ev_on_conflict = t -> lit list -> unit and ev_on_conflict = t -> lit list -> unit
and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
@ -638,11 +640,11 @@ module Make(CC_A: ARG) = struct
merge_bool rb b ra a; merge_bool rb b ra a;
(* perform [union r_from r_into] *) (* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* call [on_merge] functions, and merge theory data items *) (* call [on_pre_merge] functions, and merge theory data items *)
begin begin
(* explanation is [a=ra & e_ab & b=rb] *) (* explanation is [a=ra & e_ab & b=rb] *)
let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in
List.iter (fun f -> f cc acts r_into r_from expl) cc.on_merge; List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge;
end; end;
begin begin
(* parents might have a different signature, check for collisions *) (* parents might have a different signature, check for collisions *)
@ -695,6 +697,10 @@ module Make(CC_A: ARG) = struct
| _ -> assert false); | _ -> assert false);
a.n_expl <- FL_some {next=b; expl=e_ab}; a.n_expl <- FL_some {next=b; expl=e_ab};
end; end;
(* call [on_post_merge] *)
begin
List.iter (fun f -> f cc acts r_into r_from) cc.on_post_merge;
end;
) )
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
@ -797,13 +803,14 @@ module Make(CC_A: ARG) = struct
let[@inline] merge_t cc t1 t2 expl = let[@inline] merge_t cc t1 t2 expl =
merge cc (add_term cc t1) (add_term cc t2) expl merge cc (add_term cc t1) (add_term cc t2) expl
let on_merge cc f = cc.on_merge <- f :: cc.on_merge let on_pre_merge cc f = cc.on_pre_merge <- f :: cc.on_pre_merge
let on_post_merge cc f = cc.on_post_merge <- f :: cc.on_post_merge
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate
let create ?(stat=Stat.global) let create ?(stat=Stat.global)
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[]) ?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[])
?(size=`Big) ?(size=`Big)
(tst:term_state) : t = (tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in let size = match size with `Small -> 128 | `Big -> 2048 in
@ -814,7 +821,8 @@ module Make(CC_A: ARG) = struct
tbl = T_tbl.create size; tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size; signatures_tbl = Sig_tbl.create size;
bitgen; bitgen;
on_merge; on_pre_merge;
on_post_merge;
on_new_term; on_new_term;
on_conflict; on_conflict;
on_propagate; on_propagate;

View file

@ -210,14 +210,16 @@ module type CC_S = sig
(** Add the term to the congruence closure, if not present already. (** Add the term to the congruence closure, if not present already.
Will be backtracked. *) Will be backtracked. *)
type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit type ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
type ev_on_post_merge = t -> actions -> N.t -> N.t -> unit
type ev_on_new_term = t -> N.t -> term -> unit type ev_on_new_term = t -> N.t -> term -> unit
type ev_on_conflict = t -> lit list -> unit type ev_on_conflict = t -> lit list -> unit
type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit
val create : val create :
?stat:Stat.t -> ?stat:Stat.t ->
?on_merge:ev_on_merge list -> ?on_pre_merge:ev_on_pre_merge list ->
?on_post_merge:ev_on_post_merge list ->
?on_new_term:ev_on_new_term list -> ?on_new_term:ev_on_new_term list ->
?on_conflict:ev_on_conflict list -> ?on_conflict:ev_on_conflict list ->
?on_propagate:ev_on_propagate list -> ?on_propagate:ev_on_propagate list ->
@ -231,7 +233,10 @@ module type CC_S = sig
See {!N.bitfield}. *) See {!N.bitfield}. *)
(* TODO: remove? this is managed by the solver anyway? *) (* TODO: remove? this is managed by the solver anyway? *)
val on_merge : t -> ev_on_merge -> unit val on_pre_merge : t -> ev_on_pre_merge -> unit
(** Add a function to be called when two classes are merged *)
val on_post_merge : t -> ev_on_post_merge -> unit
(** Add a function to be called when two classes are merged *) (** Add a function to be called when two classes are merged *)
val on_new_term : t -> ev_on_new_term -> unit val on_new_term : t -> ev_on_new_term -> unit
@ -418,8 +423,11 @@ module type SOLVER_INTERNAL = sig
(** Add/retrieve congruence closure node for this term. (** Add/retrieve congruence closure node for this term.
To be used in theories *) To be used in theories *)
val on_cc_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> CC.Expl.t -> unit) -> unit val on_cc_pre_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> CC.Expl.t -> unit) -> unit
(** Callback for when two classes containing data for this key are merged *) (** Callback for when two classes containing data for this key are merged (called before) *)
val on_cc_post_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> unit) -> unit
(** Callback for when two classes containing data for this key are merged (called after)*)
val on_cc_new_term : t -> (CC.t -> CC.N.t -> term -> unit) -> unit val on_cc_new_term : t -> (CC.t -> CC.N.t -> term -> unit) -> unit
(** Callback to add data on terms when they are added to the congruence (** Callback to add data on terms when they are added to the congruence

View file

@ -251,7 +251,8 @@ module Make(A : ARG)
let on_final_check self f = self.on_final_check <- f :: self.on_final_check let on_final_check self f = self.on_final_check <- f :: self.on_final_check
let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check
let on_cc_new_term self f = CC.on_new_term (cc self) f let on_cc_new_term self f = CC.on_new_term (cc self) f
let on_cc_merge self f = CC.on_merge (cc self) f let on_cc_pre_merge self f = CC.on_pre_merge (cc self) f
let on_cc_post_merge self f = CC.on_post_merge (cc self) f
let on_cc_conflict self f = CC.on_conflict (cc self) f let on_cc_conflict self f = CC.on_conflict (cc self) f
let on_cc_propagate self f = CC.on_propagate (cc self) f let on_cc_propagate self f = CC.on_propagate (cc self) f

View file

@ -40,9 +40,9 @@ module Make(A : ARG) : S with module A = A = struct
(* TODO: also allocate a bit in CC to filter out quickly classes without cstors *) (* TODO: also allocate a bit in CC to filter out quickly classes without cstors *)
} }
let on_merge (solver:SI.t) n1 tc1 n2 tc2 e_n1_n2 : unit = let on_pre_merge (solver:SI.t) n1 tc1 n2 tc2 e_n1_n2 : unit =
Log.debugf 5 Log.debugf 5
(fun k->k "(@[th-cstor.on_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])" (fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])"
N.pp n1 T.pp tc1.t N.pp n2 T.pp tc2.t); N.pp n1 T.pp tc1.t N.pp n2 T.pp tc2.t);
let expl = Expl.mk_list [e_n1_n2; Expl.mk_merge n1 tc1.n; Expl.mk_merge n2 tc2.n] in let expl = Expl.mk_list [e_n1_n2; Expl.mk_merge n1 tc1.n; Expl.mk_merge n2 tc2.n] in
match A.view_as_cstor tc1.t, A.view_as_cstor tc2.t with match A.view_as_cstor tc1.t, A.view_as_cstor tc2.t with
@ -71,7 +71,7 @@ module Make(A : ARG) : S with module A = A = struct
cstors=N_tbl.create 32; cstors=N_tbl.create 32;
} in } in
(* TODO (* TODO
SI.on_cc_merge solver on_merge; SI.on_cc_pre_merge solver on_pre_merge;
SI.on_cc_new_term solver on_new_term; SI.on_cc_new_term solver on_new_term;
*) *)
self self