From dc68a60151de8cd124cbf469fce59d05a17bec35 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 21 Jul 2022 23:21:07 -0400 Subject: [PATCH] feat(cc): remove callbacks, return list of actions --- src/cc/Sidekick_cc.ml | 556 ++++++++++++++++------------ src/cc/plugin/sidekick_cc_plugin.ml | 71 ++-- 2 files changed, 351 insertions(+), 276 deletions(-) diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 32792284..121d7ba6 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -11,7 +11,8 @@ module type S = sig @param term_store used to be able to create new terms. All terms interacting with this congruence closure must belong in this term state - as well. *) + as well. + *) (**/**) @@ -25,6 +26,48 @@ end module type ARG = ARG +(* small bitfield *) +module Bits : sig + type t = private int + type field + type bitfield_gen + + val empty : t + val equal : t -> t -> bool + val mk_field : bitfield_gen -> field + val mk_gen : unit -> bitfield_gen + val get : field -> t -> bool + val set : field -> bool -> t -> t + val merge : t -> t -> t +end = struct + type bitfield_gen = int ref + + let max_width = Sys.word_size - 2 + let mk_gen () = ref 0 + + type t = int + type field = int + + let empty : t = 0 + + let mk_field (gen : bitfield_gen) : field = + let n = !gen in + if n > max_width then Error.errorf "maximum number of CC bitfields reached"; + incr gen; + 1 lsl n + + let[@inline] get field x = x land field <> 0 + + let[@inline] set field b x = + if b then + x lor field + else + x land lnot field + + let merge = ( lor ) + let equal : t -> t -> bool = CCEqual.poly +end + module Make (A : ARG) : S with module T = A.T @@ -50,63 +93,14 @@ module Make (A : ARG) : type proof_trace = A.Proof_trace.t type step_id = A.Proof_trace.A.step_id - type actions = - (module DYN_ACTIONS - with type term = T.Term.t - and type lit = Lit.t - and type proof_trace = proof_trace - and type step_id = step_id) - - module Bits : sig - type t = private int - type field - type bitfield_gen - - val empty : t - val equal : t -> t -> bool - val mk_field : bitfield_gen -> field - val mk_gen : unit -> bitfield_gen - val get : field -> t -> bool - val set : field -> bool -> t -> t - val merge : t -> t -> t - end = struct - type bitfield_gen = int ref - - let max_width = Sys.word_size - 2 - let mk_gen () = ref 0 - - type t = int - type field = int - - let empty : t = 0 - - let mk_field (gen : bitfield_gen) : field = - let n = !gen in - if n > max_width then - Error.errorf "maximum number of CC bitfields reached"; - incr gen; - 1 lsl n - - let[@inline] get field x = x land field <> 0 - - let[@inline] set field b x = - if b then - x lor field - else - x land lnot field - - let merge = ( lor ) - let equal : t -> t -> bool = CCEqual.poly - end - - type node = { + type e_node = { n_term: term; mutable n_sig0: signature option; (* initial signature *) mutable n_bits: Bits.t; (* bitfield for various properties *) - mutable n_parents: node Bag.t; (* parent terms of this node *) - mutable n_root: node; + mutable n_parents: e_node Bag.t; (* parent terms of this node *) + mutable n_root: e_node; (* representative of congruence class (itself if a representative) *) - mutable n_next: node; (* pointer to next element of congruence class *) + mutable n_next: e_node; (* pointer to next element of congruence class *) mutable n_size: int; (* size of the class *) mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) @@ -117,27 +111,27 @@ module Make (A : ARG) : An equivalence class is represented by its "root" element, the representative. *) - and signature = (fun_, node, node list) View.t + and signature = (fun_, e_node, e_node list) View.t and explanation_forest_link = | FL_none - | FL_some of { next: node; expl: explanation } + | FL_some of { next: e_node; expl: explanation } (* atomic explanation in the congruence closure *) and explanation = | E_trivial (* by pure reduction, tautologically equal *) | E_lit of lit (* because of this literal *) - | E_merge of node * node + | E_merge of e_node * e_node | E_merge_t of term * term - | E_congruence of node * node (* caused by normal congruence *) + | E_congruence of e_node * e_node (* caused by normal congruence *) | E_and of explanation * explanation | E_theory of term * term * (term * term * explanation list) list * step_id - | E_same_val of node * node + | E_same_val of e_node * e_node - type repr = node + type repr = e_node - module Class = struct - type t = node + module E_node = struct + type t = e_node let[@inline] equal (n1 : t) n2 = n1 == n2 let[@inline] hash n = Term.hash n.n_term @@ -162,10 +156,10 @@ module Make (A : ARG) : in n - let[@inline] is_root (n : node) : bool = n.n_root == n + let[@inline] is_root (n : e_node) : bool = n.n_root == n (* traverse the equivalence class of [n] *) - let iter_class_ (n : node) : node Iter.t = + let iter_class_ (n : e_node) : e_node Iter.t = fun yield -> let rec aux u = yield u; @@ -177,7 +171,7 @@ module Make (A : ARG) : assert (is_root n); iter_class_ n - let[@inline] iter_parents (n : node) : node Iter.t = + let[@inline] iter_parents (n : e_node) : e_node Iter.t = assert (is_root n); Bag.to_iter n.n_parents @@ -188,13 +182,13 @@ module Make (A : ARG) : end (* non-recursive, inlinable function for [find] *) - let[@inline] find_ (n : node) : repr = + let[@inline] find_ (n : e_node) : repr = let n2 = n.n_root in - assert (Class.is_root n2); + assert (E_node.is_root n2); n2 - let[@inline] same_class (n1 : node) (n2 : node) : bool = - Class.equal (find_ n1) (find_ n2) + let[@inline] same_class (n1 : e_node) (n2 : e_node) : bool = + E_node.equal (find_ n1) (find_ n2) let[@inline] find _ n = find_ n @@ -206,9 +200,9 @@ module Make (A : ARG) : | E_trivial -> Fmt.string out "reduction" | E_lit lit -> Lit.pp out lit | E_congruence (n1, n2) -> - Fmt.fprintf out "(@[congruence@ %a@ %a@])" Class.pp n1 Class.pp n2 + Fmt.fprintf out "(@[congruence@ %a@ %a@])" E_node.pp n1 E_node.pp n2 | E_merge (a, b) -> - Fmt.fprintf out "(@[merge@ %a@ %a@])" Class.pp a Class.pp b + Fmt.fprintf out "(@[merge@ %a@ %a@])" E_node.pp a E_node.pp b | E_merge_t (a, b) -> Fmt.fprintf out "(@[merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp a Term.pp b @@ -219,13 +213,13 @@ module Make (A : ARG) : es | E_and (a, b) -> Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b | E_same_val (n1, n2) -> - Fmt.fprintf out "(@[same-value@ %a@ %a@])" Class.pp n1 Class.pp n2 + Fmt.fprintf out "(@[same-value@ %a@ %a@])" E_node.pp n1 E_node.pp n2 let mk_trivial : t = E_trivial let[@inline] mk_congruence n1 n2 : t = E_congruence (n1, n2) let[@inline] mk_merge a b : t = - if Class.equal a b then + if E_node.equal a b then mk_trivial else E_merge (a, b) @@ -240,7 +234,7 @@ module Make (A : ARG) : let[@inline] mk_theory t u es pr = E_theory (t, u, es, pr) let[@inline] mk_same_value t u = - if Class.equal t u then + if E_node.equal t u then mk_trivial else E_same_val (t, u) @@ -259,7 +253,7 @@ module Make (A : ARG) : module Resolved_expl = struct type t = { lits: lit list; - same_value: (Class.t * Class.t) list; + same_value: (E_node.t * E_node.t) list; pr: proof_trace -> step_id; } @@ -276,11 +270,26 @@ module Make (A : ARG) : let { lits; same_value; pr = _ } = self in Fmt.fprintf out "(@[resolved-expl@ (@[%a@])@ :same-val (@[%a@])@])" (Util.pp_list Lit.pp) lits - (Util.pp_list @@ Fmt.Dump.pair Class.pp Class.pp) + (Util.pp_list @@ Fmt.Dump.pair E_node.pp E_node.pp) same_value ) end + type propagation_reason = unit -> lit list * step_id + + type action = + | Act_merge of E_node.t * E_node.t * Expl.t + | Act_propagate of { lit: lit; reason: propagation_reason } + + type conflict = + | Conflict of lit list * step_id + (** [raise_conflict (c,pr)] declares that [c] is a tautology of + the theory of congruence. + @param pr the proof of [c] being a tautology *) + | Conflict_expl of Expl.t + + type actions_or_confl = (action list, conflict) result + (** A signature is a shallow term shape where immediate subterms are representative *) module Signature = struct @@ -291,14 +300,14 @@ module Make (A : ARG) : | Bool b1, Bool b2 -> b1 = b2 | App_fun (f1, []), App_fun (f2, []) -> Fun.equal f1 f2 | App_fun (f1, l1), App_fun (f2, l2) -> - Fun.equal f1 f2 && CCList.equal Class.equal l1 l2 + Fun.equal f1 f2 && CCList.equal E_node.equal l1 l2 | App_ho (f1, a1), App_ho (f2, a2) -> - Class.equal f1 f2 && Class.equal a1 a2 - | Not a, Not b -> Class.equal a b + E_node.equal f1 f2 && E_node.equal a1 a2 + | Not a, Not b -> E_node.equal a b | If (a1, b1, c1), If (a2, b2, c2) -> - Class.equal a1 a2 && Class.equal b1 b2 && Class.equal c1 c2 - | Eq (a1, b1), Eq (a2, b2) -> Class.equal a1 a2 && Class.equal b1 b2 - | Opaque u1, Opaque u2 -> Class.equal u1 u2 + E_node.equal a1 a2 && E_node.equal b1 b2 && E_node.equal c1 c2 + | Eq (a1, b1), Eq (a2, b2) -> E_node.equal a1 a2 && E_node.equal b1 b2 + | Opaque u1, Opaque u2 -> E_node.equal u1 u2 | Bool _, _ | App_fun _, _ | App_ho _, _ @@ -312,25 +321,26 @@ module Make (A : ARG) : let module H = CCHash in match s with | Bool b -> H.combine2 10 (H.bool b) - | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list Class.hash l) - | App_ho (f, a) -> H.combine3 30 (Class.hash f) (Class.hash a) - | Eq (a, b) -> H.combine3 40 (Class.hash a) (Class.hash b) - | Opaque u -> H.combine2 50 (Class.hash u) + | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list E_node.hash l) + | App_ho (f, a) -> H.combine3 30 (E_node.hash f) (E_node.hash a) + | Eq (a, b) -> H.combine3 40 (E_node.hash a) (E_node.hash b) + | Opaque u -> H.combine2 50 (E_node.hash u) | If (a, b, c) -> - H.combine4 60 (Class.hash a) (Class.hash b) (Class.hash c) - | Not u -> H.combine2 70 (Class.hash u) + H.combine4 60 (E_node.hash a) (E_node.hash b) (E_node.hash c) + | Not u -> H.combine2 70 (E_node.hash u) let pp out = function | Bool b -> Fmt.bool out b | App_fun (f, []) -> Fun.pp out f | App_fun (f, l) -> - Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Class.pp) l - | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" Class.pp f Class.pp a - | Opaque t -> Class.pp out t - | Not u -> Fmt.fprintf out "(@[not@ %a@])" Class.pp u - | Eq (a, b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" Class.pp a Class.pp b + Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list E_node.pp) l + | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" E_node.pp f E_node.pp a + | Opaque t -> E_node.pp out t + | Not u -> Fmt.fprintf out "(@[not@ %a@])" E_node.pp u + | Eq (a, b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" E_node.pp a E_node.pp b | If (a, b, c) -> - Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" Class.pp a Class.pp b Class.pp c + Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" E_node.pp a E_node.pp b + E_node.pp c end module Sig_tbl = CCHashtbl.Make (Signature) @@ -338,40 +348,44 @@ module Make (A : ARG) : module T_b_tbl = Backtrackable_tbl.Make (Term) type combine_task = - | CT_merge of node * node * explanation - | CT_set_val of node * value + | CT_merge of e_node * e_node * explanation + | CT_set_val of e_node * value + | CT_act of action type t = { tst: term_store; proof: proof_trace; - tbl: node T_tbl.t; (* internalization [term -> node] *) - signatures_tbl: node Sig_tbl.t; - (* map a signature to the corresponding node in some equivalence class. + tbl: e_node T_tbl.t; (* internalization [term -> e_node] *) + signatures_tbl: e_node Sig_tbl.t; + (* map a signature to the corresponding e_node in some equivalence class. A signature is a [term_cell] in which every immediate subterm that participates in the congruence/evaluation relation is normalized (i.e. is its own representative). The critical property is that all members of an equivalence class that have the same "shape" (including head symbol) have the same signature *) - pending: node Vec.t; + pending: e_node Vec.t; combine: combine_task Vec.t; - t_to_val: (node * value) T_b_tbl.t; + t_to_val: (e_node * value) T_b_tbl.t; (* TODO: remove this, make it a plugin/EGG instead *) (* [repr -> (t,val)] where [repr = t] and [t := val] in the model *) - val_to_t: node T_b_tbl.t; (* [val -> t] where [t := val] in the model *) + val_to_t: e_node T_b_tbl.t; (* [val -> t] where [t := val] in the model *) undo: (unit -> unit) Backtrack_stack.t; bitgen: Bits.bitfield_gen; field_marked_explain: Bits.field; (* used to mark traversed nodes when looking for a common ancestor *) - true_: node lazy_t; - false_: node lazy_t; + true_: e_node lazy_t; + false_: e_node lazy_t; mutable model_mode: bool; - on_pre_merge: (t * actions * Class.t * Class.t * Expl.t) Event.Emitter.t; - on_post_merge: (t * actions * Class.t * Class.t) Event.Emitter.t; - on_new_term: (t * Class.t * term) Event.Emitter.t; - on_conflict: ev_on_conflict Event.Emitter.t; - on_propagate: (t * lit * (unit -> lit list * step_id)) Event.Emitter.t; - on_is_subterm: (t * Class.t * term) Event.Emitter.t; + mutable in_loop: bool; (* currently being modified? *) + res_acts: action Vec.t; (* to return *) + on_pre_merge: + (t * E_node.t * E_node.t * Expl.t, actions_or_confl) Event.Emitter.t; + on_post_merge: (t * E_node.t * E_node.t, action list) Event.Emitter.t; + on_new_term: (t * E_node.t * term, action list) Event.Emitter.t; + on_conflict: (ev_on_conflict, unit) Event.Emitter.t; + on_propagate: (t * lit * propagation_reason, action list) Event.Emitter.t; + on_is_subterm: (t * E_node.t * term, action list) Event.Emitter.t; count_conflict: int Stat.counter; count_props: int Stat.counter; count_merge: int Stat.counter; @@ -405,13 +419,13 @@ module Make (A : ARG) : let[@inline] on_backtrack self f : unit = Backtrack_stack.push_if_nonzero_level self.undo f - let[@inline] get_bitfield _cc field n = Class.get_field field n + let[@inline] get_bitfield _cc field n = E_node.get_field field n let set_bitfield self field b n = - let old = Class.get_field field n in + let old = E_node.get_field field n in if old <> b then ( - on_backtrack self (fun () -> Class.set_field field old n); - Class.set_field field b n + on_backtrack self (fun () -> E_node.set_field field old n); + E_node.set_field field b n ) (* check if [t] is in the congruence closure. @@ -421,25 +435,25 @@ module Make (A : ARG) : module Debug_ = struct (* print full state *) let pp out (self : t) : unit = - let pp_next out n = Fmt.fprintf out "@ :next %a" Class.pp n.n_next in + let pp_next out n = Fmt.fprintf out "@ :next %a" E_node.pp n.n_next in let pp_root out n = - if Class.is_root n then + if E_node.is_root n then Fmt.string out " :is-root" else - Fmt.fprintf out "@ :root %a" Class.pp n.n_root + Fmt.fprintf out "@ :root %a" E_node.pp n.n_root in let pp_expl out n = match n.n_expl with | FL_none -> () | FL_some e -> - Fmt.fprintf out " (@[:forest %a :expl %a@])" Class.pp e.next Expl.pp + Fmt.fprintf out " (@[:forest %a :expl %a@])" E_node.pp e.next Expl.pp e.expl in let pp_n out n = Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n pp_expl n and pp_sig_e out (s, n) = - Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s Class.pp n + Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s E_node.pp n pp_root n in Fmt.fprintf out @@ -461,29 +475,34 @@ module Make (A : ARG) : Sig_tbl.get cc.signatures_tbl s (* add to signature table. Assume it's not present already *) - let add_signature self (s : signature) (n : node) : unit = + let add_signature self (s : signature) (n : e_node) : unit = assert (not @@ Sig_tbl.mem self.signatures_tbl s); Log.debugf 50 (fun k -> - k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s Class.pp n); + k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s E_node.pp n); on_backtrack self (fun () -> Sig_tbl.remove self.signatures_tbl s); Sig_tbl.add self.signatures_tbl s n let push_pending self t : unit = - Log.debugf 50 (fun k -> k "(@[cc.push-pending@ %a@])" Class.pp t); + Log.debugf 50 (fun k -> k "(@[cc.push-pending@ %a@])" E_node.pp t); Vec.push self.pending t + let push_action self (a : action) : unit = Vec.push self.combine (CT_act a) + + let push_action_l self (l : action list) : unit = + List.iter (push_action self) l + let merge_classes self t u e : unit = if t != u && not (same_class t u) then ( Log.debugf 50 (fun k -> - k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" Class.pp t Class.pp - u Expl.pp e); + k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" E_node.pp t + E_node.pp u Expl.pp e); Vec.push self.combine @@ CT_merge (t, u, e) ) (* re-root the explanation tree of the equivalence class of [n] so that it points to [n]. postcondition: [n.n_expl = None] *) - let[@unroll 2] rec reroot_expl (self : t) (n : node) : unit = + let[@unroll 2] rec reroot_expl (self : t) (n : e_node) : unit = match n.n_expl with | FL_none -> () (* already root *) | FL_some { next = u; expl = e_n_u } -> @@ -492,33 +511,33 @@ module Make (A : ARG) : u.n_expl <- FL_some { next = n; expl = e_n_u }; n.n_expl <- FL_none - let raise_conflict_ (cc : t) ~th (acts : actions) (e : lit list) (p : step_id) - : _ = + exception E_confl of conflict + + let raise_conflict_ (cc : t) ~th (e : lit list) (p : step_id) : _ = Profile.instant "cc.conflict"; (* clear tasks queue *) Vec.clear cc.pending; Vec.clear cc.combine; Event.emit cc.on_conflict { cc; th; c = e }; Stat.incr cc.count_conflict; - let (module A) = acts in - A.raise_conflict e p + raise (E_confl (Conflict (e, p))) let[@inline] all_classes self : repr Iter.t = - T_tbl.values self.tbl |> Iter.filter Class.is_root + T_tbl.values self.tbl |> Iter.filter E_node.is_root (* find the closest common ancestor of [a] and [b] in the proof forest. Precond: - [a] and [b] are in the same class - - no node has the flag [field_marked_explain] on + - no e_node has the flag [field_marked_explain] on Invariants: - if [n] is marked, then all the predecessors of [n] from [a] or [b] are marked too. *) - let find_common_ancestor self (a : node) (b : node) : node = - (* catch up to the other node *) + let find_common_ancestor self (a : e_node) (b : e_node) : e_node = + (* catch up to the other e_node *) let rec find1 a = - if Class.get_field self.field_marked_explain a then + if E_node.get_field self.field_marked_explain a then a else ( match a.n_expl with @@ -527,15 +546,15 @@ module Make (A : ARG) : ) in let rec find2 a b = - if Class.equal a b then + if E_node.equal a b then a - else if Class.get_field self.field_marked_explain a then + else if E_node.get_field self.field_marked_explain a then a - else if Class.get_field self.field_marked_explain b then + else if E_node.get_field self.field_marked_explain b then b else ( - Class.set_field self.field_marked_explain true a; - Class.set_field self.field_marked_explain true b; + E_node.set_field self.field_marked_explain true a; + E_node.set_field self.field_marked_explain true b; match a.n_expl, b.n_expl with | FL_some r1, FL_some r2 -> find2 r1.next r2.next | FL_some r, FL_none -> find1 r.next @@ -547,8 +566,8 @@ module Make (A : ARG) : (* cleanup tags on nodes traversed in [find2] *) let rec cleanup_ n = - if Class.get_field self.field_marked_explain n then ( - Class.set_field self.field_marked_explain false n; + if E_node.get_field self.field_marked_explain n then ( + E_node.set_field self.field_marked_explain false n; match n.n_expl with | FL_none -> () | FL_some { next; _ } -> cleanup_ next @@ -562,7 +581,7 @@ module Make (A : ARG) : module Expl_state = struct type t = { mutable lits: Lit.t list; - mutable same_val: (Class.t * Class.t) list; + mutable same_val: (E_node.t * E_node.t) list; mutable th_lemmas: (Lit.t * (Lit.t * Lit.t list) list * step_id) list; } @@ -671,7 +690,8 @@ module Make (A : ARG) : (match T_tbl.find self.tbl a, T_tbl.find self.tbl b with | a, b -> explain_equal_rec_ self st a b | exception Not_found -> - Error.errorf "expl: cannot find node(s) for %a, %a" Term.pp a Term.pp b) + Error.errorf "expl: cannot find e_node(s) for %a, %a" Term.pp a Term.pp + b) | E_and (a, b) -> explain_decompose_expl self st a; explain_decompose_expl self st b @@ -681,19 +701,19 @@ module Make (A : ARG) : List.iter (explain_decompose_expl self st) es; st - and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : node) (b : node) : - unit = + and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) + : unit = Log.debugf 5 (fun k -> - k "(@[cc.explain_loop.at@ %a@ =?= %a@])" Class.pp a Class.pp b); - assert (Class.equal (find_ a) (find_ b)); + k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b); + assert (E_node.equal (find_ a) (find_ b)); let ancestor = find_common_ancestor cc a b in explain_along_path cc st a ancestor; explain_along_path cc st b ancestor (* explain why [a = parent_a], where [a -> ... -> target] in the proof forest *) - and explain_along_path self (st : Expl_state.t) (a : node) (target : node) : - unit = + and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) + : unit = let rec aux n = if n == target then () @@ -709,16 +729,16 @@ module Make (A : ARG) : aux a (* add a term *) - let[@inline] rec add_term_rec_ self t : node = + let[@inline] rec add_term_rec_ self t : e_node = match T_tbl.find self.tbl t with | n -> n | exception Not_found -> add_new_term_ self t (* add [t] when not present already *) - and add_new_term_ self (t : term) : node = + and add_new_term_ self (t : term) : e_node = assert (not @@ mem self t); Log.debugf 15 (fun k -> k "(@[cc.add-term@ %a@])" Term.pp t); - let n = Class.make t in + let n = E_node.make t in (* register sub-terms, add [t] to their parent list, and return the corresponding initial signature *) let sig0 = compute_sig0 self n in @@ -732,22 +752,24 @@ module Make (A : ARG) : if Option.is_some sig0 then (* [n] might be merged with other equiv classes *) push_pending self n; - if not self.model_mode then Event.emit self.on_new_term (self, n, t); + if not self.model_mode then + Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self); n - (* compute the initial signature of the given node *) - and compute_sig0 (self : t) (n : node) : Signature.t option = + (* compute the initial signature of the given e_node *) + and compute_sig0 (self : t) (n : e_node) : Signature.t option = (* add sub-term to [cc], and register [n] to its parents. Note that we return the exact sub-term, to get proper explanations, but we add to the sub-term's root's parent list. *) - let deref_sub (u : term) : node = + let deref_sub (u : term) : e_node = let sub = add_term_rec_ self u in (* add [n] to [sub.root]'s parent list *) (let sub_r = find_ sub in let old_parents = sub_r.n_parents in if Bag.is_empty old_parents && not self.model_mode then (* first time it has parents: tell watchers that this is a subterm *) - Event.emit self.on_is_subterm (self, sub, u); + Event.emit_iter self.on_is_subterm (self, sub, u) + ~f:(push_action_l self); on_backtrack self (fun () -> sub_r.n_parents <- old_parents); sub_r.n_parents <- Bag.cons n sub_r.n_parents); sub @@ -772,21 +794,21 @@ module Make (A : ARG) : return @@ App_ho (f, a) | If (a, b, c) -> return @@ If (deref_sub a, deref_sub b, deref_sub c) - let[@inline] add_term self t : node = add_term_rec_ self t + let[@inline] add_term self t : e_node = add_term_rec_ self t let mem_term = mem - let set_as_lit self (n : node) (lit : lit) : unit = + let set_as_lit self (n : e_node) (lit : lit) : unit = match n.n_as_lit with | Some _ -> () | None -> Log.debugf 15 (fun k -> - k "(@[cc.set-as-lit@ %a@ %a@])" Class.pp n Lit.pp lit); + k "(@[cc.set-as-lit@ %a@ %a@])" E_node.pp n Lit.pp lit); on_backtrack self (fun () -> n.n_as_lit <- None); n.n_as_lit <- Some lit (* is [n] true or false? *) let n_is_bool_value (self : t) n : bool = - Class.equal n (n_true self) || Class.equal n (n_false self) + E_node.equal n (n_true self) || E_node.equal n (n_false self) (* gather a pair [lits, pr], where [lits] is the set of asserted literals needed in the explanation (which is useful for @@ -801,17 +823,17 @@ module Make (A : ARG) : (* main CC algo: add terms from [pending] to the signature table, check for collisions *) - let rec update_tasks (self : t) (acts : actions) : unit = + let rec update_tasks (self : t) : unit = while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do while not @@ Vec.is_empty self.pending do task_pending_ self (Vec.pop_exn self.pending) done; while not @@ Vec.is_empty self.combine do - task_combine_ self acts (Vec.pop_exn self.combine) + task_combine_ self (Vec.pop_exn self.combine) done done - and task_pending_ self (n : node) : unit = + and task_pending_ self (n : e_node) : unit = (* check if some parent collided *) match n.n_sig0 with | None -> () (* no-op *) @@ -820,28 +842,28 @@ module Make (A : ARG) : if same_class a b then ( let expl = Expl.mk_merge a b in Log.debugf 5 (fun k -> - k "(@[cc.pending.eq@ %a@ :r1 %a@ :r2 %a@])" Class.pp n Class.pp a - Class.pp b); + k "(@[cc.pending.eq@ %a@ :r1 %a@ :r2 %a@])" E_node.pp n E_node.pp a + E_node.pp b); merge_classes self n (n_true self) expl ) | Some (Not u) -> (* [u = bool ==> not u = not bool] *) let r_u = find_ u in - if Class.equal r_u (n_true self) then ( + if E_node.equal r_u (n_true self) then ( let expl = Expl.mk_merge u (n_true self) in merge_classes self n (n_false self) expl - ) else if Class.equal r_u (n_false self) then ( + ) else if E_node.equal r_u (n_false self) then ( let expl = Expl.mk_merge u (n_false self) in merge_classes self n (n_true self) expl ) | Some s0 -> - (* update the signature by using [find] on each sub-node *) + (* update the signature by using [find] on each sub-e_node *) let s = update_sig s0 in (match find_signature self s with | None -> (* add to the signature table [sig(n) --> n] *) add_signature self s n - | Some u when Class.equal n u -> () + | Some u when E_node.equal n u -> () | Some u -> (* [t1] and [t2] must be applications of the same symbol to arguments that are pairwise equal *) @@ -849,11 +871,15 @@ module Make (A : ARG) : let expl = Expl.mk_congruence n u in merge_classes self n u expl) - and task_combine_ self acts = function - | CT_merge (a, b, e_ab) -> task_merge_ self acts a b e_ab - | CT_set_val (n, v) -> task_set_val_ self acts n v + and task_combine_ self = function + | CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab + | CT_set_val (n, v) -> task_set_val_ self n v + | CT_act (Act_merge (t, u, e)) -> task_merge_ self t u e + | CT_act (Act_propagate _ as a) -> + (* will return this propagation to the caller *) + Vec.push self.res_acts a - and task_set_val_ self acts n v = + and task_set_val_ self n v = let repr_n = find_ n in (* - if repr(n) has value [v], do nothing - else if repr(n) has value [v'], semantic conflict @@ -872,11 +898,15 @@ module Make (A : ARG) : k "(@[cc.semantic-conflict.set-val@ (@[set-val %a@ := %a@])@ \ (@[existing-val %a@ := %a@])@])" - Class.pp n Term.pp v Class.pp n' Term.pp v'); + E_node.pp n Term.pp v E_node.pp n' Term.pp v'); Stat.incr self.count_semantic_conflict; - let (module A) = acts in - A.raise_semantic_conflict lits tuples + (* FIXME + raise (E_confl(Conflict lits)) + let (module A) = acts in + A.raise_semantic_conflict lits tuples + *) + assert false | Some _ -> () | None -> T_b_tbl.add self.t_to_val repr_n.n_term (n, v)); (* now for the reverse map, look in self.val_to_t for [v]. @@ -890,23 +920,23 @@ module Make (A : ARG) : (* main CC algo: merge equivalence classes in [st.combine]. @raise Exn_unsat if merge fails *) - and task_merge_ self acts a b e_ab : unit = + and task_merge_ self a b e_ab : unit = let ra = find_ a in let rb = find_ b in - if not @@ Class.equal ra rb then ( - assert (Class.is_root ra); - assert (Class.is_root rb); + if not @@ E_node.equal ra rb then ( + assert (E_node.is_root ra); + assert (E_node.is_root rb); Stat.incr self.count_merge; (* check we're not merging [true] and [false] *) if - (Class.equal ra (n_true self) && Class.equal rb (n_false self)) - || (Class.equal rb (n_true self) && Class.equal ra (n_false self)) + (E_node.equal ra (n_true self) && E_node.equal rb (n_false self)) + || (E_node.equal rb (n_true self) && E_node.equal ra (n_false self)) then ( Log.debugf 5 (fun k -> k "(@[cc.merge.true_false_conflict@ @[:r1 %a@ :t1 %a@]@ @[:r2 \ %a@ :t2 %a@]@ :e_ab %a@])" - Class.pp ra Class.pp a Class.pp rb Class.pp b Expl.pp e_ab); + E_node.pp ra E_node.pp a E_node.pp rb E_node.pp b Expl.pp e_ab); let th = ref false in (* TODO: C1: P.true_neq_false @@ -923,16 +953,19 @@ module Make (A : ARG) : let lits = expl_st.lits in let same_val = expl_st.same_val - |> List.rev_map (fun (t, u) -> true, Class.term t, Class.term u) + |> List.rev_map (fun (t, u) -> true, E_node.term t, E_node.term u) in assert (same_val <> []); Stat.incr self.count_semantic_conflict; - let (module A) = acts in - A.raise_semantic_conflict lits same_val + (* FIXME + let (module A) = acts in + A.raise_semantic_conflict lits same_val + *) + assert false ) else ( (* regular conflict *) let lits, pr = lits_and_proof_of_expl self expl_st in - raise_conflict_ self ~th:!th acts (List.rev_map Lit.neg lits) pr + raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr ) ); (* We will merge [r_from] into [r_into]. @@ -950,10 +983,10 @@ module Make (A : ARG) : in (* when merging terms with [true] or [false], possibly propagate them to SAT *) let merge_bool r1 t1 r2 t2 = - if Class.equal r1 (n_true self) then - propagate_bools self acts r2 t2 r1 t1 e_ab true - else if Class.equal r1 (n_false self) then - propagate_bools self acts r2 t2 r1 t1 e_ab false + if E_node.equal r1 (n_true self) then + propagate_bools self r2 t2 r1 t1 e_ab true + else if E_node.equal r1 (n_false self) then + propagate_bools self r2 t2 r1 t1 e_ab false in if not self.model_mode then ( @@ -963,7 +996,8 @@ module Make (A : ARG) : (* perform [union r_from r_into] *) Log.debugf 15 (fun k -> - k "(@[cc.merge@ :from %a@ :into %a@])" Class.pp r_from Class.pp r_into); + k "(@[cc.merge@ :from %a@ :into %a@])" E_node.pp r_from E_node.pp + r_into); (* call [on_pre_merge] functions, and merge theory data items *) if not self.model_mode then ( @@ -971,13 +1005,18 @@ module Make (A : ARG) : let expl = Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ] in - Event.emit self.on_pre_merge (self, acts, r_into, r_from, expl) + Event.emit_iter self.on_pre_merge (self, r_into, r_from, expl) + ~f:(function + | Ok l -> push_action_l self l + | Error c -> raise (E_confl c)) ); + (* TODO: merge plugin data here, _after_ the pre-merge hooks are called, + so they have a chance of observing pre-merge plugin data *) ((* parents might have a different signature, check for collisions *) - Class.iter_parents r_from (fun parent -> push_pending self parent); - (* for each node in [r_from]'s class, make it point to [r_into] *) - Class.iter_class r_from (fun u -> + E_node.iter_parents r_from (fun parent -> push_pending self parent); + (* for each e_node in [r_from]'s class, make it point to [r_into] *) + E_node.iter_class r_from (fun u -> assert (u.n_root == r_from); u.n_root <- r_into); (* capture current state *) @@ -994,15 +1033,15 @@ module Make (A : ARG) : (* on backtrack, unmerge classes and restore the pointers to [r_from] *) on_backtrack self (fun () -> Log.debugf 30 (fun k -> - k "(@[cc.undo_merge@ :from %a@ :into %a@])" Class.pp r_from - Class.pp r_into); + k "(@[cc.undo_merge@ :from %a@ :into %a@])" E_node.pp r_from + E_node.pp r_into); r_into.n_bits <- r_into_old_bits; r_into.n_next <- r_into_old_next; r_from.n_next <- r_from_old_next; r_into.n_parents <- r_into_old_parents; (* NOTE: this must come after the restoration of [next] pointers, otherwise we'd iterate on too big a class *) - Class.iter_class_ r_from (fun u -> u.n_root <- r_from); + E_node.iter_class_ r_from (fun u -> u.n_root <- r_from); r_into.n_size <- r_into.n_size - r_from.n_size)); (* check for semantic values, update the one of [r_into] @@ -1030,11 +1069,14 @@ module Make (A : ARG) : k "(@[cc.semantic-conflict.post-merge@ (@[n-from %a@ := %a@])@ \ (@[n-into %a@ := %a@])@])" - Class.pp n_from Term.pp v_from Class.pp n_into Term.pp v_into); + E_node.pp n_from Term.pp v_from E_node.pp n_into Term.pp v_into); Stat.incr self.count_semantic_conflict; - let (module A) = acts in - A.raise_semantic_conflict lits tuples + (* FIXME + let (module A) = acts in + A.raise_semantic_conflict lits tuples + *) + assert false | Some _ -> ())); (* update explanations (a -> b), arbitrarily. @@ -1046,20 +1088,21 @@ module Make (A : ARG) : that bridges between [a] and [b] *) on_backtrack self (fun () -> match a.n_expl, b.n_expl with - | FL_some e, _ when Class.equal e.next b -> a.n_expl <- FL_none - | _, FL_some e when Class.equal e.next a -> b.n_expl <- FL_none + | FL_some e, _ when E_node.equal e.next b -> a.n_expl <- FL_none + | _, FL_some e when E_node.equal e.next a -> b.n_expl <- FL_none | _ -> assert false); a.n_expl <- FL_some { next = b; expl = e_ab }; (* call [on_post_merge] *) if not self.model_mode then - Event.emit self.on_post_merge (self, acts, r_into, r_from) + Event.emit_iter self.on_post_merge (self, r_into, r_from) + ~f:(push_action_l self) ) (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] in the equiv class of [r1] that is a known literal back to the SAT solver and which is not the one initially merged. We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) - and propagate_bools self acts r1 t1 r2 t2 (e_12 : explanation) sign : unit = + and propagate_bools self r1 t1 r2 t2 (e_12 : explanation) sign : unit = (* explanation for [t1 =e= t2 = r2] *) let half_expl_and_pr = lazy @@ -1070,14 +1113,14 @@ module Make (A : ARG) : in (* TODO: flag per class, `or`-ed on merge, to indicate if the class contains at least one lit *) - Class.iter_class r1 (fun u1 -> + E_node.iter_class r1 (fun u1 -> (* propagate if: - [u1] is a proper literal - [t2 != r2], because that can only happen after an explicit merge (no way to obtain that by propagation) *) - match Class.as_lit u1 with - | Some lit when not (Class.equal r2 t2) -> + match E_node.as_lit u1 with + | Some lit when not (E_node.equal r2 t2) -> let lit = if sign then lit @@ -1102,21 +1145,23 @@ module Make (A : ARG) : let _, pr = lits_and_proof_of_expl self st in guard, pr in - Event.emit self.on_propagate (self, lit, reason); - Stat.incr self.count_props; - let (module A) = acts in - A.propagate lit ~reason + push_action self (Act_propagate { lit; reason }); + Event.emit_iter self.on_propagate (self, lit, reason) + ~f:(push_action_l self); + Stat.incr self.count_props ) | _ -> ()) let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t) - let[@inline] push_level (self : t) : unit = + let push_level (self : t) : unit = + assert (not self.in_loop); Backtrack_stack.push_level self.undo; T_b_tbl.push_level self.t_to_val; T_b_tbl.push_level self.val_to_t let pop_levels (self : t) n : unit = + assert (not self.in_loop); Vec.clear self.pending; Vec.clear self.combine; Log.debugf 15 (fun k -> @@ -1127,6 +1172,7 @@ module Make (A : ARG) : T_b_tbl.pop_levels self.val_to_t n; () + (* FIXME: remove *) (* run [f] in a local congruence closure level *) let with_model_mode self f = assert (not self.model_mode); @@ -1141,22 +1187,26 @@ module Make (A : ARG) : all_classes self |> Iter.filter_map (fun repr -> match T_b_tbl.get self.t_to_val repr.n_term with - | Some (_, v) -> Some (repr, Class.iter_class repr, v) + | Some (_, v) -> Some (repr, E_node.iter_class repr, v) | None -> None) + let assert_eq self t u expl : unit = + assert (not self.in_loop); + let t = add_term self t in + let u = add_term self u in + (* merge [a] and [b] *) + merge_classes self t u expl + (* assert that this boolean literal holds. if a lit is [= a b], merge [a] and [b]; otherwise merge the atom with true/false *) let assert_lit self lit : unit = + assert (not self.in_loop); let t = Lit.term lit in Log.debugf 15 (fun k -> k "(@[cc.assert-lit@ %a@])" Lit.pp lit); let sign = Lit.sign lit in match A.view_as_cc t with - | Eq (a, b) when sign -> - let a = add_term self a in - let b = add_term self b in - (* merge [a] and [b] *) - merge_classes self a b (Expl.mk_lit lit) + | Eq (a, b) when sign -> assert_eq self a b (Expl.mk_lit lit) | _ -> (* equate t and true/false *) let rhs = n_bool self sign in @@ -1168,23 +1218,28 @@ module Make (A : ARG) : (* TODO: use oriented merge (force direction [n -> rhs]) *) merge_classes self n rhs (Expl.mk_lit lit) - let[@inline] assert_lits self lits : unit = Iter.iter (assert_lit self) lits + let[@inline] assert_lits self lits : unit = + assert (not self.in_loop); + Iter.iter (assert_lit self) lits - (* raise a conflict *) - let raise_conflict_from_expl self (acts : actions) expl = - Log.debugf 5 (fun k -> - k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); - let st = Expl_state.create () in - explain_decompose_expl self st expl; - let lits, pr = lits_and_proof_of_expl self st in - let c = List.rev_map Lit.neg lits in - let th = st.th_lemmas <> [] in - raise_conflict_ self ~th acts c pr + (* FIXME: remove? + (* raise a conflict *) + let raise_conflict_from_expl self (acts : actions_or_confl) expl = + Log.debugf 5 (fun k -> + k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); + let st = Expl_state.create () in + explain_decompose_expl self st expl; + let lits, pr = lits_and_proof_of_expl self st in + let c = List.rev_map Lit.neg lits in + let th = st.th_lemmas <> [] in + raise_conflict_ self ~th c pr + *) let merge self n1 n2 expl = + assert (not self.in_loop); Log.debugf 5 (fun k -> - k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" Class.pp n1 Class.pp - n2 Expl.pp expl); + k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" E_node.pp n1 + E_node.pp n2 Expl.pp expl); assert (T.Ty.equal (T.Term.ty n1.n_term) (T.Term.ty n2.n_term)); merge_classes self n1 n2 expl @@ -1192,6 +1247,7 @@ module Make (A : ARG) : merge self (add_term self t1) (add_term self t2) expl let set_model_value (self : t) (t : term) (v : value) : unit = + assert (not self.in_loop); assert self.model_mode; (* only valid in model mode *) match T_tbl.find_opt self.tbl t with @@ -1241,6 +1297,8 @@ module Make (A : ARG) : undo = Backtrack_stack.create (); true_; false_; + in_loop = false; + res_acts = Vec.create (); field_marked_explain; count_conflict = Stat.mk_int stat "cc.conflicts"; count_props = Stat.mk_int stat "cc.propagations"; @@ -1249,17 +1307,31 @@ module Make (A : ARG) : } and true_ = lazy (add_term cc (Term.bool tst true)) and false_ = lazy (add_term cc (Term.bool tst false)) in - ignore (Lazy.force true_ : node); - ignore (Lazy.force false_ : node); + ignore (Lazy.force true_ : e_node); + ignore (Lazy.force false_ : e_node); cc let[@inline] find_t self t : repr = let n = T_tbl.find self.tbl t in find_ n - let[@inline] check self acts : unit = + let pop_acts_ self = + let rec loop acc = + match Vec.pop self.res_acts with + | None -> acc + | Some x -> loop (x :: acc) + in + loop [] + + let check self : actions_or_confl = Log.debug 5 "(cc.check)"; - update_tasks self acts + self.in_loop <- true; + let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in + try + update_tasks self; + let l = pop_acts_ self in + Ok l + with E_confl c -> Error c let check_inv_enabled_ = true (* XXX NUDGE *) @@ -1268,7 +1340,7 @@ module Make (A : ARG) : if check_inv_enabled_ then ( Log.debug 2 "(cc.check-invariants)"; all_classes self - |> Iter.flat_map Class.iter_class + |> Iter.flat_map E_node.iter_class |> Iter.iter (fun n -> match n.n_sig0 with | None -> () @@ -1277,16 +1349,16 @@ module Make (A : ARG) : let ok = match find_signature self s' with | None -> false - | Some r -> Class.equal r n.n_root + | Some r -> E_node.equal r n.n_root in if not ok then Log.debugf 0 (fun k -> k "(@[cc.check.fail@ :n %a@ :sig %a@ :actual-sig %a@])" - Class.pp n Signature.pp s Signature.pp s')) + E_node.pp n Signature.pp s Signature.pp s')) ) (* model: return all the classes *) let get_model (self : t) : repr Iter.t Iter.t = check_inv_ self; - all_classes self |> Iter.map Class.iter_class + all_classes self |> Iter.map E_node.iter_class end diff --git a/src/cc/plugin/sidekick_cc_plugin.ml b/src/cc/plugin/sidekick_cc_plugin.ml index 211fd6be..2ddb9389 100644 --- a/src/cc/plugin/sidekick_cc_plugin.ml +++ b/src/cc/plugin/sidekick_cc_plugin.ml @@ -3,11 +3,11 @@ open Sidekick_sigs_cc module type EXTENDED_PLUGIN_BUILDER = sig include MONOID_PLUGIN_BUILDER - val mem : t -> M.CC.Class.t -> bool - (** Does the CC Class.t have a monoid value? *) + val mem : t -> M.CC.E_node.t -> bool + (** Does the CC E_node.t have a monoid value? *) - val get : t -> M.CC.Class.t -> M.t option - (** Get monoid value for this CC Class.t, if any *) + val get : t -> M.CC.E_node.t -> M.t option + (** Get monoid value for this CC E_node.t, if any *) val iter_all : t -> (M.CC.repr * M.t) Iter.t @@ -19,8 +19,8 @@ module Make (M : MONOID_PLUGIN_ARG) : EXTENDED_PLUGIN_BUILDER with module M = M = struct module M = M module CC = M.CC - module Class = CC.Class - module Cls_tbl = Backtrackable_tbl.Make (Class) + module E_node = CC.E_node + module Cls_tbl = Backtrackable_tbl.Make (E_node) module Expl = CC.Expl type term = CC.term @@ -41,7 +41,7 @@ module Make (M : MONOID_PLUGIN_ARG) : let values : M.t Cls_tbl.t = Cls_tbl.create ?size () (* bit in CC to filter out quickly classes without value *) - let field_has_value : CC.Class.bitfield = + let field_has_value : CC.E_node.bitfield = CC.allocate_bitfield ~descr:("monoid." ^ M.name ^ ".has-value") cc let push_level () = Cls_tbl.push_level values @@ -69,8 +69,8 @@ module Make (M : MONOID_PLUGIN_ARG) : (match maybe_m with | Some v -> Log.debugf 20 (fun k -> - k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])" M.name Class.pp n - M.pp v); + k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])" M.name E_node.pp + n M.pp v); CC.set_bitfield cc field_has_value true n; Cls_tbl.add values n v | None -> ()); @@ -78,25 +78,25 @@ module Make (M : MONOID_PLUGIN_ARG) : (fun (n_u, m_u) -> Log.debugf 20 (fun k -> k "(@[monoid[%s].on-new-term.sub@ :n %a@ :sub-t %a@ :value %a@])" - M.name Class.pp n Class.pp n_u M.pp m_u); + M.name E_node.pp n E_node.pp n_u M.pp m_u); let n_u = CC.find cc n_u in if CC.get_bitfield cc field_has_value n_u then ( let m_u' = try Cls_tbl.find values n_u with Not_found -> - Error.errorf "node %a has bitfield but no value" Class.pp n_u + Error.errorf "node %a has bitfield but no value" E_node.pp n_u in match M.merge cc n_u m_u n_u m_u' (Expl.mk_list []) with | Error expl -> Error.errorf "when merging@ @[for node %a@],@ values %a and %a:@ conflict %a" - Class.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl + E_node.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl | Ok m_u_merged -> Log.debugf 20 (fun k -> k "(@[monoid[%s].on-new-term.sub.merged@ :n %a@ :sub-t %a@ \ :value %a@])" - M.name Class.pp n Class.pp n_u M.pp m_u_merged); + M.name E_node.pp n E_node.pp n_u M.pp m_u_merged); Cls_tbl.add values n_u m_u_merged ) else ( (* just add to [n_u] *) @@ -108,30 +108,33 @@ module Make (M : MONOID_PLUGIN_ARG) : let iter_all : _ Iter.t = Cls_tbl.to_iter values - let on_pre_merge cc acts n1 n2 e_n1_n2 : unit = - match get n1, get n2 with - | Some v1, Some v2 -> - Log.debugf 5 (fun k -> - k - "(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 %a@ \ - :val2 %a@])@])" - M.name Class.pp n1 M.pp v1 Class.pp n2 M.pp v2); - (match M.merge cc n1 v1 n2 v2 e_n1_n2 with - | Ok v' -> - Cls_tbl.remove values n2; - (* only keep repr *) - Cls_tbl.add values n1 v' - | Error expl -> CC.raise_conflict_from_expl cc acts expl) - | None, Some cr -> - CC.set_bitfield cc field_has_value true n1; - Cls_tbl.add values n1 cr; - Cls_tbl.remove values n2 (* only keep reprs *) - | Some _, None -> () (* already there on the left *) - | None, None -> () + let on_pre_merge cc n1 n2 e_n1_n2 : CC.actions = + let exception E of M.CC.conflict in + try + match get n1, get n2 with + | Some v1, Some v2 -> + Log.debugf 5 (fun k -> + k + "(@[monoid[%s].on_pre_merge@ (@[:n1 %a@ :val1 %a@])@ (@[:n2 \ + %a@ :val2 %a@])@])" + M.name E_node.pp n1 M.pp v1 E_node.pp n2 M.pp v2); + (match M.merge cc n1 v1 n2 v2 e_n1_n2 with + | Ok v' -> + Cls_tbl.remove values n2; + (* only keep repr *) + Cls_tbl.add values n1 v' + | Error expl -> raise (E (CC.Conflict_expl expl))) + | None, Some cr -> + CC.set_bitfield cc field_has_value true n1; + Cls_tbl.add values n1 cr; + Cls_tbl.remove values n2 (* only keep reprs *) + | Some _, None -> () (* already there on the left *) + | None, None -> () + with E c -> Error c let pp out () : unit = let pp_e out (t, v) = - Fmt.fprintf out "(@[%a@ :has %a@])" Class.pp t M.pp v + Fmt.fprintf out "(@[%a@ :has %a@])" E_node.pp t M.pp v in Fmt.fprintf out "(@[%a@])" (Fmt.iter pp_e) iter_all