From 464bc6647454666e3dab31761f9e3fd4e480213f Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 29 Jul 2022 00:02:27 -0400 Subject: [PATCH] wip: refactor(cc): remove layers of functorization --- src/cc/Sidekick_cc.ml | 1211 +------------------------- src/cc/Sidekick_cc.mli | 16 +- src/cc/bits.ml | 26 + src/cc/bits.mli | 13 + src/cc/core_cc.ml | 1136 ++++++++++++++++++++++++ src/cc/dune | 6 +- src/cc/mini/Sidekick_mini_cc.ml | 55 +- src/cc/mini/Sidekick_mini_cc.mli | 29 +- src/cc/mini/dune | 2 +- src/cc/plugin/dune | 4 +- src/cc/plugin/sidekick_cc_plugin.ml | 2 +- src/cc/plugin/sidekick_cc_plugin.mli | 4 +- src/cc/sigs.ml | 506 +++++++++++ src/cc/view.ml | 38 + src/cc/view.mli | 33 + 15 files changed, 1805 insertions(+), 1276 deletions(-) create mode 100644 src/cc/bits.ml create mode 100644 src/cc/bits.mli create mode 100644 src/cc/core_cc.ml create mode 100644 src/cc/sigs.ml create mode 100644 src/cc/view.ml create mode 100644 src/cc/view.mli diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index ae1562ae..7648562c 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -1,1208 +1,9 @@ -open Sidekick_sigs_cc +open Sidekick_core module View = View -open View -module type S = sig - include S +module type ARG = Sigs.ARG +module type S = Sigs.S +module type MONOID_PLUGIN_ARG = Sigs.MONOID_PLUGIN_ARG +module type MONOID_PLUGIN_BUILDER = Sigs.MONOID_PLUGIN_BUILDER - val create : - ?stat:Stat.t -> ?size:[ `Small | `Big ] -> term_store -> proof_trace -> t - (** Create a new congruence closure. - - @param term_store used to be able to create new terms. All terms - interacting with this congruence closure must belong in this term state - as well. - *) - - (**/**) - - module Debug_ : sig - val pp : t Fmt.printer - (** Print the whole CC *) - end - - (**/**) -end - -module type ARG = ARG - -(* small bitfield *) -module Bits : sig - type t = private int - type field - type bitfield_gen - - val empty : t - val equal : t -> t -> bool - val mk_field : bitfield_gen -> field - val mk_gen : unit -> bitfield_gen - val get : field -> t -> bool - val set : field -> bool -> t -> t - val merge : t -> t -> t -end = struct - type bitfield_gen = int ref - - let max_width = Sys.word_size - 2 - let mk_gen () = ref 0 - - type t = int - type field = int - - let empty : t = 0 - - let mk_field (gen : bitfield_gen) : field = - let n = !gen in - if n > max_width then Error.errorf "maximum number of CC bitfields reached"; - incr gen; - 1 lsl n - - let[@inline] get field x = x land field <> 0 - - let[@inline] set field b x = - if b then - x lor field - else - x land lnot field - - let merge = ( lor ) - let equal : t -> t -> bool = CCEqual.poly -end - -module Make (A : ARG) : - S - with module T = A.T - and module Lit = A.Lit - and module Proof_trace = A.Proof_trace = struct - module T = A.T - module Lit = A.Lit - module Proof_trace = A.Proof_trace - module Term = T.Term - module Fun = T.Fun - - open struct - (* proof rules *) - module Rules_ = A.Rule_core - module P = Proof_trace - end - - type term = T.Term.t - type value = term - type term_store = T.Term.store - type lit = Lit.t - type fun_ = T.Fun.t - type proof_trace = A.Proof_trace.t - type step_id = A.Proof_trace.A.step_id - - type e_node = { - n_term: term; - mutable n_sig0: signature option; (* initial signature *) - mutable n_bits: Bits.t; (* bitfield for various properties *) - mutable n_parents: e_node Bag.t; (* parent terms of this node *) - mutable n_root: e_node; - (* representative of congruence class (itself if a representative) *) - mutable n_next: e_node; (* pointer to next element of congruence class *) - mutable n_size: int; (* size of the class *) - mutable n_as_lit: lit option; - (* TODO: put into payload? and only in root? *) - mutable n_expl: explanation_forest_link; - (* the rooted forest for explanations *) - } - (** A node of the congruence closure. - An equivalence class is represented by its "root" element, - the representative. *) - - and signature = (fun_, e_node, e_node list) View.t - - and explanation_forest_link = - | FL_none - | FL_some of { next: e_node; expl: explanation } - - (* atomic explanation in the congruence closure *) - and explanation = - | E_trivial (* by pure reduction, tautologically equal *) - | E_lit of lit (* because of this literal *) - | E_merge of e_node * e_node - | E_merge_t of term * term - | E_congruence of e_node * e_node (* caused by normal congruence *) - | E_and of explanation * explanation - | E_theory of term * term * (term * term * explanation list) list * step_id - - type repr = e_node - - module E_node = struct - type t = e_node - - let[@inline] equal (n1 : t) n2 = n1 == n2 - let[@inline] hash n = Term.hash n.n_term - let[@inline] term n = n.n_term - let[@inline] pp out n = Term.pp out n.n_term - let[@inline] as_lit n = n.n_as_lit - - let make (t : term) : t = - let rec n = - { - n_term = t; - n_sig0 = None; - n_bits = Bits.empty; - n_parents = Bag.empty; - n_as_lit = None; - (* TODO: provide a method to do it *) - n_root = n; - n_expl = FL_none; - n_next = n; - n_size = 1; - } - in - n - - let[@inline] is_root (n : e_node) : bool = n.n_root == n - - (* traverse the equivalence class of [n] *) - let iter_class_ (n : e_node) : e_node Iter.t = - fun yield -> - let rec aux u = - yield u; - if u.n_next != n then aux u.n_next - in - aux n - - let[@inline] iter_class n = - assert (is_root n); - iter_class_ n - - let[@inline] iter_parents (n : e_node) : e_node Iter.t = - assert (is_root n); - Bag.to_iter n.n_parents - - type bitfield = Bits.field - - let[@inline] get_field f t = Bits.get f t.n_bits - let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits - end - - (* non-recursive, inlinable function for [find] *) - let[@inline] find_ (n : e_node) : repr = - let n2 = n.n_root in - assert (E_node.is_root n2); - n2 - - let[@inline] same_class (n1 : e_node) (n2 : e_node) : bool = - E_node.equal (find_ n1) (find_ n2) - - let[@inline] find _ n = find_ n - - module Expl = struct - type t = explanation - - let rec pp out (e : explanation) = - match e with - | E_trivial -> Fmt.string out "reduction" - | E_lit lit -> Lit.pp out lit - | E_congruence (n1, n2) -> - Fmt.fprintf out "(@[congruence@ %a@ %a@])" E_node.pp n1 E_node.pp n2 - | E_merge (a, b) -> - Fmt.fprintf out "(@[merge@ %a@ %a@])" E_node.pp a E_node.pp b - | E_merge_t (a, b) -> - Fmt.fprintf out "(@[merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp a - Term.pp b - | E_theory (t, u, es, _) -> - Fmt.fprintf out "(@[th@ :t `%a`@ :u `%a`@ :expl_sets %a@])" Term.pp t - Term.pp u - (Util.pp_list @@ Fmt.Dump.triple Term.pp Term.pp (Fmt.Dump.list pp)) - es - | E_and (a, b) -> Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b - - let mk_trivial : t = E_trivial - let[@inline] mk_congruence n1 n2 : t = E_congruence (n1, n2) - - let[@inline] mk_merge a b : t = - if E_node.equal a b then - mk_trivial - else - E_merge (a, b) - - let[@inline] mk_merge_t a b : t = - if Term.equal a b then - mk_trivial - else - E_merge_t (a, b) - - let[@inline] mk_lit l : t = E_lit l - let[@inline] mk_theory t u es pr = E_theory (t, u, es, pr) - - let rec mk_list l = - match l with - | [] -> mk_trivial - | [ x ] -> x - | E_trivial :: tl -> mk_list tl - | x :: y -> - (match mk_list y with - | E_trivial -> x - | y' -> E_and (x, y')) - end - - module Resolved_expl = struct - type t = { lits: lit list; pr: proof_trace -> step_id } - - let pp out (self : t) = - Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits - end - - (** A signature is a shallow term shape where immediate subterms - are representative *) - module Signature = struct - type t = signature - - let equal (s1 : t) s2 : bool = - match s1, s2 with - | Bool b1, Bool b2 -> b1 = b2 - | App_fun (f1, []), App_fun (f2, []) -> Fun.equal f1 f2 - | App_fun (f1, l1), App_fun (f2, l2) -> - Fun.equal f1 f2 && CCList.equal E_node.equal l1 l2 - | App_ho (f1, a1), App_ho (f2, a2) -> - E_node.equal f1 f2 && E_node.equal a1 a2 - | Not a, Not b -> E_node.equal a b - | If (a1, b1, c1), If (a2, b2, c2) -> - E_node.equal a1 a2 && E_node.equal b1 b2 && E_node.equal c1 c2 - | Eq (a1, b1), Eq (a2, b2) -> E_node.equal a1 a2 && E_node.equal b1 b2 - | Opaque u1, Opaque u2 -> E_node.equal u1 u2 - | Bool _, _ - | App_fun _, _ - | App_ho _, _ - | If _, _ - | Eq _, _ - | Opaque _, _ - | Not _, _ -> - false - - let hash (s : t) : int = - let module H = CCHash in - match s with - | Bool b -> H.combine2 10 (H.bool b) - | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list E_node.hash l) - | App_ho (f, a) -> H.combine3 30 (E_node.hash f) (E_node.hash a) - | Eq (a, b) -> H.combine3 40 (E_node.hash a) (E_node.hash b) - | Opaque u -> H.combine2 50 (E_node.hash u) - | If (a, b, c) -> - H.combine4 60 (E_node.hash a) (E_node.hash b) (E_node.hash c) - | Not u -> H.combine2 70 (E_node.hash u) - - let pp out = function - | Bool b -> Fmt.bool out b - | App_fun (f, []) -> Fun.pp out f - | App_fun (f, l) -> - Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list E_node.pp) l - | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" E_node.pp f E_node.pp a - | Opaque t -> E_node.pp out t - | Not u -> Fmt.fprintf out "(@[not@ %a@])" E_node.pp u - | Eq (a, b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" E_node.pp a E_node.pp b - | If (a, b, c) -> - Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" E_node.pp a E_node.pp b - E_node.pp c - end - - module Sig_tbl = CCHashtbl.Make (Signature) - module T_tbl = CCHashtbl.Make (Term) - - type propagation_reason = unit -> lit list * step_id - - module Handler_action = struct - type t = - | Act_merge of E_node.t * E_node.t * Expl.t - | Act_propagate of lit * propagation_reason - - type conflict = Conflict of Expl.t [@@unboxed] - type or_conflict = (t list, conflict) result - end - - module Result_action = struct - type t = Act_propagate of { lit: lit; reason: propagation_reason } - type conflict = Conflict of lit list * step_id - type or_conflict = (t list, conflict) result - end - - type combine_task = - | CT_merge of e_node * e_node * explanation - | CT_act of Handler_action.t - - type t = { - tst: term_store; - proof: proof_trace; - tbl: e_node T_tbl.t; (* internalization [term -> e_node] *) - signatures_tbl: e_node Sig_tbl.t; - (* map a signature to the corresponding e_node in some equivalence class. - A signature is a [term_cell] in which every immediate subterm - that participates in the congruence/evaluation relation - is normalized (i.e. is its own representative). - The critical property is that all members of an equivalence class - that have the same "shape" (including head symbol) - have the same signature *) - pending: e_node Vec.t; - combine: combine_task Vec.t; - undo: (unit -> unit) Backtrack_stack.t; - bitgen: Bits.bitfield_gen; - field_marked_explain: Bits.field; - (* used to mark traversed nodes when looking for a common ancestor *) - true_: e_node lazy_t; - false_: e_node lazy_t; - mutable in_loop: bool; (* currently being modified? *) - res_acts: Result_action.t Vec.t; (* to return *) - on_pre_merge: - ( t * E_node.t * E_node.t * Expl.t, - Handler_action.or_conflict ) - Event.Emitter.t; - on_pre_merge2: - ( t * E_node.t * E_node.t * Expl.t, - Handler_action.or_conflict ) - Event.Emitter.t; - on_post_merge: - (t * E_node.t * E_node.t, Handler_action.t list) Event.Emitter.t; - on_new_term: (t * E_node.t * term, Handler_action.t list) Event.Emitter.t; - on_conflict: (ev_on_conflict, unit) Event.Emitter.t; - on_propagate: - (t * lit * propagation_reason, Handler_action.t list) Event.Emitter.t; - on_is_subterm: (t * E_node.t * term, Handler_action.t list) Event.Emitter.t; - count_conflict: int Stat.counter; - count_props: int Stat.counter; - count_merge: int Stat.counter; - } - (* TODO: an additional union-find to keep track, for each term, - of the terms they are known to be equal to, according - to the current explanation. That allows not to prove some equality - several times. - See "fast congruence closure and extensions", Nieuwenhuis&al, page 14 *) - - and ev_on_conflict = { cc: t; th: bool; c: lit list } - - let[@inline] size_ (r : repr) = r.n_size - let[@inline] n_true self = Lazy.force self.true_ - let[@inline] n_false self = Lazy.force self.false_ - - let n_bool self b = - if b then - n_true self - else - n_false self - - let[@inline] term_store self = self.tst - let[@inline] proof self = self.proof - - let allocate_bitfield self ~descr = - Log.debugf 5 (fun k -> k "(@[cc.allocate-bit-field@ :descr %s@])" descr); - Bits.mk_field self.bitgen - - let[@inline] on_backtrack self f : unit = - Backtrack_stack.push_if_nonzero_level self.undo f - - let[@inline] get_bitfield _cc field n = E_node.get_field field n - - let set_bitfield self field b n = - let old = E_node.get_field field n in - if old <> b then ( - on_backtrack self (fun () -> E_node.set_field field old n); - E_node.set_field field b n - ) - - (* check if [t] is in the congruence closure. - Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) - let[@inline] mem (self : t) (t : term) : bool = T_tbl.mem self.tbl t - - module Debug_ = struct - (* print full state *) - let pp out (self : t) : unit = - let pp_next out n = Fmt.fprintf out "@ :next %a" E_node.pp n.n_next in - let pp_root out n = - if E_node.is_root n then - Fmt.string out " :is-root" - else - Fmt.fprintf out "@ :root %a" E_node.pp n.n_root - in - let pp_expl out n = - match n.n_expl with - | FL_none -> () - | FL_some e -> - Fmt.fprintf out " (@[:forest %a :expl %a@])" E_node.pp e.next Expl.pp - e.expl - in - let pp_n out n = - Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n - pp_expl n - and pp_sig_e out (s, n) = - Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s E_node.pp n - pp_root n - in - Fmt.fprintf out - "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ \ - %a@])@])" - (Util.pp_iter ~sep:" " pp_n) - (T_tbl.values self.tbl) - (Util.pp_iter ~sep:" " pp_sig_e) - (Sig_tbl.to_iter self.signatures_tbl) - end - - (* compute up-to-date signature *) - let update_sig (s : signature) : Signature.t = - View.map_view s ~f_f:(fun x -> x) ~f_t:find_ ~f_ts:(List.map find_) - - (* find whether the given (parent) term corresponds to some signature - in [signatures_] *) - let[@inline] find_signature cc (s : signature) : repr option = - Sig_tbl.get cc.signatures_tbl s - - (* add to signature table. Assume it's not present already *) - let add_signature self (s : signature) (n : e_node) : unit = - assert (not @@ Sig_tbl.mem self.signatures_tbl s); - Log.debugf 50 (fun k -> - k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s E_node.pp n); - on_backtrack self (fun () -> Sig_tbl.remove self.signatures_tbl s); - Sig_tbl.add self.signatures_tbl s n - - let push_pending self t : unit = - Log.debugf 50 (fun k -> k "(@[cc.push-pending@ %a@])" E_node.pp t); - Vec.push self.pending t - - let push_action self (a : Handler_action.t) : unit = - Vec.push self.combine (CT_act a) - - let push_action_l self (l : _ list) : unit = List.iter (push_action self) l - - let merge_classes self t u e : unit = - if t != u && not (same_class t u) then ( - Log.debugf 50 (fun k -> - k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" E_node.pp t - E_node.pp u Expl.pp e); - Vec.push self.combine @@ CT_merge (t, u, e) - ) - - (* re-root the explanation tree of the equivalence class of [n] - so that it points to [n]. - postcondition: [n.n_expl = None] *) - let[@unroll 2] rec reroot_expl (self : t) (n : e_node) : unit = - match n.n_expl with - | FL_none -> () (* already root *) - | FL_some { next = u; expl = e_n_u } -> - (* reroot to [u], then invert link between [u] and [n] *) - reroot_expl self u; - u.n_expl <- FL_some { next = n; expl = e_n_u }; - n.n_expl <- FL_none - - exception E_confl of Result_action.conflict - - let raise_conflict_ (cc : t) ~th (e : lit list) (p : step_id) : _ = - Profile.instant "cc.conflict"; - (* clear tasks queue *) - Vec.clear cc.pending; - Vec.clear cc.combine; - Event.emit cc.on_conflict { cc; th; c = e }; - Stat.incr cc.count_conflict; - raise (E_confl (Conflict (e, p))) - - let[@inline] all_classes self : repr Iter.t = - T_tbl.values self.tbl |> Iter.filter E_node.is_root - - (* find the closest common ancestor of [a] and [b] in the proof forest. - - Precond: - - [a] and [b] are in the same class - - no e_node has the flag [field_marked_explain] on - Invariants: - - if [n] is marked, then all the predecessors of [n] - from [a] or [b] are marked too. - *) - let find_common_ancestor self (a : e_node) (b : e_node) : e_node = - (* catch up to the other e_node *) - let rec find1 a = - if E_node.get_field self.field_marked_explain a then - a - else ( - match a.n_expl with - | FL_none -> assert false - | FL_some r -> find1 r.next - ) - in - let rec find2 a b = - if E_node.equal a b then - a - else if E_node.get_field self.field_marked_explain a then - a - else if E_node.get_field self.field_marked_explain b then - b - else ( - E_node.set_field self.field_marked_explain true a; - E_node.set_field self.field_marked_explain true b; - match a.n_expl, b.n_expl with - | FL_some r1, FL_some r2 -> find2 r1.next r2.next - | FL_some r, FL_none -> find1 r.next - | FL_none, FL_some r -> find1 r.next - | FL_none, FL_none -> assert false - (* no common ancestor *) - ) - in - - (* cleanup tags on nodes traversed in [find2] *) - let rec cleanup_ n = - if E_node.get_field self.field_marked_explain n then ( - E_node.set_field self.field_marked_explain false n; - match n.n_expl with - | FL_none -> () - | FL_some { next; _ } -> cleanup_ next - ) - in - let n = find2 a b in - cleanup_ a; - cleanup_ b; - n - - module Expl_state = struct - type t = { - mutable lits: Lit.t list; - mutable th_lemmas: (Lit.t * (Lit.t * Lit.t list) list * step_id) list; - } - - let create () : t = { lits = []; th_lemmas = [] } - let[@inline] copy self : t = { self with lits = self.lits } - let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits - - let[@inline] add_th (self : t) lit hyps pr : unit = - self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas - - let merge self other = - let { lits = o_lits; th_lemmas = o_lemmas } = other in - self.lits <- List.rev_append o_lits self.lits; - self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas; - () - - (* proof of [\/_i ¬lits[i]] *) - let proof_of_th_lemmas (self : t) (proof : proof_trace) : step_id = - let p_lits1 = Iter.of_list self.lits |> Iter.map Lit.neg in - let p_lits2 = - Iter.of_list self.th_lemmas - |> Iter.map (fun (lit_t_u, _, _) -> Lit.neg lit_t_u) - in - let p_cc = - P.add_step proof @@ Rules_.lemma_cc (Iter.append p_lits1 p_lits2) - in - let resolve_with_th_proof pr (lit_t_u, sub_proofs, pr_th) = - (* pr_th: [sub_proofs |- t=u]. - now resolve away [sub_proofs] to get literals that were - asserted in the congruence closure *) - let pr_th = - List.fold_left - (fun pr_th (lit_i, hyps_i) -> - (* [hyps_i |- lit_i] *) - let lemma_i = - P.add_step proof - @@ Rules_.lemma_cc - Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) - in - (* resolve [lit_i] away. *) - P.add_step proof - @@ Rules_.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th) - pr_th sub_proofs - in - P.add_step proof @@ Rules_.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr - in - (* resolve with theory proofs responsible for some merges, if any. *) - List.fold_left resolve_with_th_proof p_cc self.th_lemmas - - let to_resolved_expl (self : t) : Resolved_expl.t = - (* FIXME: package the th lemmas too *) - let { lits; th_lemmas = _ } = self in - let s2 = copy self in - let pr proof = proof_of_th_lemmas s2 proof in - { Resolved_expl.lits; pr } - end - - (* decompose explanation [e] into a list of literals added to [acc] *) - let rec explain_decompose_expl self (st : Expl_state.t) (e : explanation) : - unit = - Log.debugf 5 (fun k -> k "(@[cc.decompose_expl@ %a@])" Expl.pp e); - match e with - | E_trivial -> () - | E_congruence (n1, n2) -> - (match n1.n_sig0, n2.n_sig0 with - | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> - assert (Fun.equal f1 f2); - assert (List.length a1 = List.length a2); - List.iter2 (explain_equal_rec_ self st) a1 a2 - | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> - explain_equal_rec_ self st f1 f2; - explain_equal_rec_ self st a1 a2 - | Some (If (a1, b1, c1)), Some (If (a2, b2, c2)) -> - explain_equal_rec_ self st a1 a2; - explain_equal_rec_ self st b1 b2; - explain_equal_rec_ self st c1 c2 - | _ -> assert false) - | E_lit lit -> Expl_state.add_lit st lit - | E_theory (t, u, expl_sets, pr) -> - let sub_proofs = - List.map - (fun (t_i, u_i, expls_i) -> - let lit_i = A.mk_lit_eq self.tst t_i u_i in - (* use a separate call to [explain_expls] for each set *) - let sub = explain_expls self expls_i in - Expl_state.merge st sub; - lit_i, sub.lits) - expl_sets - in - let lit_t_u = A.mk_lit_eq self.tst t u in - Expl_state.add_th st lit_t_u sub_proofs pr - | E_merge (a, b) -> explain_equal_rec_ self st a b - | E_merge_t (a, b) -> - (* find nodes for [a] and [b] on the fly *) - (match T_tbl.find self.tbl a, T_tbl.find self.tbl b with - | a, b -> explain_equal_rec_ self st a b - | exception Not_found -> - Error.errorf "expl: cannot find e_node(s) for %a, %a" Term.pp a Term.pp - b) - | E_and (a, b) -> - explain_decompose_expl self st a; - explain_decompose_expl self st b - - and explain_expls self (es : explanation list) : Expl_state.t = - let st = Expl_state.create () in - List.iter (explain_decompose_expl self st) es; - st - - and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) - : unit = - Log.debugf 5 (fun k -> - k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b); - assert (E_node.equal (find_ a) (find_ b)); - let ancestor = find_common_ancestor cc a b in - explain_along_path cc st a ancestor; - explain_along_path cc st b ancestor - - (* explain why [a = parent_a], where [a -> ... -> target] in the - proof forest *) - and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) - : unit = - let rec aux n = - if n == target then - () - else ( - match n.n_expl with - | FL_none -> assert false - | FL_some { next = next_n; expl } -> - explain_decompose_expl self st expl; - (* now prove [next_n = target] *) - aux next_n - ) - in - aux a - - (* add a term *) - let[@inline] rec add_term_rec_ self t : e_node = - match T_tbl.find self.tbl t with - | n -> n - | exception Not_found -> add_new_term_ self t - - (* add [t] when not present already *) - and add_new_term_ self (t : term) : e_node = - assert (not @@ mem self t); - Log.debugf 15 (fun k -> k "(@[cc.add-term@ %a@])" Term.pp t); - let n = E_node.make t in - (* register sub-terms, add [t] to their parent list, and return the - corresponding initial signature *) - let sig0 = compute_sig0 self n in - n.n_sig0 <- sig0; - (* remove term when we backtrack *) - on_backtrack self (fun () -> - Log.debugf 30 (fun k -> k "(@[cc.remove-term@ %a@])" Term.pp t); - T_tbl.remove self.tbl t); - (* add term to the table *) - T_tbl.add self.tbl t n; - if Option.is_some sig0 then - (* [n] might be merged with other equiv classes *) - push_pending self n; - Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self); - n - - (* compute the initial signature of the given e_node *) - and compute_sig0 (self : t) (n : e_node) : Signature.t option = - (* add sub-term to [cc], and register [n] to its parents. - Note that we return the exact sub-term, to get proper - explanations, but we add to the sub-term's root's parent list. *) - let deref_sub (u : term) : e_node = - let sub = add_term_rec_ self u in - (* add [n] to [sub.root]'s parent list *) - (let sub_r = find_ sub in - let old_parents = sub_r.n_parents in - if Bag.is_empty old_parents then - (* first time it has parents: tell watchers that this is a subterm *) - Event.emit_iter self.on_is_subterm (self, sub, u) - ~f:(push_action_l self); - on_backtrack self (fun () -> sub_r.n_parents <- old_parents); - sub_r.n_parents <- Bag.cons n sub_r.n_parents); - sub - in - let[@inline] return x = Some x in - match A.view_as_cc n.n_term with - | Bool _ | Opaque _ -> None - | Eq (a, b) -> - let a = deref_sub a in - let b = deref_sub b in - return @@ Eq (a, b) - | Not u -> return @@ Not (deref_sub u) - | App_fun (f, args) -> - let args = args |> Iter.map deref_sub |> Iter.to_list in - if args <> [] then - return @@ App_fun (f, args) - else - None - | App_ho (f, a) -> - let f = deref_sub f in - let a = deref_sub a in - return @@ App_ho (f, a) - | If (a, b, c) -> return @@ If (deref_sub a, deref_sub b, deref_sub c) - - let[@inline] add_term self t : e_node = add_term_rec_ self t - let mem_term = mem - - let set_as_lit self (n : e_node) (lit : lit) : unit = - match n.n_as_lit with - | Some _ -> () - | None -> - Log.debugf 15 (fun k -> - k "(@[cc.set-as-lit@ %a@ %a@])" E_node.pp n Lit.pp lit); - on_backtrack self (fun () -> n.n_as_lit <- None); - n.n_as_lit <- Some lit - - (* is [n] true or false? *) - let n_is_bool_value (self : t) n : bool = - E_node.equal n (n_true self) || E_node.equal n (n_false self) - - (* gather a pair [lits, pr], where [lits] is the set of - asserted literals needed in the explanation (which is useful for - the SAT solver), and [pr] is a proof, including sub-proofs for theory - merges. *) - let lits_and_proof_of_expl (self : t) (st : Expl_state.t) : - Lit.t list * step_id = - let { Expl_state.lits; th_lemmas = _ } = st in - let pr = Expl_state.proof_of_th_lemmas st self.proof in - lits, pr - - (* main CC algo: add terms from [pending] to the signature table, - check for collisions *) - let rec update_tasks (self : t) : unit = - while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do - while not @@ Vec.is_empty self.pending do - task_pending_ self (Vec.pop_exn self.pending) - done; - while not @@ Vec.is_empty self.combine do - task_combine_ self (Vec.pop_exn self.combine) - done - done - - and task_pending_ self (n : e_node) : unit = - (* check if some parent collided *) - match n.n_sig0 with - | None -> () (* no-op *) - | Some (Eq (a, b)) -> - (* if [a=b] is now true, merge [(a=b)] and [true] *) - if same_class a b then ( - let expl = Expl.mk_merge a b in - Log.debugf 5 (fun k -> - k "(@[cc.pending.eq@ %a@ :r1 %a@ :r2 %a@])" E_node.pp n E_node.pp a - E_node.pp b); - merge_classes self n (n_true self) expl - ) - | Some (Not u) -> - (* [u = bool ==> not u = not bool] *) - let r_u = find_ u in - if E_node.equal r_u (n_true self) then ( - let expl = Expl.mk_merge u (n_true self) in - merge_classes self n (n_false self) expl - ) else if E_node.equal r_u (n_false self) then ( - let expl = Expl.mk_merge u (n_false self) in - merge_classes self n (n_true self) expl - ) - | Some s0 -> - (* update the signature by using [find] on each sub-e_node *) - let s = update_sig s0 in - (match find_signature self s with - | None -> - (* add to the signature table [sig(n) --> n] *) - add_signature self s n - | Some u when E_node.equal n u -> () - | Some u -> - (* [t1] and [t2] must be applications of the same symbol to - arguments that are pairwise equal *) - assert (n != u); - let expl = Expl.mk_congruence n u in - merge_classes self n u expl) - - and task_combine_ self = function - | CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab - | CT_act (Handler_action.Act_merge (t, u, e)) -> task_merge_ self t u e - | CT_act (Handler_action.Act_propagate (lit, reason)) -> - (* will return this propagation to the caller *) - Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }) - - (* main CC algo: merge equivalence classes in [st.combine]. - @raise Exn_unsat if merge fails *) - and task_merge_ self a b e_ab : unit = - let ra = find_ a in - let rb = find_ b in - if not @@ E_node.equal ra rb then ( - assert (E_node.is_root ra); - assert (E_node.is_root rb); - Stat.incr self.count_merge; - (* check we're not merging [true] and [false] *) - if - (E_node.equal ra (n_true self) && E_node.equal rb (n_false self)) - || (E_node.equal rb (n_true self) && E_node.equal ra (n_false self)) - then ( - Log.debugf 5 (fun k -> - k - "(@[cc.merge.true_false_conflict@ @[:r1 %a@ :t1 %a@]@ @[:r2 \ - %a@ :t2 %a@]@ :e_ab %a@])" - E_node.pp ra E_node.pp a E_node.pp rb E_node.pp b Expl.pp e_ab); - let th = ref false in - (* TODO: - C1: P.true_neq_false - C2: lemma [lits |- true=false] (and resolve on theory proofs) - C3: r1 C1 C2 - *) - let expl_st = Expl_state.create () in - explain_decompose_expl self expl_st e_ab; - explain_equal_rec_ self expl_st a ra; - explain_equal_rec_ self expl_st b rb; - - (* regular conflict *) - let lits, pr = lits_and_proof_of_expl self expl_st in - raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr - ); - (* We will merge [r_from] into [r_into]. - we try to ensure that [size ra <= size rb] in general, but always - keep values as representative *) - let r_from, r_into = - if n_is_bool_value self ra then - rb, ra - else if n_is_bool_value self rb then - ra, rb - else if size_ ra > size_ rb then - rb, ra - else - ra, rb - in - (* when merging terms with [true] or [false], possibly propagate them to SAT *) - let merge_bool r1 t1 r2 t2 = - if E_node.equal r1 (n_true self) then - propagate_bools self r2 t2 r1 t1 e_ab true - else if E_node.equal r1 (n_false self) then - propagate_bools self r2 t2 r1 t1 e_ab false - in - - merge_bool ra a rb b; - merge_bool rb b ra a; - - (* perform [union r_from r_into] *) - Log.debugf 15 (fun k -> - k "(@[cc.merge@ :from %a@ :into %a@])" E_node.pp r_from E_node.pp - r_into); - - (* call [on_pre_merge] functions, and merge theory data items *) - (* explanation is [a=ra & e_ab & b=rb] *) - (let expl = - Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ] - in - - let handle_act = function - | Ok l -> push_action_l self l - | Error (Handler_action.Conflict expl) -> - raise_conflict_from_expl self expl - in - - Event.emit_iter self.on_pre_merge - (self, r_into, r_from, expl) - ~f:handle_act; - Event.emit_iter self.on_pre_merge2 - (self, r_into, r_from, expl) - ~f:handle_act); - - (* TODO: merge plugin data here, _after_ the pre-merge hooks are called, - so they have a chance of observing pre-merge plugin data *) - ((* parents might have a different signature, check for collisions *) - E_node.iter_parents r_from (fun parent -> push_pending self parent); - (* for each e_node in [r_from]'s class, make it point to [r_into] *) - E_node.iter_class r_from (fun u -> - assert (u.n_root == r_from); - u.n_root <- r_into); - (* capture current state *) - let r_into_old_next = r_into.n_next in - let r_from_old_next = r_from.n_next in - let r_into_old_parents = r_into.n_parents in - let r_into_old_bits = r_into.n_bits in - (* swap [into.next] and [from.next], merging the classes *) - r_into.n_next <- r_from_old_next; - r_from.n_next <- r_into_old_next; - r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; - r_into.n_size <- r_into.n_size + r_from.n_size; - r_into.n_bits <- Bits.merge r_into.n_bits r_from.n_bits; - (* on backtrack, unmerge classes and restore the pointers to [r_from] *) - on_backtrack self (fun () -> - Log.debugf 30 (fun k -> - k "(@[cc.undo_merge@ :from %a@ :into %a@])" E_node.pp r_from - E_node.pp r_into); - r_into.n_bits <- r_into_old_bits; - r_into.n_next <- r_into_old_next; - r_from.n_next <- r_from_old_next; - r_into.n_parents <- r_into_old_parents; - (* NOTE: this must come after the restoration of [next] pointers, - otherwise we'd iterate on too big a class *) - E_node.iter_class_ r_from (fun u -> u.n_root <- r_from); - r_into.n_size <- r_into.n_size - r_from.n_size)); - - (* update explanations (a -> b), arbitrarily. - Note that here we merge the classes by adding a bridge between [a] - and [b], not their roots. *) - reroot_expl self a; - assert (a.n_expl = FL_none); - (* on backtracking, link may be inverted, but we delete the one - that bridges between [a] and [b] *) - on_backtrack self (fun () -> - match a.n_expl, b.n_expl with - | FL_some e, _ when E_node.equal e.next b -> a.n_expl <- FL_none - | _, FL_some e when E_node.equal e.next a -> b.n_expl <- FL_none - | _ -> assert false); - a.n_expl <- FL_some { next = b; expl = e_ab }; - (* call [on_post_merge] *) - Event.emit_iter self.on_post_merge (self, r_into, r_from) - ~f:(push_action_l self) - ) - - (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] - in the equiv class of [r1] that is a known literal back to the SAT solver - and which is not the one initially merged. - We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) - and propagate_bools self r1 t1 r2 t2 (e_12 : explanation) sign : unit = - (* explanation for [t1 =e= t2 = r2] *) - let half_expl_and_pr = - lazy - (let st = Expl_state.create () in - explain_decompose_expl self st e_12; - explain_equal_rec_ self st r2 t2; - st) - in - (* TODO: flag per class, `or`-ed on merge, to indicate if the class - contains at least one lit *) - E_node.iter_class r1 (fun u1 -> - (* propagate if: - - [u1] is a proper literal - - [t2 != r2], because that can only happen - after an explicit merge (no way to obtain that by propagation) - *) - match E_node.as_lit u1 with - | Some lit when not (E_node.equal r2 t2) -> - let lit = - if sign then - lit - else - Lit.neg lit - in - (* apply sign *) - Log.debugf 5 (fun k -> k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); - (* complete explanation with the [u1=t1] chunk *) - let (lazy st) = half_expl_and_pr in - let st = Expl_state.copy st in - (* do not modify shared st *) - explain_equal_rec_ self st u1 t1; - - (* propagate only if this doesn't depend on some semantic values *) - let reason () = - (* true literals explaining why t1=t2 *) - let guard = st.lits in - (* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *) - Expl_state.add_lit st (Lit.neg lit); - let _, pr = lits_and_proof_of_expl self st in - guard, pr - in - Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }); - Event.emit_iter self.on_propagate (self, lit, reason) - ~f:(push_action_l self); - Stat.incr self.count_props - | _ -> ()) - - (* raise a conflict from an explanation, typically from an event handler. - Raises E_confl with a result conflict. *) - and raise_conflict_from_expl self (expl : Expl.t) : 'a = - Log.debugf 5 (fun k -> - k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); - let st = Expl_state.create () in - explain_decompose_expl self st expl; - let lits, pr = lits_and_proof_of_expl self st in - let c = List.rev_map Lit.neg lits in - let th = st.th_lemmas <> [] in - raise_conflict_ self ~th c pr - - let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t) - - let push_level (self : t) : unit = - assert (not self.in_loop); - Backtrack_stack.push_level self.undo - - let pop_levels (self : t) n : unit = - assert (not self.in_loop); - Vec.clear self.pending; - Vec.clear self.combine; - Log.debugf 15 (fun k -> - k "(@[cc.pop-levels %d@ :n-lvls %d@])" n - (Backtrack_stack.n_levels self.undo)); - Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f ()); - () - - let assert_eq self t u expl : unit = - assert (not self.in_loop); - let t = add_term self t in - let u = add_term self u in - (* merge [a] and [b] *) - merge_classes self t u expl - - (* assert that this boolean literal holds. - if a lit is [= a b], merge [a] and [b]; - otherwise merge the atom with true/false *) - let assert_lit self lit : unit = - assert (not self.in_loop); - let t = Lit.term lit in - Log.debugf 15 (fun k -> k "(@[cc.assert-lit@ %a@])" Lit.pp lit); - let sign = Lit.sign lit in - match A.view_as_cc t with - | Eq (a, b) when sign -> assert_eq self a b (Expl.mk_lit lit) - | _ -> - (* equate t and true/false *) - let rhs = n_bool self sign in - let n = add_term self t in - (* TODO: ensure that this is O(1). - basically, just have [n] point to true/false and thus acquire - the corresponding value, so its superterms (like [ite]) can evaluate - properly *) - (* TODO: use oriented merge (force direction [n -> rhs]) *) - merge_classes self n rhs (Expl.mk_lit lit) - - let[@inline] assert_lits self lits : unit = - assert (not self.in_loop); - Iter.iter (assert_lit self) lits - - let merge self n1 n2 expl = - assert (not self.in_loop); - Log.debugf 5 (fun k -> - k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" E_node.pp n1 - E_node.pp n2 Expl.pp expl); - assert (T.Ty.equal (T.Term.ty n1.n_term) (T.Term.ty n2.n_term)); - merge_classes self n1 n2 expl - - let merge_t self t1 t2 expl = - merge self (add_term self t1) (add_term self t2) expl - - let explain_eq self n1 n2 : Resolved_expl.t = - let st = Expl_state.create () in - explain_equal_rec_ self st n1 n2; - (* FIXME: also need to return the proof? *) - Expl_state.to_resolved_expl st - - let explain_expl (self : t) expl : Resolved_expl.t = - let expl_st = Expl_state.create () in - explain_decompose_expl self expl_st expl; - Expl_state.to_resolved_expl expl_st - - let[@inline] on_pre_merge self = Event.of_emitter self.on_pre_merge - let[@inline] on_pre_merge2 self = Event.of_emitter self.on_pre_merge2 - let[@inline] on_post_merge self = Event.of_emitter self.on_post_merge - let[@inline] on_new_term self = Event.of_emitter self.on_new_term - let[@inline] on_conflict self = Event.of_emitter self.on_conflict - let[@inline] on_propagate self = Event.of_emitter self.on_propagate - let[@inline] on_is_subterm self = Event.of_emitter self.on_is_subterm - - let create ?(stat = Stat.global) ?(size = `Big) (tst : term_store) - (proof : proof_trace) : t = - let size = - match size with - | `Small -> 128 - | `Big -> 2048 - in - let bitgen = Bits.mk_gen () in - let field_marked_explain = Bits.mk_field bitgen in - let rec cc = - { - tst; - proof; - tbl = T_tbl.create size; - signatures_tbl = Sig_tbl.create size; - bitgen; - on_pre_merge = Event.Emitter.create (); - on_pre_merge2 = Event.Emitter.create (); - on_post_merge = Event.Emitter.create (); - on_new_term = Event.Emitter.create (); - on_conflict = Event.Emitter.create (); - on_propagate = Event.Emitter.create (); - on_is_subterm = Event.Emitter.create (); - pending = Vec.create (); - combine = Vec.create (); - undo = Backtrack_stack.create (); - true_; - false_; - in_loop = false; - res_acts = Vec.create (); - field_marked_explain; - count_conflict = Stat.mk_int stat "cc.conflicts"; - count_props = Stat.mk_int stat "cc.propagations"; - count_merge = Stat.mk_int stat "cc.merges"; - } - and true_ = lazy (add_term cc (Term.bool tst true)) - and false_ = lazy (add_term cc (Term.bool tst false)) in - ignore (Lazy.force true_ : e_node); - ignore (Lazy.force false_ : e_node); - cc - - let[@inline] find_t self t : repr = - let n = T_tbl.find self.tbl t in - find_ n - - let pop_acts_ self = - let rec loop acc = - match Vec.pop self.res_acts with - | None -> acc - | Some x -> loop (x :: acc) - in - loop [] - - let check self : Result_action.or_conflict = - Log.debug 5 "(cc.check)"; - self.in_loop <- true; - let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in - try - update_tasks self; - let l = pop_acts_ self in - Ok l - with E_confl c -> Error c - - let check_inv_enabled_ = true (* XXX NUDGE *) - - (* check some internal invariants *) - let check_inv_ (self : t) : unit = - if check_inv_enabled_ then ( - Log.debug 2 "(cc.check-invariants)"; - all_classes self - |> Iter.flat_map E_node.iter_class - |> Iter.iter (fun n -> - match n.n_sig0 with - | None -> () - | Some s -> - let s' = update_sig s in - let ok = - match find_signature self s' with - | None -> false - | Some r -> E_node.equal r n.n_root - in - if not ok then - Log.debugf 0 (fun k -> - k "(@[cc.check.fail@ :n %a@ :sig %a@ :actual-sig %a@])" - E_node.pp n Signature.pp s Signature.pp s')) - ) - - (* model: return all the classes *) - let get_model (self : t) : repr Iter.t Iter.t = - check_inv_ self; - all_classes self |> Iter.map E_node.iter_class -end +module Make (A : ARG) : S = Core_cc.Make (A) diff --git a/src/cc/Sidekick_cc.mli b/src/cc/Sidekick_cc.mli index 2ecc963d..0eb9def5 100644 --- a/src/cc/Sidekick_cc.mli +++ b/src/cc/Sidekick_cc.mli @@ -1,15 +1,15 @@ (** Congruence Closure Implementation *) -module View = Sidekick_sigs_cc.View -open Sidekick_sigs_cc +open Sidekick_core +module View = View -module type ARG = ARG +module type ARG = Sigs.ARG module type S = sig - include S + include Sigs.S val create : - ?stat:Stat.t -> ?size:[ `Small | `Big ] -> term_store -> proof_trace -> t + ?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t (** Create a new congruence closure. @param term_store used to be able to create new terms. All terms @@ -26,8 +26,4 @@ module type S = sig (**/**) end -module Make (A : ARG) : - S - with module T = A.T - and module Lit = A.Lit - and module Proof_trace = A.Proof_trace +module Make (_ : ARG) : S diff --git a/src/cc/bits.ml b/src/cc/bits.ml new file mode 100644 index 00000000..3e376b56 --- /dev/null +++ b/src/cc/bits.ml @@ -0,0 +1,26 @@ +type bitfield_gen = int ref + +let max_width = Sys.word_size - 2 +let mk_gen () = ref 0 + +type t = int +type field = int + +let empty : t = 0 + +let mk_field (gen : bitfield_gen) : field = + let n = !gen in + if n > max_width then Error.errorf "maximum number of CC bitfields reached"; + incr gen; + 1 lsl n + +let[@inline] get field x = x land field <> 0 + +let[@inline] set field b x = + if b then + x lor field + else + x land lnot field + +let merge = ( lor ) +let equal : t -> t -> bool = CCEqual.poly diff --git a/src/cc/bits.mli b/src/cc/bits.mli new file mode 100644 index 00000000..1460ed8f --- /dev/null +++ b/src/cc/bits.mli @@ -0,0 +1,13 @@ +(** Basic bitfield *) + +type t = private int +type field +type bitfield_gen + +val empty : t +val equal : t -> t -> bool +val mk_field : bitfield_gen -> field +val mk_gen : unit -> bitfield_gen +val get : field -> t -> bool +val set : field -> bool -> t -> t +val merge : t -> t -> t diff --git a/src/cc/core_cc.ml b/src/cc/core_cc.ml new file mode 100644 index 00000000..0df1b40a --- /dev/null +++ b/src/cc/core_cc.ml @@ -0,0 +1,1136 @@ +(* actual implementation *) + +open Sidekick_core +open View + +module type ARG = Sigs.ARG + +module Make (A : ARG) : Sigs.S = struct + open struct + (* proof rules *) + module Rules_ = Proof_core + module P = Proof_trace + end + + type e_node = { + n_term: Term.t; + mutable n_sig0: signature option; (* initial signature *) + mutable n_bits: Bits.t; (* bitfield for various properties *) + mutable n_parents: e_node Bag.t; (* parent terms of this node *) + mutable n_root: e_node; + (* representative of congruence class (itself if a representative) *) + mutable n_next: e_node; (* pointer to next element of congruence class *) + mutable n_size: int; (* size of the class *) + mutable n_as_lit: Lit.t option; + (* TODO: put into payload? and only in root? *) + mutable n_expl: explanation_forest_link; + (* the rooted forest for explanations *) + } + (** A node of the congruence closure. + An equivalence class is represented by its "root" element, + the representative. *) + + and signature = (Const.t, e_node, e_node list) View.t + + and explanation_forest_link = + | FL_none + | FL_some of { next: e_node; expl: explanation } + + (* atomic explanation in the congruence closure *) + and explanation = + | E_trivial (* by pure reduction, tautologically equal *) + | E_lit of Lit.t (* because of this literal *) + | E_merge of e_node * e_node + | E_merge_t of Term.t * Term.t + | E_congruence of e_node * e_node (* caused by normal congruence *) + | E_and of explanation * explanation + | E_theory of + Term.t + * Term.t + * (Term.t * Term.t * explanation list) list + * Proof_term.step_id + + type repr = e_node + + module E_node = struct + type t = e_node + + let[@inline] equal (n1 : t) n2 = n1 == n2 + let[@inline] hash n = Term.hash n.n_term + let[@inline] term n = n.n_term + let[@inline] pp out n = Term.pp_debug out n.n_term + let[@inline] as_lit n = n.n_as_lit + + let make (t : Term.t) : t = + let rec n = + { + n_term = t; + n_sig0 = None; + n_bits = Bits.empty; + n_parents = Bag.empty; + n_as_lit = None; + (* TODO: provide a method to do it *) + n_root = n; + n_expl = FL_none; + n_next = n; + n_size = 1; + } + in + n + + let[@inline] is_root (n : e_node) : bool = n.n_root == n + + (* traverse the equivalence class of [n] *) + let iter_class_ (n : e_node) : e_node Iter.t = + fun yield -> + let rec aux u = + yield u; + if u.n_next != n then aux u.n_next + in + aux n + + let[@inline] iter_class n = + assert (is_root n); + iter_class_ n + + let[@inline] iter_parents (n : e_node) : e_node Iter.t = + assert (is_root n); + Bag.to_iter n.n_parents + + type bitfield = Bits.field + + let[@inline] get_field f t = Bits.get f t.n_bits + let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits + end + + (* non-recursive, inlinable function for [find] *) + let[@inline] find_ (n : e_node) : repr = + let n2 = n.n_root in + assert (E_node.is_root n2); + n2 + + let[@inline] same_class (n1 : e_node) (n2 : e_node) : bool = + E_node.equal (find_ n1) (find_ n2) + + let[@inline] find _ n = find_ n + + module Expl = struct + type t = explanation + + let rec pp out (e : explanation) = + match e with + | E_trivial -> Fmt.string out "reduction" + | E_lit lit -> Lit.pp out lit + | E_congruence (n1, n2) -> + Fmt.fprintf out "(@[congruence@ %a@ %a@])" E_node.pp n1 E_node.pp n2 + | E_merge (a, b) -> + Fmt.fprintf out "(@[merge@ %a@ %a@])" E_node.pp a E_node.pp b + | E_merge_t (a, b) -> + Fmt.fprintf out "(@[merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp_debug + a Term.pp_debug b + | E_theory (t, u, es, _) -> + Fmt.fprintf out "(@[th@ :t `%a`@ :u `%a`@ :expl_sets %a@])" + Term.pp_debug t Term.pp_debug u + (Util.pp_list + @@ Fmt.Dump.triple Term.pp_debug Term.pp_debug (Fmt.Dump.list pp)) + es + | E_and (a, b) -> Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b + + let mk_trivial : t = E_trivial + let[@inline] mk_congruence n1 n2 : t = E_congruence (n1, n2) + + let[@inline] mk_merge a b : t = + if E_node.equal a b then + mk_trivial + else + E_merge (a, b) + + let[@inline] mk_merge_t a b : t = + if Term.equal a b then + mk_trivial + else + E_merge_t (a, b) + + let[@inline] mk_lit l : t = E_lit l + let[@inline] mk_theory t u es pr = E_theory (t, u, es, pr) + + let rec mk_list l = + match l with + | [] -> mk_trivial + | [ x ] -> x + | E_trivial :: tl -> mk_list tl + | x :: y -> + (match mk_list y with + | E_trivial -> x + | y' -> E_and (x, y')) + end + + module Resolved_expl = struct + type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id } + + let pp out (self : t) = + Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits + end + + (** A signature is a shallow term shape where immediate subterms + are representative *) + module Signature = struct + type t = signature + + let equal (s1 : t) s2 : bool = + match s1, s2 with + | Bool b1, Bool b2 -> b1 = b2 + | App_fun (f1, []), App_fun (f2, []) -> Const.equal f1 f2 + | App_fun (f1, l1), App_fun (f2, l2) -> + Const.equal f1 f2 && CCList.equal E_node.equal l1 l2 + | App_ho (f1, a1), App_ho (f2, a2) -> + E_node.equal f1 f2 && E_node.equal a1 a2 + | Not a, Not b -> E_node.equal a b + | If (a1, b1, c1), If (a2, b2, c2) -> + E_node.equal a1 a2 && E_node.equal b1 b2 && E_node.equal c1 c2 + | Eq (a1, b1), Eq (a2, b2) -> E_node.equal a1 a2 && E_node.equal b1 b2 + | Opaque u1, Opaque u2 -> E_node.equal u1 u2 + | Bool _, _ + | App_fun _, _ + | App_ho _, _ + | If _, _ + | Eq _, _ + | Opaque _, _ + | Not _, _ -> + false + + let hash (s : t) : int = + let module H = CCHash in + match s with + | Bool b -> H.combine2 10 (H.bool b) + | App_fun (f, l) -> H.combine3 20 (Const.hash f) (H.list E_node.hash l) + | App_ho (f, a) -> H.combine3 30 (E_node.hash f) (E_node.hash a) + | Eq (a, b) -> H.combine3 40 (E_node.hash a) (E_node.hash b) + | Opaque u -> H.combine2 50 (E_node.hash u) + | If (a, b, c) -> + H.combine4 60 (E_node.hash a) (E_node.hash b) (E_node.hash c) + | Not u -> H.combine2 70 (E_node.hash u) + + let pp out = function + | Bool b -> Fmt.bool out b + | App_fun (f, []) -> Const.pp out f + | App_fun (f, l) -> + Fmt.fprintf out "(@[%a@ %a@])" Const.pp f (Util.pp_list E_node.pp) l + | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" E_node.pp f E_node.pp a + | Opaque t -> E_node.pp out t + | Not u -> Fmt.fprintf out "(@[not@ %a@])" E_node.pp u + | Eq (a, b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" E_node.pp a E_node.pp b + | If (a, b, c) -> + Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" E_node.pp a E_node.pp b + E_node.pp c + end + + module Sig_tbl = CCHashtbl.Make (Signature) + module T_tbl = CCHashtbl.Make (Term) + + type propagation_reason = unit -> Lit.t list * Proof_term.step_id + + module Handler_action = struct + type t = + | Act_merge of E_node.t * E_node.t * Expl.t + | Act_propagate of Lit.t * propagation_reason + + type conflict = Conflict of Expl.t [@@unboxed] + type or_conflict = (t list, conflict) result + end + + module Result_action = struct + type t = Act_propagate of { lit: Lit.t; reason: propagation_reason } + type conflict = Conflict of Lit.t list * Proof_term.step_id + type or_conflict = (t list, conflict) result + end + + type combine_task = + | CT_merge of e_node * e_node * explanation + | CT_act of Handler_action.t + + type t = { + tst: Term.store; + proof: Proof_trace.t; + tbl: e_node T_tbl.t; (* internalization [term -> e_node] *) + signatures_tbl: e_node Sig_tbl.t; + (* map a signature to the corresponding e_node in some equivalence class. + A signature is a [term_cell] in which every immediate subterm + that participates in the congruence/evaluation relation + is normalized (i.e. is its own representative). + The critical property is that all members of an equivalence class + that have the same "shape" (including head symbol) + have the same signature *) + pending: e_node Vec.t; + combine: combine_task Vec.t; + undo: (unit -> unit) Backtrack_stack.t; + bitgen: Bits.bitfield_gen; + field_marked_explain: Bits.field; + (* used to mark traversed nodes when looking for a common ancestor *) + true_: e_node lazy_t; + false_: e_node lazy_t; + mutable in_loop: bool; (* currently being modified? *) + res_acts: Result_action.t Vec.t; (* to return *) + on_pre_merge: + ( t * E_node.t * E_node.t * Expl.t, + Handler_action.or_conflict ) + Event.Emitter.t; + on_pre_merge2: + ( t * E_node.t * E_node.t * Expl.t, + Handler_action.or_conflict ) + Event.Emitter.t; + on_post_merge: + (t * E_node.t * E_node.t, Handler_action.t list) Event.Emitter.t; + on_new_term: (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t; + on_conflict: (ev_on_conflict, unit) Event.Emitter.t; + on_propagate: + (t * Lit.t * propagation_reason, Handler_action.t list) Event.Emitter.t; + on_is_subterm: + (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t; + count_conflict: int Stat.counter; + count_props: int Stat.counter; + count_merge: int Stat.counter; + } + (* TODO: an additional union-find to keep track, for each term, + of the terms they are known to be equal to, according + to the current explanation. That allows not to prove some equality + several times. + See "fast congruence closure and extensions", Nieuwenhuis&al, page 14 *) + + and ev_on_conflict = { cc: t; th: bool; c: Lit.t list } + + let[@inline] size_ (r : repr) = r.n_size + let[@inline] n_true self = Lazy.force self.true_ + let[@inline] n_false self = Lazy.force self.false_ + + let n_bool self b = + if b then + n_true self + else + n_false self + + let[@inline] term_store self = self.tst + let[@inline] proof self = self.proof + + let allocate_bitfield self ~descr = + Log.debugf 5 (fun k -> k "(@[cc.allocate-bit-field@ :descr %s@])" descr); + Bits.mk_field self.bitgen + + let[@inline] on_backtrack self f : unit = + Backtrack_stack.push_if_nonzero_level self.undo f + + let[@inline] get_bitfield _cc field n = E_node.get_field field n + + let set_bitfield self field b n = + let old = E_node.get_field field n in + if old <> b then ( + on_backtrack self (fun () -> E_node.set_field field old n); + E_node.set_field field b n + ) + + (* check if [t] is in the congruence closure. + Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) + let[@inline] mem (self : t) (t : Term.t) : bool = T_tbl.mem self.tbl t + + module Debug_ = struct + (* print full state *) + let pp out (self : t) : unit = + let pp_next out n = Fmt.fprintf out "@ :next %a" E_node.pp n.n_next in + let pp_root out n = + if E_node.is_root n then + Fmt.string out " :is-root" + else + Fmt.fprintf out "@ :root %a" E_node.pp n.n_root + in + let pp_expl out n = + match n.n_expl with + | FL_none -> () + | FL_some e -> + Fmt.fprintf out " (@[:forest %a :expl %a@])" E_node.pp e.next Expl.pp + e.expl + in + let pp_n out n = + Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp_debug n.n_term pp_root n + pp_next n pp_expl n + and pp_sig_e out (s, n) = + Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s E_node.pp n + pp_root n + in + Fmt.fprintf out + "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ \ + %a@])@])" + (Util.pp_iter ~sep:" " pp_n) + (T_tbl.values self.tbl) + (Util.pp_iter ~sep:" " pp_sig_e) + (Sig_tbl.to_iter self.signatures_tbl) + end + + (* compute up-to-date signature *) + let update_sig (s : signature) : Signature.t = + View.map_view s ~f_f:(fun x -> x) ~f_t:find_ ~f_ts:(List.map find_) + + (* find whether the given (parent) term corresponds to some signature + in [signatures_] *) + let[@inline] find_signature cc (s : signature) : repr option = + Sig_tbl.get cc.signatures_tbl s + + (* add to signature table. Assume it's not present already *) + let add_signature self (s : signature) (n : e_node) : unit = + assert (not @@ Sig_tbl.mem self.signatures_tbl s); + Log.debugf 50 (fun k -> + k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s E_node.pp n); + on_backtrack self (fun () -> Sig_tbl.remove self.signatures_tbl s); + Sig_tbl.add self.signatures_tbl s n + + let push_pending self t : unit = + Log.debugf 50 (fun k -> k "(@[cc.push-pending@ %a@])" E_node.pp t); + Vec.push self.pending t + + let push_action self (a : Handler_action.t) : unit = + Vec.push self.combine (CT_act a) + + let push_action_l self (l : _ list) : unit = List.iter (push_action self) l + + let merge_classes self t u e : unit = + if t != u && not (same_class t u) then ( + Log.debugf 50 (fun k -> + k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" E_node.pp t + E_node.pp u Expl.pp e); + Vec.push self.combine @@ CT_merge (t, u, e) + ) + + (* re-root the explanation tree of the equivalence class of [n] + so that it points to [n]. + postcondition: [n.n_expl = None] *) + let[@unroll 2] rec reroot_expl (self : t) (n : e_node) : unit = + match n.n_expl with + | FL_none -> () (* already root *) + | FL_some { next = u; expl = e_n_u } -> + (* reroot to [u], then invert link between [u] and [n] *) + reroot_expl self u; + u.n_expl <- FL_some { next = n; expl = e_n_u }; + n.n_expl <- FL_none + + exception E_confl of Result_action.conflict + + let raise_conflict_ (cc : t) ~th (e : Lit.t list) (p : Proof_term.step_id) : _ + = + Profile.instant "cc.conflict"; + (* clear tasks queue *) + Vec.clear cc.pending; + Vec.clear cc.combine; + Event.emit cc.on_conflict { cc; th; c = e }; + Stat.incr cc.count_conflict; + raise (E_confl (Conflict (e, p))) + + let[@inline] all_classes self : repr Iter.t = + T_tbl.values self.tbl |> Iter.filter E_node.is_root + + (* find the closest common ancestor of [a] and [b] in the proof forest. + + Precond: + - [a] and [b] are in the same class + - no e_node has the flag [field_marked_explain] on + Invariants: + - if [n] is marked, then all the predecessors of [n] + from [a] or [b] are marked too. + *) + let find_common_ancestor self (a : e_node) (b : e_node) : e_node = + (* catch up to the other e_node *) + let rec find1 a = + if E_node.get_field self.field_marked_explain a then + a + else ( + match a.n_expl with + | FL_none -> assert false + | FL_some r -> find1 r.next + ) + in + let rec find2 a b = + if E_node.equal a b then + a + else if E_node.get_field self.field_marked_explain a then + a + else if E_node.get_field self.field_marked_explain b then + b + else ( + E_node.set_field self.field_marked_explain true a; + E_node.set_field self.field_marked_explain true b; + match a.n_expl, b.n_expl with + | FL_some r1, FL_some r2 -> find2 r1.next r2.next + | FL_some r, FL_none -> find1 r.next + | FL_none, FL_some r -> find1 r.next + | FL_none, FL_none -> assert false + (* no common ancestor *) + ) + in + + (* cleanup tags on nodes traversed in [find2] *) + let rec cleanup_ n = + if E_node.get_field self.field_marked_explain n then ( + E_node.set_field self.field_marked_explain false n; + match n.n_expl with + | FL_none -> () + | FL_some { next; _ } -> cleanup_ next + ) + in + let n = find2 a b in + cleanup_ a; + cleanup_ b; + n + + module Expl_state = struct + type t = { + mutable lits: Lit.t list; + mutable th_lemmas: + (Lit.t * (Lit.t * Lit.t list) list * Proof_term.step_id) list; + } + + let create () : t = { lits = []; th_lemmas = [] } + let[@inline] copy self : t = { self with lits = self.lits } + let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits + + let[@inline] add_th (self : t) lit hyps pr : unit = + self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas + + let merge self other = + let { lits = o_lits; th_lemmas = o_lemmas } = other in + self.lits <- List.rev_append o_lits self.lits; + self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas; + () + + (* proof of [\/_i ¬lits[i]] *) + let proof_of_th_lemmas (self : t) (proof : Proof_trace.t) : + Proof_term.step_id = + let p_lits1 = Iter.of_list self.lits |> Iter.map Lit.neg in + let p_lits2 = + Iter.of_list self.th_lemmas + |> Iter.map (fun (lit_t_u, _, _) -> Lit.neg lit_t_u) + in + let p_cc = + P.add_step proof @@ Rules_.lemma_cc (Iter.append p_lits1 p_lits2) + in + let resolve_with_th_proof pr (lit_t_u, sub_proofs, pr_th) = + (* pr_th: [sub_proofs |- t=u]. + now resolve away [sub_proofs] to get literals that were + asserted in the congruence closure *) + let pr_th = + List.fold_left + (fun pr_th (lit_i, hyps_i) -> + (* [hyps_i |- lit_i] *) + let lemma_i = + P.add_step proof + @@ Rules_.lemma_cc + Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) + in + (* resolve [lit_i] away. *) + P.add_step proof + @@ Rules_.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th) + pr_th sub_proofs + in + P.add_step proof @@ Rules_.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr + in + (* resolve with theory proofs responsible for some merges, if any. *) + List.fold_left resolve_with_th_proof p_cc self.th_lemmas + + let to_resolved_expl (self : t) : Resolved_expl.t = + (* FIXME: package the th lemmas too *) + let { lits; th_lemmas = _ } = self in + let s2 = copy self in + let pr proof = proof_of_th_lemmas s2 proof in + { Resolved_expl.lits; pr } + end + + (* decompose explanation [e] into a list of literals added to [acc] *) + let rec explain_decompose_expl self (st : Expl_state.t) (e : explanation) : + unit = + Log.debugf 5 (fun k -> k "(@[cc.decompose_expl@ %a@])" Expl.pp e); + match e with + | E_trivial -> () + | E_congruence (n1, n2) -> + (match n1.n_sig0, n2.n_sig0 with + | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> + assert (Const.equal f1 f2); + assert (List.length a1 = List.length a2); + List.iter2 (explain_equal_rec_ self st) a1 a2 + | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> + explain_equal_rec_ self st f1 f2; + explain_equal_rec_ self st a1 a2 + | Some (If (a1, b1, c1)), Some (If (a2, b2, c2)) -> + explain_equal_rec_ self st a1 a2; + explain_equal_rec_ self st b1 b2; + explain_equal_rec_ self st c1 c2 + | _ -> assert false) + | E_lit lit -> Expl_state.add_lit st lit + | E_theory (t, u, expl_sets, pr) -> + let sub_proofs = + List.map + (fun (t_i, u_i, expls_i) -> + let lit_i = Lit.make_eq self.tst t_i u_i in + (* use a separate call to [explain_expls] for each set *) + let sub = explain_expls self expls_i in + Expl_state.merge st sub; + lit_i, sub.lits) + expl_sets + in + let lit_t_u = Lit.make_eq self.tst t u in + Expl_state.add_th st lit_t_u sub_proofs pr + | E_merge (a, b) -> explain_equal_rec_ self st a b + | E_merge_t (a, b) -> + (* find nodes for [a] and [b] on the fly *) + (match T_tbl.find self.tbl a, T_tbl.find self.tbl b with + | a, b -> explain_equal_rec_ self st a b + | exception Not_found -> + Error.errorf "expl: cannot find e_node(s) for %a, %a" Term.pp_debug a + Term.pp_debug b) + | E_and (a, b) -> + explain_decompose_expl self st a; + explain_decompose_expl self st b + + and explain_expls self (es : explanation list) : Expl_state.t = + let st = Expl_state.create () in + List.iter (explain_decompose_expl self st) es; + st + + and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) + : unit = + Log.debugf 5 (fun k -> + k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b); + assert (E_node.equal (find_ a) (find_ b)); + let ancestor = find_common_ancestor cc a b in + explain_along_path cc st a ancestor; + explain_along_path cc st b ancestor + + (* explain why [a = parent_a], where [a -> ... -> target] in the + proof forest *) + and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) + : unit = + let rec aux n = + if n == target then + () + else ( + match n.n_expl with + | FL_none -> assert false + | FL_some { next = next_n; expl } -> + explain_decompose_expl self st expl; + (* now prove [next_n = target] *) + aux next_n + ) + in + aux a + + (* add a term *) + let[@inline] rec add_term_rec_ self t : e_node = + match T_tbl.find self.tbl t with + | n -> n + | exception Not_found -> add_new_term_ self t + + (* add [t] when not present already *) + and add_new_term_ self (t : Term.t) : e_node = + assert (not @@ mem self t); + Log.debugf 15 (fun k -> k "(@[cc.add-term@ %a@])" Term.pp_debug t); + let n = E_node.make t in + (* register sub-terms, add [t] to their parent list, and return the + corresponding initial signature *) + let sig0 = compute_sig0 self n in + n.n_sig0 <- sig0; + (* remove term when we backtrack *) + on_backtrack self (fun () -> + Log.debugf 30 (fun k -> k "(@[cc.remove-term@ %a@])" Term.pp_debug t); + T_tbl.remove self.tbl t); + (* add term to the table *) + T_tbl.add self.tbl t n; + if Option.is_some sig0 then + (* [n] might be merged with other equiv classes *) + push_pending self n; + Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self); + n + + (* compute the initial signature of the given e_node *) + and compute_sig0 (self : t) (n : e_node) : Signature.t option = + (* add sub-term to [cc], and register [n] to its parents. + Note that we return the exact sub-term, to get proper + explanations, but we add to the sub-term's root's parent list. *) + let deref_sub (u : Term.t) : e_node = + let sub = add_term_rec_ self u in + (* add [n] to [sub.root]'s parent list *) + (let sub_r = find_ sub in + let old_parents = sub_r.n_parents in + if Bag.is_empty old_parents then + (* first time it has parents: tell watchers that this is a subterm *) + Event.emit_iter self.on_is_subterm (self, sub, u) + ~f:(push_action_l self); + on_backtrack self (fun () -> sub_r.n_parents <- old_parents); + sub_r.n_parents <- Bag.cons n sub_r.n_parents); + sub + in + let[@inline] return x = Some x in + match A.view_as_cc n.n_term with + | Bool _ | Opaque _ -> None + | Eq (a, b) -> + let a = deref_sub a in + let b = deref_sub b in + return @@ Eq (a, b) + | Not u -> return @@ Not (deref_sub u) + | App_fun (f, args) -> + let args = args |> Iter.map deref_sub |> Iter.to_list in + if args <> [] then + return @@ App_fun (f, args) + else + None + | App_ho (f, a) -> + let f = deref_sub f in + let a = deref_sub a in + return @@ App_ho (f, a) + | If (a, b, c) -> return @@ If (deref_sub a, deref_sub b, deref_sub c) + + let[@inline] add_term self t : e_node = add_term_rec_ self t + let mem_term = mem + + let set_as_lit self (n : e_node) (lit : Lit.t) : unit = + match n.n_as_lit with + | Some _ -> () + | None -> + Log.debugf 15 (fun k -> + k "(@[cc.set-as-lit@ %a@ %a@])" E_node.pp n Lit.pp lit); + on_backtrack self (fun () -> n.n_as_lit <- None); + n.n_as_lit <- Some lit + + (* is [n] true or false? *) + let n_is_bool_value (self : t) n : bool = + E_node.equal n (n_true self) || E_node.equal n (n_false self) + + (* gather a pair [lits, pr], where [lits] is the set of + asserted literals needed in the explanation (which is useful for + the SAT solver), and [pr] is a proof, including sub-proofs for theory + merges. *) + let lits_and_proof_of_expl (self : t) (st : Expl_state.t) : + Lit.t list * Proof_term.step_id = + let { Expl_state.lits; th_lemmas = _ } = st in + let pr = Expl_state.proof_of_th_lemmas st self.proof in + lits, pr + + (* main CC algo: add terms from [pending] to the signature table, + check for collisions *) + let rec update_tasks (self : t) : unit = + while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do + while not @@ Vec.is_empty self.pending do + task_pending_ self (Vec.pop_exn self.pending) + done; + while not @@ Vec.is_empty self.combine do + task_combine_ self (Vec.pop_exn self.combine) + done + done + + and task_pending_ self (n : e_node) : unit = + (* check if some parent collided *) + match n.n_sig0 with + | None -> () (* no-op *) + | Some (Eq (a, b)) -> + (* if [a=b] is now true, merge [(a=b)] and [true] *) + if same_class a b then ( + let expl = Expl.mk_merge a b in + Log.debugf 5 (fun k -> + k "(@[cc.pending.eq@ %a@ :r1 %a@ :r2 %a@])" E_node.pp n E_node.pp a + E_node.pp b); + merge_classes self n (n_true self) expl + ) + | Some (Not u) -> + (* [u = bool ==> not u = not bool] *) + let r_u = find_ u in + if E_node.equal r_u (n_true self) then ( + let expl = Expl.mk_merge u (n_true self) in + merge_classes self n (n_false self) expl + ) else if E_node.equal r_u (n_false self) then ( + let expl = Expl.mk_merge u (n_false self) in + merge_classes self n (n_true self) expl + ) + | Some s0 -> + (* update the signature by using [find] on each sub-e_node *) + let s = update_sig s0 in + (match find_signature self s with + | None -> + (* add to the signature table [sig(n) --> n] *) + add_signature self s n + | Some u when E_node.equal n u -> () + | Some u -> + (* [t1] and [t2] must be applications of the same symbol to + arguments that are pairwise equal *) + assert (n != u); + let expl = Expl.mk_congruence n u in + merge_classes self n u expl) + + and task_combine_ self = function + | CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab + | CT_act (Handler_action.Act_merge (t, u, e)) -> task_merge_ self t u e + | CT_act (Handler_action.Act_propagate (lit, reason)) -> + (* will return this propagation to the caller *) + Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }) + + (* main CC algo: merge equivalence classes in [st.combine]. + @raise Exn_unsat if merge fails *) + and task_merge_ self a b e_ab : unit = + let ra = find_ a in + let rb = find_ b in + if not @@ E_node.equal ra rb then ( + assert (E_node.is_root ra); + assert (E_node.is_root rb); + Stat.incr self.count_merge; + (* check we're not merging [true] and [false] *) + if + (E_node.equal ra (n_true self) && E_node.equal rb (n_false self)) + || (E_node.equal rb (n_true self) && E_node.equal ra (n_false self)) + then ( + Log.debugf 5 (fun k -> + k + "(@[cc.merge.true_false_conflict@ @[:r1 %a@ :t1 %a@]@ @[:r2 \ + %a@ :t2 %a@]@ :e_ab %a@])" + E_node.pp ra E_node.pp a E_node.pp rb E_node.pp b Expl.pp e_ab); + let th = ref false in + (* TODO: + C1: P.true_neq_false + C2: lemma [lits |- true=false] (and resolve on theory proofs) + C3: r1 C1 C2 + *) + let expl_st = Expl_state.create () in + explain_decompose_expl self expl_st e_ab; + explain_equal_rec_ self expl_st a ra; + explain_equal_rec_ self expl_st b rb; + + (* regular conflict *) + let lits, pr = lits_and_proof_of_expl self expl_st in + raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr + ); + (* We will merge [r_from] into [r_into]. + we try to ensure that [size ra <= size rb] in general, but always + keep values as representative *) + let r_from, r_into = + if n_is_bool_value self ra then + rb, ra + else if n_is_bool_value self rb then + ra, rb + else if size_ ra > size_ rb then + rb, ra + else + ra, rb + in + (* when merging terms with [true] or [false], possibly propagate them to SAT *) + let merge_bool r1 t1 r2 t2 = + if E_node.equal r1 (n_true self) then + propagate_bools self r2 t2 r1 t1 e_ab true + else if E_node.equal r1 (n_false self) then + propagate_bools self r2 t2 r1 t1 e_ab false + in + + merge_bool ra a rb b; + merge_bool rb b ra a; + + (* perform [union r_from r_into] *) + Log.debugf 15 (fun k -> + k "(@[cc.merge@ :from %a@ :into %a@])" E_node.pp r_from E_node.pp + r_into); + + (* call [on_pre_merge] functions, and merge theory data items *) + (* explanation is [a=ra & e_ab & b=rb] *) + (let expl = + Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ] + in + + let handle_act = function + | Ok l -> push_action_l self l + | Error (Handler_action.Conflict expl) -> + raise_conflict_from_expl self expl + in + + Event.emit_iter self.on_pre_merge + (self, r_into, r_from, expl) + ~f:handle_act; + Event.emit_iter self.on_pre_merge2 + (self, r_into, r_from, expl) + ~f:handle_act); + + (* TODO: merge plugin data here, _after_ the pre-merge hooks are called, + so they have a chance of observing pre-merge plugin data *) + ((* parents might have a different signature, check for collisions *) + E_node.iter_parents r_from (fun parent -> push_pending self parent); + (* for each e_node in [r_from]'s class, make it point to [r_into] *) + E_node.iter_class r_from (fun u -> + assert (u.n_root == r_from); + u.n_root <- r_into); + (* capture current state *) + let r_into_old_next = r_into.n_next in + let r_from_old_next = r_from.n_next in + let r_into_old_parents = r_into.n_parents in + let r_into_old_bits = r_into.n_bits in + (* swap [into.next] and [from.next], merging the classes *) + r_into.n_next <- r_from_old_next; + r_from.n_next <- r_into_old_next; + r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; + r_into.n_size <- r_into.n_size + r_from.n_size; + r_into.n_bits <- Bits.merge r_into.n_bits r_from.n_bits; + (* on backtrack, unmerge classes and restore the pointers to [r_from] *) + on_backtrack self (fun () -> + Log.debugf 30 (fun k -> + k "(@[cc.undo_merge@ :from %a@ :into %a@])" E_node.pp r_from + E_node.pp r_into); + r_into.n_bits <- r_into_old_bits; + r_into.n_next <- r_into_old_next; + r_from.n_next <- r_from_old_next; + r_into.n_parents <- r_into_old_parents; + (* NOTE: this must come after the restoration of [next] pointers, + otherwise we'd iterate on too big a class *) + E_node.iter_class_ r_from (fun u -> u.n_root <- r_from); + r_into.n_size <- r_into.n_size - r_from.n_size)); + + (* update explanations (a -> b), arbitrarily. + Note that here we merge the classes by adding a bridge between [a] + and [b], not their roots. *) + reroot_expl self a; + assert (a.n_expl = FL_none); + (* on backtracking, link may be inverted, but we delete the one + that bridges between [a] and [b] *) + on_backtrack self (fun () -> + match a.n_expl, b.n_expl with + | FL_some e, _ when E_node.equal e.next b -> a.n_expl <- FL_none + | _, FL_some e when E_node.equal e.next a -> b.n_expl <- FL_none + | _ -> assert false); + a.n_expl <- FL_some { next = b; expl = e_ab }; + (* call [on_post_merge] *) + Event.emit_iter self.on_post_merge (self, r_into, r_from) + ~f:(push_action_l self) + ) + + (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] + in the equiv class of [r1] that is a known literal back to the SAT solver + and which is not the one initially merged. + We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) + and propagate_bools self r1 t1 r2 t2 (e_12 : explanation) sign : unit = + (* explanation for [t1 =e= t2 = r2] *) + let half_expl_and_pr = + lazy + (let st = Expl_state.create () in + explain_decompose_expl self st e_12; + explain_equal_rec_ self st r2 t2; + st) + in + (* TODO: flag per class, `or`-ed on merge, to indicate if the class + contains at least one lit *) + E_node.iter_class r1 (fun u1 -> + (* propagate if: + - [u1] is a proper literal + - [t2 != r2], because that can only happen + after an explicit merge (no way to obtain that by propagation) + *) + match E_node.as_lit u1 with + | Some lit when not (E_node.equal r2 t2) -> + let lit = + if sign then + lit + else + Lit.neg lit + in + (* apply sign *) + Log.debugf 5 (fun k -> k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); + (* complete explanation with the [u1=t1] chunk *) + let (lazy st) = half_expl_and_pr in + let st = Expl_state.copy st in + (* do not modify shared st *) + explain_equal_rec_ self st u1 t1; + + (* propagate only if this doesn't depend on some semantic values *) + let reason () = + (* true literals explaining why t1=t2 *) + let guard = st.lits in + (* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *) + Expl_state.add_lit st (Lit.neg lit); + let _, pr = lits_and_proof_of_expl self st in + guard, pr + in + Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }); + Event.emit_iter self.on_propagate (self, lit, reason) + ~f:(push_action_l self); + Stat.incr self.count_props + | _ -> ()) + + (* raise a conflict from an explanation, typically from an event handler. + Raises E_confl with a result conflict. *) + and raise_conflict_from_expl self (expl : Expl.t) : 'a = + Log.debugf 5 (fun k -> + k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); + let st = Expl_state.create () in + explain_decompose_expl self st expl; + let lits, pr = lits_and_proof_of_expl self st in + let c = List.rev_map Lit.neg lits in + let th = st.th_lemmas <> [] in + raise_conflict_ self ~th c pr + + let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t) + + let push_level (self : t) : unit = + assert (not self.in_loop); + Backtrack_stack.push_level self.undo + + let pop_levels (self : t) n : unit = + assert (not self.in_loop); + Vec.clear self.pending; + Vec.clear self.combine; + Log.debugf 15 (fun k -> + k "(@[cc.pop-levels %d@ :n-lvls %d@])" n + (Backtrack_stack.n_levels self.undo)); + Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f ()); + () + + let assert_eq self t u expl : unit = + assert (not self.in_loop); + let t = add_term self t in + let u = add_term self u in + (* merge [a] and [b] *) + merge_classes self t u expl + + (* assert that this boolean literal holds. + if a lit is [= a b], merge [a] and [b]; + otherwise merge the atom with true/false *) + let assert_lit self lit : unit = + assert (not self.in_loop); + let t = Lit.term lit in + Log.debugf 15 (fun k -> k "(@[cc.assert-lit@ %a@])" Lit.pp lit); + let sign = Lit.sign lit in + match A.view_as_cc t with + | Eq (a, b) when sign -> assert_eq self a b (Expl.mk_lit lit) + | _ -> + (* equate t and true/false *) + let rhs = n_bool self sign in + let n = add_term self t in + (* TODO: ensure that this is O(1). + basically, just have [n] point to true/false and thus acquire + the corresponding value, so its superterms (like [ite]) can evaluate + properly *) + (* TODO: use oriented merge (force direction [n -> rhs]) *) + merge_classes self n rhs (Expl.mk_lit lit) + + let[@inline] assert_lits self lits : unit = + assert (not self.in_loop); + Iter.iter (assert_lit self) lits + + let merge self n1 n2 expl = + assert (not self.in_loop); + Log.debugf 5 (fun k -> + k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" E_node.pp n1 + E_node.pp n2 Expl.pp expl); + assert (Term.equal (Term.ty n1.n_term) (Term.ty n2.n_term)); + merge_classes self n1 n2 expl + + let merge_t self t1 t2 expl = + merge self (add_term self t1) (add_term self t2) expl + + let explain_eq self n1 n2 : Resolved_expl.t = + let st = Expl_state.create () in + explain_equal_rec_ self st n1 n2; + (* FIXME: also need to return the proof? *) + Expl_state.to_resolved_expl st + + let explain_expl (self : t) expl : Resolved_expl.t = + let expl_st = Expl_state.create () in + explain_decompose_expl self expl_st expl; + Expl_state.to_resolved_expl expl_st + + let[@inline] on_pre_merge self = Event.of_emitter self.on_pre_merge + let[@inline] on_pre_merge2 self = Event.of_emitter self.on_pre_merge2 + let[@inline] on_post_merge self = Event.of_emitter self.on_post_merge + let[@inline] on_new_term self = Event.of_emitter self.on_new_term + let[@inline] on_conflict self = Event.of_emitter self.on_conflict + let[@inline] on_propagate self = Event.of_emitter self.on_propagate + let[@inline] on_is_subterm self = Event.of_emitter self.on_is_subterm + + let create ?(stat = Stat.global) ?(size = `Big) (tst : Term.store) + (proof : Proof_trace.t) : t = + let size = + match size with + | `Small -> 128 + | `Big -> 2048 + in + let bitgen = Bits.mk_gen () in + let field_marked_explain = Bits.mk_field bitgen in + let rec cc = + { + tst; + proof; + tbl = T_tbl.create size; + signatures_tbl = Sig_tbl.create size; + bitgen; + on_pre_merge = Event.Emitter.create (); + on_pre_merge2 = Event.Emitter.create (); + on_post_merge = Event.Emitter.create (); + on_new_term = Event.Emitter.create (); + on_conflict = Event.Emitter.create (); + on_propagate = Event.Emitter.create (); + on_is_subterm = Event.Emitter.create (); + pending = Vec.create (); + combine = Vec.create (); + undo = Backtrack_stack.create (); + true_; + false_; + in_loop = false; + res_acts = Vec.create (); + field_marked_explain; + count_conflict = Stat.mk_int stat "cc.conflicts"; + count_props = Stat.mk_int stat "cc.propagations"; + count_merge = Stat.mk_int stat "cc.merges"; + } + and true_ = lazy (add_term cc (Term.true_ tst)) + and false_ = lazy (add_term cc (Term.false_ tst)) in + ignore (Lazy.force true_ : e_node); + ignore (Lazy.force false_ : e_node); + cc + + let[@inline] find_t self t : repr = + let n = T_tbl.find self.tbl t in + find_ n + + let pop_acts_ self = + let rec loop acc = + match Vec.pop self.res_acts with + | None -> acc + | Some x -> loop (x :: acc) + in + loop [] + + let check self : Result_action.or_conflict = + Log.debug 5 "(cc.check)"; + self.in_loop <- true; + let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in + try + update_tasks self; + let l = pop_acts_ self in + Ok l + with E_confl c -> Error c + + let check_inv_enabled_ = true (* XXX NUDGE *) + + (* check some internal invariants *) + let check_inv_ (self : t) : unit = + if check_inv_enabled_ then ( + Log.debug 2 "(cc.check-invariants)"; + all_classes self + |> Iter.flat_map E_node.iter_class + |> Iter.iter (fun n -> + match n.n_sig0 with + | None -> () + | Some s -> + let s' = update_sig s in + let ok = + match find_signature self s' with + | None -> false + | Some r -> E_node.equal r n.n_root + in + if not ok then + Log.debugf 0 (fun k -> + k "(@[cc.check.fail@ :n %a@ :sig %a@ :actual-sig %a@])" + E_node.pp n Signature.pp s Signature.pp s')) + ) + + (* model: return all the classes *) + let get_model (self : t) : repr Iter.t Iter.t = + check_inv_ self; + all_classes self |> Iter.map E_node.iter_class +end diff --git a/src/cc/dune b/src/cc/dune index b33f850d..d249010d 100644 --- a/src/cc/dune +++ b/src/cc/dune @@ -1,5 +1,7 @@ (library (name Sidekick_cc) (public_name sidekick.cc) - (libraries containers iter sidekick.sigs sidekick.sigs.cc sidekick.util) - (flags :standard -warn-error -a+8 -w -32 -open Sidekick_util)) + (synopsis "main congruence closure implementation") + (private_modules core_cc) + (libraries containers iter sidekick.sigs sidekick.core sidekick.util) + (flags :standard -open Sidekick_util)) diff --git a/src/cc/mini/Sidekick_mini_cc.ml b/src/cc/mini/Sidekick_mini_cc.ml index 6decc650..059ad1b5 100644 --- a/src/cc/mini/Sidekick_mini_cc.ml +++ b/src/cc/mini/Sidekick_mini_cc.ml @@ -1,46 +1,34 @@ -module CC_view = Sidekick_sigs_cc.View - -module type TERM = Sidekick_sigs_term.S +open Sidekick_core +module CC_view = Sidekick_cc.View module type ARG = sig - module T : TERM - - val view_as_cc : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t + val view_as_cc : Term.t -> (Const.t, Term.t, Term.t Iter.t) CC_view.t end module type S = sig - type term - type fun_ - type term_store type t - val create : term_store -> t + val create : Term.store -> t val clear : t -> unit - val add_lit : t -> term -> bool -> unit + val add_lit : t -> Term.t -> bool -> unit val check_sat : t -> bool - val classes : t -> term Iter.t Iter.t + val classes : t -> Term.t Iter.t Iter.t end module Make (A : ARG) = struct open CC_view - module Fun = A.T.Fun - module T = A.T.Term - - type fun_ = A.T.Fun.t - type term = T.t - type term_store = T.store - - module T_tbl = CCHashtbl.Make (T) + module T = Term + module T_tbl = Term.Tbl type node = { - n_t: term; + n_t: Term.t; mutable n_next: node; (* next in class *) mutable n_size: int; (* size of class *) mutable n_parents: node list; mutable n_root: node; (* root of the class *) } - type signature = (fun_, node, node list) CC_view.t + type signature = (Const.t, node, node list) CC_view.t module Node = struct type t = node @@ -51,7 +39,7 @@ module Make (A : ARG) = struct let[@inline] is_root n = n == n.n_root let[@inline] root n = n.n_root let[@inline] term n = n.n_t - let pp out n = T.pp out n.n_t + let pp out n = T.pp_debug out n.n_t let add_parent (self : t) ~p : unit = self.n_parents <- p :: self.n_parents let make (t : T.t) : t = @@ -79,9 +67,9 @@ module Make (A : ARG) = struct let equal (s1 : t) s2 : bool = match s1, s2 with | Bool b1, Bool b2 -> b1 = b2 - | App_fun (f1, []), App_fun (f2, []) -> Fun.equal f1 f2 + | App_fun (f1, []), App_fun (f2, []) -> Const.equal f1 f2 | App_fun (f1, l1), App_fun (f2, l2) -> - Fun.equal f1 f2 && CCList.equal Node.equal l1 l2 + Const.equal f1 f2 && CCList.equal Node.equal l1 l2 | App_ho (f1, a1), App_ho (f2, a2) -> Node.equal f1 f2 && Node.equal a1 a2 | Not n1, Not n2 -> Node.equal n1 n2 | If (a1, b1, c1), If (a2, b2, c2) -> @@ -101,7 +89,7 @@ module Make (A : ARG) = struct let module H = CCHash in match s with | Bool b -> H.combine2 10 (H.bool b) - | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l) + | App_fun (f, l) -> H.combine3 20 (Const.hash f) (H.list Node.hash l) | App_ho (f, a) -> H.combine3 30 (Node.hash f) (Node.hash a) | Eq (a, b) -> H.combine3 40 (Node.hash a) (Node.hash b) | Opaque u -> H.combine2 50 (Node.hash u) @@ -110,9 +98,9 @@ module Make (A : ARG) = struct let pp out = function | Bool b -> Fmt.bool out b - | App_fun (f, []) -> Fun.pp out f + | App_fun (f, []) -> Const.pp out f | App_fun (f, l) -> - Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l + Fmt.fprintf out "(@[%a@ %a@])" Const.pp f (Util.pp_list Node.pp) l | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" Node.pp f Node.pp a | Opaque t -> Node.pp out t | Not u -> Fmt.fprintf out "(@[not@ %a@])" Node.pp u @@ -134,8 +122,8 @@ module Make (A : ARG) = struct } let create tst : t = - let true_ = T.bool tst true in - let false_ = T.bool tst false in + let true_ = Term.true_ tst in + let false_ = Term.false_ tst in let self = { ok = true; @@ -180,7 +168,7 @@ module Make (A : ARG) = struct k b; k c - let rec add_t (self : t) (t : term) : node = + let rec add_t (self : t) (t : Term.t) : node = match T_tbl.find self.tbl t with | n -> n | exception Not_found -> @@ -194,9 +182,10 @@ module Make (A : ARG) = struct self.pending <- node :: self.pending; node - let find_t_ (self : t) (t : term) : node = + let find_t_ (self : t) (t : Term.t) : node = try T_tbl.find self.tbl t |> Node.root - with Not_found -> Error.errorf "mini-cc.find_t: no node for %a" T.pp t + with Not_found -> + Error.errorf "mini-cc.find_t: no node for %a" T.pp_debug t exception E_unsat diff --git a/src/cc/mini/Sidekick_mini_cc.mli b/src/cc/mini/Sidekick_mini_cc.mli index 413d2518..fd4b4493 100644 --- a/src/cc/mini/Sidekick_mini_cc.mli +++ b/src/cc/mini/Sidekick_mini_cc.mli @@ -5,35 +5,28 @@ It just decides the satisfiability of a set of (dis)equations. *) -module CC_view = Sidekick_sigs_cc.View - -module type TERM = Sidekick_sigs_term.S +open Sidekick_core +module CC_view = Sidekick_cc.View (** Argument for the functor {!Make} - It only requires a term structure, and a congruence-oriented view. *) + It only requires a Term.t structure, and a congruence-oriented view. *) module type ARG = sig - module T : TERM - - val view_as_cc : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t + val view_as_cc : Term.t -> (Const.t, Term.t, Term.t Iter.t) CC_view.t end (** Main signature for an instance of the mini congruence closure *) module type S = sig - type term - type fun_ - type term_store - type t (** An instance of the congruence closure. Mutable *) - val create : term_store -> t + val create : Term.store -> t (** New instance *) val clear : t -> unit (** Fully reset the congruence closure's state *) - val add_lit : t -> term -> bool -> unit + val add_lit : t -> Term.t -> bool -> unit (** [add_lit cc p sign] asserts that [p] is true if [sign], or [p] is false if [not sign]. If [p] is an equation and [sign] is [true], this adds a new equation to the congruence relation. *) @@ -42,14 +35,10 @@ module type S = sig (** [check_sat cc] returns [true] if the current state is satisfiable, [false] if it's unsatisfiable. *) - val classes : t -> term Iter.t Iter.t + val classes : t -> Term.t Iter.t Iter.t (** Traverse the set of classes in the congruence closure. This should be called only if {!check} returned [Sat]. *) end -(** Instantiate the congruence closure for the given term structure. *) -module Make (A : ARG) : - S - with type term = A.T.Term.t - and type fun_ = A.T.Fun.t - and type term_store = A.T.Term.store +(** Instantiate the congruence closure for the given Term.t structure. *) +module Make (_ : ARG) : S diff --git a/src/cc/mini/dune b/src/cc/mini/dune index bbcbb9ad..23187086 100644 --- a/src/cc/mini/dune +++ b/src/cc/mini/dune @@ -1,5 +1,5 @@ (library (name Sidekick_mini_cc) (public_name sidekick.mini-cc) - (libraries containers iter sidekick.sigs.cc sidekick.sigs.term sidekick.util) + (libraries containers iter sidekick.cc sidekick.core sidekick.util) (flags :standard -warn-error -a+8 -w -32 -open Sidekick_util)) diff --git a/src/cc/plugin/dune b/src/cc/plugin/dune index 269abd1e..46f79cee 100644 --- a/src/cc/plugin/dune +++ b/src/cc/plugin/dune @@ -1,5 +1,5 @@ (library (name Sidekick_cc_plugin) (public_name sidekick.cc.plugin) - (libraries containers iter sidekick.sigs sidekick.sigs.cc sidekick.util) - (flags :standard -warn-error -a+8 -w -32 -open Sidekick_util)) + (libraries containers iter sidekick.sigs sidekick.cc sidekick.util) + (flags :standard -w +32 -open Sidekick_util)) diff --git a/src/cc/plugin/sidekick_cc_plugin.ml b/src/cc/plugin/sidekick_cc_plugin.ml index 6ee73414..0563977e 100644 --- a/src/cc/plugin/sidekick_cc_plugin.ml +++ b/src/cc/plugin/sidekick_cc_plugin.ml @@ -1,4 +1,4 @@ -open Sidekick_sigs_cc +open Sidekick_cc module type EXTENDED_PLUGIN_BUILDER = sig include MONOID_PLUGIN_BUILDER diff --git a/src/cc/plugin/sidekick_cc_plugin.mli b/src/cc/plugin/sidekick_cc_plugin.mli index 71ccdbc5..413d8408 100644 --- a/src/cc/plugin/sidekick_cc_plugin.mli +++ b/src/cc/plugin/sidekick_cc_plugin.mli @@ -1,6 +1,6 @@ -(** Congruence Closure Implementation *) +(** Congruence Closure Plugin *) -open Sidekick_sigs_cc +open Sidekick_cc module type EXTENDED_PLUGIN_BUILDER = sig include MONOID_PLUGIN_BUILDER diff --git a/src/cc/sigs.ml b/src/cc/sigs.ml new file mode 100644 index 00000000..20541c45 --- /dev/null +++ b/src/cc/sigs.ml @@ -0,0 +1,506 @@ +(** Main types for congruence closure *) + +open Sidekick_core +module View = View + +(** Arguments to a congruence closure's implementation *) +module type ARG = sig + val view_as_cc : Term.t -> (Const.t, Term.t, Term.t Iter.t) View.t + (** View the Term.t through the lens of the congruence closure *) +end + +(** Collection of input types, and types defined by the congruence closure *) +module type ARGS_CLASSES_EXPL_EVENT = sig + (** E-node. + + An e-node is a node in the congruence closure that is contained + in some equivalence classe). + An equivalence class is a set of terms that are currently equal + in the partial model built by the solver. + The class is represented by a collection of nodes, one of which is + distinguished and is called the "representative". + + All information pertaining to the whole equivalence class is stored + in its representative's {!E_node.t}. + + When two classes become equal (are "merged"), one of the two + representatives is picked as the representative of the new class. + The new class contains the union of the two old classes' nodes. + + We also allow theories to store additional information in the + representative. This information can be used when two classes are + merged, to detect conflicts and solve equations à la Shostak. + *) + module E_node : sig + type t + (** An E-node. + + A value of type [t] points to a particular Term.t, but see + {!find} to get the representative of the class. *) + + include Sidekick_sigs.PRINT with type t := t + + val term : t -> Term.t + (** Term contained in this equivalence class. + If [is_root n], then [Term.t n] is the class' representative Term.t. *) + + val equal : t -> t -> bool + (** Are two classes {b physically} equal? To check for + logical equality, use [CC.E_node.equal (CC.find cc n1) (CC.find cc n2)] + which checks for equality of representatives. *) + + val hash : t -> int + (** An opaque hash of this E_node.t. *) + + val is_root : t -> bool + (** Is the E_node.t a root (ie the representative of its class)? + See {!find} to get the root. *) + + val iter_class : t -> t Iter.t + (** Traverse the congruence class. + Precondition: [is_root n] (see {!find} below) *) + + val iter_parents : t -> t Iter.t + (** Traverse the parents of the class. + Precondition: [is_root n] (see {!find} below) *) + + (* FIXME: + [@@alert refactor "this should be replaced with a Per_class concept"] + *) + + type bitfield + (** A field in the bitfield of this node. This should only be + allocated when a theory is initialized. + + Bitfields are accessed using preallocated keys. + See {!CC_S.allocate_bitfield}. + + All fields are initially 0, are backtracked automatically, + and are merged automatically when classes are merged. *) + end + + (** Explanations + + Explanations are specialized proofs, created by the congruence closure + when asked to justify why two terms are equal. *) + module Expl : sig + type t + + include Sidekick_sigs.PRINT with type t := t + + val mk_merge : E_node.t -> E_node.t -> t + (** Explanation: the nodes were explicitly merged *) + + val mk_merge_t : Term.t -> Term.t -> t + (** Explanation: the terms were explicitly merged *) + + val mk_lit : Lit.t -> t + (** Explanation: we merged [t] and [u] because of literal [t=u], + or we merged [t] and [true] because of literal [t], + or [t] and [false] because of literal [¬t] *) + + val mk_list : t list -> t + (** Conjunction of explanations *) + + val mk_theory : + Term.t -> + Term.t -> + (Term.t * Term.t * t list) list -> + Proof_term.step_id -> + t + (** [mk_theory t u expl_sets pr] builds a theory explanation for + why [|- t=u]. It depends on sub-explanations [expl_sets] which + are tuples [ (t_i, u_i, expls_i) ] where [expls_i] are + explanations that justify [t_i = u_i] in the current congruence closure. + + The proof [pr] is the theory lemma, of the form + [ (t_i = u_i)_i |- t=u ]. + It is resolved against each [expls_i |- t_i=u_i] obtained from + [expl_sets], on pivot [t_i=u_i], to obtain a proof of [Gamma |- t=u] + where [Gamma] is a subset of the literals asserted into the congruence + closure. + + For example for the lemma [a=b] deduced by injectivity + from [Some a=Some b] in the theory of datatypes, + the arguments would be + [a, b, [Some a, Some b, mk_merge_t (Some a)(Some b)], pr] + where [pr] is the injectivity lemma [Some a=Some b |- a=b]. + *) + end + + (** Resolved explanations. + + The congruence closure keeps explanations for why terms are in the same + class. However these are represented in a compact, cheap form. + To use these explanations we need to {b resolve} them into a + resolved explanation, typically a list of + literals that are true in the current trail and are responsible for + merges. + + However, we can also have merged classes because they have the same value + in the current model. *) + module Resolved_expl : sig + type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id } + + include Sidekick_sigs.PRINT with type t := t + end + + (** Per-node data *) + + type e_node = E_node.t + (** A node of the congruence closure *) + + type repr = E_node.t + (** Node that is currently a representative. *) + + type explanation = Expl.t +end + +(** Main congruence closure signature. + + The congruence closure handles the theory QF_UF (uninterpreted + function symbols). + It is also responsible for {i theory combination}, and provides + a general framework for equality reasoning that other + theories piggyback on. + + For example, the theory of datatypes relies on the congruence closure + to do most of the work, and "only" adds injectivity/disjointness/acyclicity + lemmas when needed. + + Similarly, a theory of arrays would hook into the congruence closure and + assert (dis)equalities as needed. +*) +module type S = sig + include ARGS_CLASSES_EXPL_EVENT + + type t + (** The congruence closure object. + It contains a fair amount of state and is mutable + and backtrackable. *) + + (** {3 Accessors} *) + + val term_store : t -> Term.store + val proof : t -> Proof_trace.t + + val find : t -> e_node -> repr + (** Current representative *) + + val add_term : t -> Term.t -> e_node + (** Add the Term.t to the congruence closure, if not present already. + Will be backtracked. *) + + val mem_term : t -> Term.t -> bool + (** Returns [true] if the Term.t is explicitly present in the congruence closure *) + + val allocate_bitfield : t -> descr:string -> E_node.bitfield + (** Allocate a new e_node field (see {!E_node.bitfield}). + + This field descriptor is henceforth reserved for all nodes + in this congruence closure, and can be set using {!set_bitfield} + for each class_ individually. + This can be used to efficiently store some metadata on nodes + (e.g. "is there a numeric value in the class" + or "is there a constructor Term.t in the class"). + + There may be restrictions on how many distinct fields are allocated + for a given congruence closure (e.g. at most {!Sys.int_size} fields). + *) + + val get_bitfield : t -> E_node.bitfield -> E_node.t -> bool + (** Access the bit field of the given e_node *) + + val set_bitfield : t -> E_node.bitfield -> bool -> E_node.t -> unit + (** Set the bitfield for the e_node. This will be backtracked. + See {!E_node.bitfield}. *) + + type propagation_reason = unit -> Lit.t list * Proof_term.step_id + + (** Handler Actions + + Actions that can be scheduled by event handlers. *) + module Handler_action : sig + type t = + | Act_merge of E_node.t * E_node.t * Expl.t + | Act_propagate of Lit.t * propagation_reason + + (* TODO: + - an action to modify data associated with a class + *) + + type conflict = Conflict of Expl.t [@@unboxed] + + type or_conflict = (t list, conflict) result + (** Actions or conflict scheduled by an event handler. + + - [Ok acts] is a list of merges and propagations + - [Error confl] is a conflict to resolve. + *) + end + + (** Result Actions. + + + Actions returned by the congruence closure after calling {!check}. *) + module Result_action : sig + type t = + | Act_propagate of { lit: Lit.t; reason: propagation_reason } + (** [propagate (Lit.t, reason)] declares that [reason() => Lit.t] + is a tautology. + + - [reason()] should return a list of literals that are currently true, + as well as a proof. + - [Lit.t] should be a literal of interest (see {!S.set_as_lit}). + + This function might never be called, a congruence closure has the right + to not propagate and only trigger conflicts. *) + + type conflict = + | Conflict of Lit.t list * Proof_term.step_id + (** [raise_conflict (c,pr)] declares that [c] is a tautology of + the theory of congruence. + @param pr the proof of [c] being a tautology *) + + type or_conflict = (t list, conflict) result + end + + (** {3 Events} + + Events triggered by the congruence closure, to which + other plugins can subscribe. *) + + (** Events emitted by the congruence closure when something changes. *) + val on_pre_merge : + t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t + (** [Ev_on_pre_merge acts n1 n2 expl] is emitted right before [n1] + and [n2] are merged with explanation [expl]. *) + + val on_pre_merge2 : + t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t + (** Second phase of "on pre merge". This runs after {!on_pre_merge} + and is used by Plugins. {b NOTE}: Plugin state might be observed as already + changed in these handlers. *) + + val on_post_merge : + t -> (t * E_node.t * E_node.t, Handler_action.t list) Event.t + (** [ev_on_post_merge acts n1 n2] is emitted right after [n1] + and [n2] were merged. [find cc n1] and [find cc n2] will return + the same E_node.t. *) + + val on_new_term : t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t + (** [ev_on_new_term n t] is emitted whenever a new Term.t [t] + is added to the congruence closure. Its E_node.t is [n]. *) + + type ev_on_conflict = { cc: t; th: bool; c: Lit.t list } + (** Event emitted when a conflict occurs in the CC. + + [th] is true if the explanation for this conflict involves + at least one "theory" explanation; i.e. some of the equations + participating in the conflict are purely syntactic theories + like injectivity of constructors. *) + + val on_conflict : t -> (ev_on_conflict, unit) Event.t + (** [ev_on_conflict {th; c}] is emitted when the congruence + closure triggers a conflict by asserting the tautology [c]. *) + + val on_propagate : + t -> + ( t * Lit.t * (unit -> Lit.t list * Proof_term.step_id), + Handler_action.t list ) + Event.t + (** [ev_on_propagate Lit.t reason] is emitted whenever [reason() => Lit.t] + is a propagated lemma. See {!CC_ACTIONS.propagate}. *) + + val on_is_subterm : + t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t + (** [ev_on_is_subterm n t] is emitted when [n] is a subterm of + another E_node.t for the first time. [t] is the Term.t corresponding to + the E_node.t [n]. This can be useful for theory combination. *) + + (** {3 Misc} *) + + val n_true : t -> E_node.t + (** Node for [true] *) + + val n_false : t -> E_node.t + (** Node for [false] *) + + val n_bool : t -> bool -> E_node.t + (** Node for either true or false *) + + val set_as_lit : t -> E_node.t -> Lit.t -> unit + (** map the given e_node to a literal. *) + + val find_t : t -> Term.t -> repr + (** Current representative of the Term.t. + @raise E_node.t_found if the Term.t is not already {!add}-ed. *) + + val add_iter : t -> Term.t Iter.t -> unit + (** Add a sequence of terms to the congruence closure *) + + val all_classes : t -> repr Iter.t + (** All current classes. This is costly, only use if there is no other solution *) + + val explain_eq : t -> E_node.t -> E_node.t -> Resolved_expl.t + (** Explain why the two nodes are equal. + Fails if they are not, in an unspecified way. *) + + val explain_expl : t -> Expl.t -> Resolved_expl.t + (** Transform explanation into an actionable conflict clause *) + + (* FIXME: remove + val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a + (** Raise a conflict with the given explanation. + It must be a theory tautology that [expl ==> absurd]. + To be used in theories. + + This fails in an unspecified way if the explanation, once resolved, + satisfies {!Resolved_expl.is_semantic}. *) + *) + + val merge : t -> E_node.t -> E_node.t -> Expl.t -> unit + (** Merge these two nodes given this explanation. + It must be a theory tautology that [expl ==> n1 = n2]. + To be used in theories. *) + + val merge_t : t -> Term.t -> Term.t -> Expl.t -> unit + (** Shortcut for adding + merging *) + + (** {3 Main API *) + + val assert_eq : t -> Term.t -> Term.t -> Expl.t -> unit + (** Assert that two terms are equal, using the given explanation. *) + + val assert_lit : t -> Lit.t -> unit + (** Given a literal, assume it in the congruence closure and propagate + its consequences. Will be backtracked. + + Useful for the theory combination or the SAT solver's functor *) + + val assert_lits : t -> Lit.t Iter.t -> unit + (** Addition of many literals *) + + val check : t -> Result_action.or_conflict + (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. + Will use the {!actions} to propagate literals, declare conflicts, etc. *) + + val push_level : t -> unit + (** Push backtracking level *) + + val pop_levels : t -> int -> unit + (** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *) + + val get_model : t -> E_node.t Iter.t Iter.t + (** get all the equivalence classes so they can be merged in the model *) + + val create : + ?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t + (** Create a new congruence closure. + + @param term_store used to be able to create new terms. All terms + interacting with this congruence closure must belong in this term state + as well. + *) + + (**/**) + + module Debug_ : sig + val pp : t Fmt.printer + (** Print the whole CC *) + end + + (**/**) +end + +(* TODO: full EGG, also have a function to update the value when + the subterms (produced in [of_term]) are updated *) + +(** Data attached to the congruence closure classes. + + This helps theories keeping track of some state for each class. + The state of a class is the monoidal combination of the state for each + Term.t in the class (for example, the set of terms in the + class whose head symbol is a datatype constructor). *) +module type MONOID_PLUGIN_ARG = sig + module CC : S + + type t + (** Some type with a monoid structure *) + + include Sidekick_sigs.PRINT with type t := t + + val name : string + (** name of the monoid structure (short) *) + + (* FIXME: for subs, return list of e_nodes, and assume of_term already + returned data for them. *) + val of_term : + CC.t -> CC.E_node.t -> Term.t -> t option * (CC.E_node.t * t) list + (** [of_term n t], where [t] is the Term.t annotating node [n], + must return [maybe_m, l], where: + + - [maybe_m = Some m] if [t] has monoid value [m]; + otherwise [maybe_m=None] + - [l] is a list of [(u, m_u)] where each [u]'s Term.t + is a direct subterm of [t] + and [m_u] is the monoid value attached to [u]. + + *) + + val merge : + CC.t -> + CC.E_node.t -> + t -> + CC.E_node.t -> + t -> + CC.Expl.t -> + (t * CC.Handler_action.t list, CC.Handler_action.conflict) result + (** Monoidal combination of two values. + + [merge cc n1 mon1 n2 mon2 expl] returns the result of merging + monoid values [mon1] (for class [n1]) and [mon2] (for class [n2]) + when [n1] and [n2] are merged with explanation [expl]. + + @return [Ok mon] if the merge is acceptable, annotating the class of [n1 ∪ n2]; + or [Error expl'] if the merge is unsatisfiable. [expl'] can then be + used to trigger a conflict and undo the merge. + *) +end + +(** Stateful plugin holding a per-equivalence-class monoid. + + Helps keep track of monoid state per equivalence class. + A theory might use one or more instance(s) of this to + aggregate some theory-specific state over all terms, with + the information of what terms are already known to be equal + potentially saving work for the theory. *) +module type DYN_MONOID_PLUGIN = sig + module M : MONOID_PLUGIN_ARG + include Sidekick_sigs.DYN_BACKTRACKABLE + + val pp : unit Fmt.printer + + val mem : M.CC.E_node.t -> bool + (** Does the CC E_node.t have a monoid value? *) + + val get : M.CC.E_node.t -> M.t option + (** Get monoid value for this CC E_node.t, if any *) + + val iter_all : (M.CC.repr * M.t) Iter.t +end + +(** Builder for a plugin. + + The builder takes a congruence closure, and instantiate the + plugin on it. *) +module type MONOID_PLUGIN_BUILDER = sig + module M : MONOID_PLUGIN_ARG + + module type DYN_PL_FOR_M = DYN_MONOID_PLUGIN with module M = M + + type t = (module DYN_PL_FOR_M) + + val create_and_setup : ?size:int -> M.CC.t -> t + (** Create a new monoid state *) +end diff --git a/src/cc/view.ml b/src/cc/view.ml new file mode 100644 index 00000000..e319f5ef --- /dev/null +++ b/src/cc/view.ml @@ -0,0 +1,38 @@ +type ('f, 't, 'ts) t = + | Bool of bool + | App_fun of 'f * 'ts + | App_ho of 't * 't + | If of 't * 't * 't + | Eq of 't * 't + | Not of 't + | Opaque of 't +(* do not enter *) + +let map_view ~f_f ~f_t ~f_ts (v : _ t) : _ t = + match v with + | Bool b -> Bool b + | App_fun (f, args) -> App_fun (f_f f, f_ts args) + | App_ho (f, a) -> App_ho (f_t f, f_t a) + | Not t -> Not (f_t t) + | If (a, b, c) -> If (f_t a, f_t b, f_t c) + | Eq (a, b) -> Eq (f_t a, f_t b) + | Opaque t -> Opaque (f_t t) + +let iter_view ~f_f ~f_t ~f_ts (v : _ t) : unit = + match v with + | Bool _ -> () + | App_fun (f, args) -> + f_f f; + f_ts args + | App_ho (f, a) -> + f_t f; + f_t a + | Not t -> f_t t + | If (a, b, c) -> + f_t a; + f_t b; + f_t c + | Eq (a, b) -> + f_t a; + f_t b + | Opaque t -> f_t t diff --git a/src/cc/view.mli b/src/cc/view.mli new file mode 100644 index 00000000..038ea1a6 --- /dev/null +++ b/src/cc/view.mli @@ -0,0 +1,33 @@ +(** View terms through the lens of the Congruence Closure *) + +(** A view of a term fron the point of view of the congruence closure. + + - ['f] is the type of function symbols + - ['t] is the type of terms + - ['ts] is the type of sequences of terms (arguments of function application) + *) +type ('f, 't, 'ts) t = + | Bool of bool + | App_fun of 'f * 'ts + | App_ho of 't * 't + | If of 't * 't * 't + | Eq of 't * 't + | Not of 't + | Opaque of 't (** do not enter *) + +val map_view : + f_f:('a -> 'b) -> + f_t:('c -> 'd) -> + f_ts:('e -> 'f) -> + ('a, 'c, 'e) t -> + ('b, 'd, 'f) t +(** Map function over a view, one level deep. + Each function maps over a different type, e.g. [f_t] maps over terms *) + +val iter_view : + f_f:('a -> unit) -> + f_t:('b -> unit) -> + f_ts:('c -> unit) -> + ('a, 'b, 'c) t -> + unit +(** Iterate over a view, one level deep. *)