From 7fe6f07c0bfe34b2b1b34d2f5a53e6f0204930ae Mon Sep 17 00:00:00 2001 From: Alexander Bentkamp Date: Wed, 21 Aug 2019 10:31:43 +0200 Subject: [PATCH] split on_merge into two events: pre and post merge --- src/cc/Sidekick_cc.ml | 22 +++++++++++++++------- src/core/Sidekick_core.ml | 18 +++++++++++++----- src/msat-solver/Sidekick_msat_solver.ml | 3 ++- src/th-cstor/Sidekick_th_cstor.ml | 6 +++--- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index dec46ea5..91905900 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -247,7 +247,8 @@ module Make(CC_A: ARG) = struct pending: node Vec.t; combine: combine_task Vec.t; undo: (unit -> unit) Backtrack_stack.t; - mutable on_merge: ev_on_merge list; + mutable on_pre_merge: ev_on_pre_merge list; + mutable on_post_merge: ev_on_post_merge list; mutable on_new_term: ev_on_new_term list; mutable on_conflict: ev_on_conflict list; mutable on_propagate: ev_on_propagate list; @@ -266,7 +267,8 @@ module Make(CC_A: ARG) = struct several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) - and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit + and ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit + and ev_on_post_merge = t -> actions -> N.t -> N.t -> unit and ev_on_new_term = t -> N.t -> term -> unit and ev_on_conflict = t -> lit list -> unit and ev_on_propagate = t -> lit -> (unit -> lit list) -> unit @@ -638,11 +640,11 @@ module Make(CC_A: ARG) = struct merge_bool rb b ra a; (* perform [union r_from r_into] *) Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); - (* call [on_merge] functions, and merge theory data items *) + (* call [on_pre_merge] functions, and merge theory data items *) begin (* explanation is [a=ra & e_ab & b=rb] *) let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in - List.iter (fun f -> f cc acts r_into r_from expl) cc.on_merge; + List.iter (fun f -> f cc acts r_into r_from expl) cc.on_pre_merge; end; begin (* parents might have a different signature, check for collisions *) @@ -695,6 +697,10 @@ module Make(CC_A: ARG) = struct | _ -> assert false); a.n_expl <- FL_some {next=b; expl=e_ab}; end; + (* call [on_post_merge] *) + begin + List.iter (fun f -> f cc acts r_into r_from) cc.on_post_merge; + end; ) (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] @@ -797,13 +803,14 @@ module Make(CC_A: ARG) = struct let[@inline] merge_t cc t1 t2 expl = merge cc (add_term cc t1) (add_term cc t2) expl - let on_merge cc f = cc.on_merge <- f :: cc.on_merge + let on_pre_merge cc f = cc.on_pre_merge <- f :: cc.on_pre_merge + let on_post_merge cc f = cc.on_post_merge <- f :: cc.on_post_merge let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict let on_propagate cc f = cc.on_propagate <- f :: cc.on_propagate let create ?(stat=Stat.global) - ?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[]) + ?(on_pre_merge=[]) ?(on_post_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(on_propagate=[]) ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in @@ -814,7 +821,8 @@ module Make(CC_A: ARG) = struct tbl = T_tbl.create size; signatures_tbl = Sig_tbl.create size; bitgen; - on_merge; + on_pre_merge; + on_post_merge; on_new_term; on_conflict; on_propagate; diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 2937b6ab..5d5dfad9 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -210,14 +210,16 @@ module type CC_S = sig (** Add the term to the congruence closure, if not present already. Will be backtracked. *) - type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit + type ev_on_pre_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit + type ev_on_post_merge = t -> actions -> N.t -> N.t -> unit type ev_on_new_term = t -> N.t -> term -> unit type ev_on_conflict = t -> lit list -> unit type ev_on_propagate = t -> lit -> (unit -> lit list) -> unit val create : ?stat:Stat.t -> - ?on_merge:ev_on_merge list -> + ?on_pre_merge:ev_on_pre_merge list -> + ?on_post_merge:ev_on_post_merge list -> ?on_new_term:ev_on_new_term list -> ?on_conflict:ev_on_conflict list -> ?on_propagate:ev_on_propagate list -> @@ -231,7 +233,10 @@ module type CC_S = sig See {!N.bitfield}. *) (* TODO: remove? this is managed by the solver anyway? *) - val on_merge : t -> ev_on_merge -> unit + val on_pre_merge : t -> ev_on_pre_merge -> unit + (** Add a function to be called when two classes are merged *) + + val on_post_merge : t -> ev_on_post_merge -> unit (** Add a function to be called when two classes are merged *) val on_new_term : t -> ev_on_new_term -> unit @@ -418,8 +423,11 @@ module type SOLVER_INTERNAL = sig (** Add/retrieve congruence closure node for this term. To be used in theories *) - val on_cc_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> CC.Expl.t -> unit) -> unit - (** Callback for when two classes containing data for this key are merged *) + val on_cc_pre_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> CC.Expl.t -> unit) -> unit + (** Callback for when two classes containing data for this key are merged (called before) *) + + val on_cc_post_merge : t -> (CC.t -> actions -> CC.N.t -> CC.N.t -> unit) -> unit + (** Callback for when two classes containing data for this key are merged (called after)*) val on_cc_new_term : t -> (CC.t -> CC.N.t -> term -> unit) -> unit (** Callback to add data on terms when they are added to the congruence diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index e0524f57..6fa9e92d 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -251,7 +251,8 @@ module Make(A : ARG) let on_final_check self f = self.on_final_check <- f :: self.on_final_check let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check let on_cc_new_term self f = CC.on_new_term (cc self) f - let on_cc_merge self f = CC.on_merge (cc self) f + let on_cc_pre_merge self f = CC.on_pre_merge (cc self) f + let on_cc_post_merge self f = CC.on_post_merge (cc self) f let on_cc_conflict self f = CC.on_conflict (cc self) f let on_cc_propagate self f = CC.on_propagate (cc self) f diff --git a/src/th-cstor/Sidekick_th_cstor.ml b/src/th-cstor/Sidekick_th_cstor.ml index 11c11754..721d657e 100644 --- a/src/th-cstor/Sidekick_th_cstor.ml +++ b/src/th-cstor/Sidekick_th_cstor.ml @@ -40,9 +40,9 @@ module Make(A : ARG) : S with module A = A = struct (* TODO: also allocate a bit in CC to filter out quickly classes without cstors *) } - let on_merge (solver:SI.t) n1 tc1 n2 tc2 e_n1_n2 : unit = + let on_pre_merge (solver:SI.t) n1 tc1 n2 tc2 e_n1_n2 : unit = Log.debugf 5 - (fun k->k "(@[th-cstor.on_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])" + (fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])" N.pp n1 T.pp tc1.t N.pp n2 T.pp tc2.t); let expl = Expl.mk_list [e_n1_n2; Expl.mk_merge n1 tc1.n; Expl.mk_merge n2 tc2.n] in match A.view_as_cstor tc1.t, A.view_as_cstor tc2.t with @@ -71,7 +71,7 @@ module Make(A : ARG) : S with module A = A = struct cstors=N_tbl.create 32; } in (* TODO - SI.on_cc_merge solver on_merge; + SI.on_cc_pre_merge solver on_pre_merge; SI.on_cc_new_term solver on_new_term; *) self