diff --git a/src/cc/CC.ml b/src/cc/CC.ml new file mode 100644 index 00000000..ba2214be --- /dev/null +++ b/src/cc/CC.ml @@ -0,0 +1,971 @@ +open Types_ + +type view_as_cc = Term.t -> (Const.t, Term.t, Term.t Iter.t) View.t + +open struct + (* proof rules *) + module Rules_ = Proof_core + module P = Proof_trace +end + +type e_node = E_node.t +(** A node of the congruence closure *) + +type repr = E_node.t +(** Node that is currently a representative. *) + +type explanation = Expl.t +type bitfield = Bits.field + +(* non-recursive, inlinable function for [find] *) +let[@inline] find_ (n : e_node) : repr = + let n2 = n.n_root in + assert (E_node.is_root n2); + n2 + +let[@inline] same_class (n1 : e_node) (n2 : e_node) : bool = + E_node.equal (find_ n1) (find_ n2) + +let[@inline] find _ n = find_ n + +module Sig_tbl = CCHashtbl.Make (Signature) +module T_tbl = Term.Tbl + +type propagation_reason = unit -> Lit.t list * Proof_term.step_id + +module Handler_action = struct + type t = + | Act_merge of E_node.t * E_node.t * Expl.t + | Act_propagate of Lit.t * propagation_reason + + type conflict = Conflict of Expl.t [@@unboxed] + type or_conflict = (t list, conflict) result +end + +module Result_action = struct + type t = Act_propagate of { lit: Lit.t; reason: propagation_reason } + type conflict = Conflict of Lit.t list * Proof_term.step_id + type or_conflict = (t list, conflict) result +end + +type combine_task = + | CT_merge of e_node * e_node * explanation + | CT_act of Handler_action.t + +type t = { + view_as_cc: view_as_cc; + tst: Term.store; + proof: Proof_trace.t; + tbl: e_node T_tbl.t; (* internalization [term -> e_node] *) + signatures_tbl: e_node Sig_tbl.t; + (* map a signature to the corresponding e_node in some equivalence class. + A signature is a [term_cell] in which every immediate subterm + that participates in the congruence/evaluation relation + is normalized (i.e. is its own representative). + The critical property is that all members of an equivalence class + that have the same "shape" (including head symbol) + have the same signature *) + pending: e_node Vec.t; + combine: combine_task Vec.t; + undo: (unit -> unit) Backtrack_stack.t; + bitgen: Bits.bitfield_gen; + field_marked_explain: Bits.field; + (* used to mark traversed nodes when looking for a common ancestor *) + true_: e_node lazy_t; + false_: e_node lazy_t; + mutable in_loop: bool; (* currently being modified? *) + res_acts: Result_action.t Vec.t; (* to return *) + on_pre_merge: + ( t * E_node.t * E_node.t * Expl.t, + Handler_action.or_conflict ) + Event.Emitter.t; + on_pre_merge2: + ( t * E_node.t * E_node.t * Expl.t, + Handler_action.or_conflict ) + Event.Emitter.t; + on_post_merge: + (t * E_node.t * E_node.t, Handler_action.t list) Event.Emitter.t; + on_new_term: (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t; + on_conflict: (ev_on_conflict, unit) Event.Emitter.t; + on_propagate: + (t * Lit.t * propagation_reason, Handler_action.t list) Event.Emitter.t; + on_is_subterm: (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t; + count_conflict: int Stat.counter; + count_props: int Stat.counter; + count_merge: int Stat.counter; +} +(* TODO: an additional union-find to keep track, for each term, + of the terms they are known to be equal to, according + to the current explanation. That allows not to prove some equality + several times. + See "fast congruence closure and extensions", Nieuwenhuis&al, page 14 *) + +and ev_on_conflict = { cc: t; th: bool; c: Lit.t list } + +let[@inline] size_ (r : repr) = r.n_size +let[@inline] n_true self = Lazy.force self.true_ +let[@inline] n_false self = Lazy.force self.false_ + +let n_bool self b = + if b then + n_true self + else + n_false self + +let[@inline] term_store self = self.tst +let[@inline] proof self = self.proof + +let allocate_bitfield self ~descr : bitfield = + Log.debugf 5 (fun k -> k "(@[cc.allocate-bit-field@ :descr %s@])" descr); + Bits.mk_field self.bitgen + +let[@inline] on_backtrack self f : unit = + Backtrack_stack.push_if_nonzero_level self.undo f + +let[@inline] set_bitfield_ f b t = t.n_bits <- Bits.set f b t.n_bits +let[@inline] get_bitfield_ field n = Bits.get field n.n_bits +let[@inline] get_bitfield _cc field n = get_bitfield_ field n + +let set_bitfield self field b n = + let old = get_bitfield self field n in + if old <> b then ( + on_backtrack self (fun () -> set_bitfield_ field old n); + set_bitfield_ field b n + ) + +(* check if [t] is in the congruence closure. + Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) +let[@inline] mem (self : t) (t : Term.t) : bool = T_tbl.mem self.tbl t + +module Debug_ = struct + (* print full state *) + let pp out (self : t) : unit = + let pp_next out n = Fmt.fprintf out "@ :next %a" E_node.pp n.n_next in + let pp_root out n = + if E_node.is_root n then + Fmt.string out " :is-root" + else + Fmt.fprintf out "@ :root %a" E_node.pp n.n_root + in + let pp_expl out n = + match n.n_expl with + | FL_none -> () + | FL_some e -> + Fmt.fprintf out " (@[:forest %a :expl %a@])" E_node.pp e.next Expl.pp + e.expl + in + let pp_n out n = + Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp_debug n.n_term pp_root n pp_next + n pp_expl n + and pp_sig_e out (s, n) = + Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s E_node.pp n pp_root + n + in + Fmt.fprintf out + "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ %a@])@])" + (Util.pp_iter ~sep:" " pp_n) + (T_tbl.values self.tbl) + (Util.pp_iter ~sep:" " pp_sig_e) + (Sig_tbl.to_iter self.signatures_tbl) +end + +(* compute up-to-date signature *) +let update_sig (s : signature) : Signature.t = + View.map_view s ~f_f:(fun x -> x) ~f_t:find_ ~f_ts:(List.map find_) + +(* find whether the given (parent) term corresponds to some signature + in [signatures_] *) +let[@inline] find_signature cc (s : signature) : repr option = + Sig_tbl.get cc.signatures_tbl s + +(* add to signature table. Assume it's not present already *) +let add_signature self (s : signature) (n : e_node) : unit = + assert (not @@ Sig_tbl.mem self.signatures_tbl s); + Log.debugf 50 (fun k -> + k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s E_node.pp n); + on_backtrack self (fun () -> Sig_tbl.remove self.signatures_tbl s); + Sig_tbl.add self.signatures_tbl s n + +let push_pending self t : unit = + Log.debugf 50 (fun k -> k "(@[cc.push-pending@ %a@])" E_node.pp t); + Vec.push self.pending t + +let push_action self (a : Handler_action.t) : unit = + Vec.push self.combine (CT_act a) + +let push_action_l self (l : _ list) : unit = List.iter (push_action self) l + +let merge_classes self t u e : unit = + if t != u && not (same_class t u) then ( + Log.debugf 50 (fun k -> + k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" E_node.pp t E_node.pp + u Expl.pp e); + Vec.push self.combine @@ CT_merge (t, u, e) + ) + +(* re-root the explanation tree of the equivalence class of [n] + so that it points to [n]. + postcondition: [n.n_expl = None] *) +let[@unroll 2] rec reroot_expl (self : t) (n : e_node) : unit = + match n.n_expl with + | FL_none -> () (* already root *) + | FL_some { next = u; expl = e_n_u } -> + (* reroot to [u], then invert link between [u] and [n] *) + reroot_expl self u; + u.n_expl <- FL_some { next = n; expl = e_n_u }; + n.n_expl <- FL_none + +exception E_confl of Result_action.conflict + +let raise_conflict_ (cc : t) ~th (e : Lit.t list) (p : Proof_term.step_id) : _ = + Profile.instant "cc.conflict"; + (* clear tasks queue *) + Vec.clear cc.pending; + Vec.clear cc.combine; + Event.emit cc.on_conflict { cc; th; c = e }; + Stat.incr cc.count_conflict; + raise (E_confl (Conflict (e, p))) + +let[@inline] all_classes self : repr Iter.t = + T_tbl.values self.tbl |> Iter.filter E_node.is_root + +(* find the closest common ancestor of [a] and [b] in the proof forest. + + Precond: + - [a] and [b] are in the same class + - no e_node has the flag [field_marked_explain] on + Invariants: + - if [n] is marked, then all the predecessors of [n] + from [a] or [b] are marked too. +*) +let find_common_ancestor self (a : e_node) (b : e_node) : e_node = + (* catch up to the other e_node *) + let rec find1 a = + if get_bitfield_ self.field_marked_explain a then + a + else ( + match a.n_expl with + | FL_none -> assert false + | FL_some r -> find1 r.next + ) + in + let rec find2 a b = + if E_node.equal a b then + a + else if get_bitfield_ self.field_marked_explain a then + a + else if get_bitfield_ self.field_marked_explain b then + b + else ( + set_bitfield_ self.field_marked_explain true a; + set_bitfield_ self.field_marked_explain true b; + match a.n_expl, b.n_expl with + | FL_some r1, FL_some r2 -> find2 r1.next r2.next + | FL_some r, FL_none -> find1 r.next + | FL_none, FL_some r -> find1 r.next + | FL_none, FL_none -> assert false + (* no common ancestor *) + ) + in + + (* cleanup tags on nodes traversed in [find2] *) + let rec cleanup_ n = + if get_bitfield_ self.field_marked_explain n then ( + set_bitfield_ self.field_marked_explain false n; + match n.n_expl with + | FL_none -> () + | FL_some { next; _ } -> cleanup_ next + ) + in + let n = find2 a b in + cleanup_ a; + cleanup_ b; + n + +module Expl_state = struct + type t = { + mutable lits: Lit.t list; + mutable th_lemmas: + (Lit.t * (Lit.t * Lit.t list) list * Proof_term.step_id) list; + } + + let create () : t = { lits = []; th_lemmas = [] } + let[@inline] copy self : t = { self with lits = self.lits } + let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits + + let[@inline] add_th (self : t) lit hyps pr : unit = + self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas + + let merge self other = + let { lits = o_lits; th_lemmas = o_lemmas } = other in + self.lits <- List.rev_append o_lits self.lits; + self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas; + () + + (* proof of [\/_i ¬lits[i]] *) + let proof_of_th_lemmas (self : t) (proof : Proof_trace.t) : Proof_term.step_id + = + let p_lits1 = Iter.of_list self.lits |> Iter.map Lit.neg in + let p_lits2 = + Iter.of_list self.th_lemmas + |> Iter.map (fun (lit_t_u, _, _) -> Lit.neg lit_t_u) + in + let p_cc = + P.add_step proof @@ Rules_.lemma_cc (Iter.append p_lits1 p_lits2) + in + let resolve_with_th_proof pr (lit_t_u, sub_proofs, pr_th) = + (* pr_th: [sub_proofs |- t=u]. + now resolve away [sub_proofs] to get literals that were + asserted in the congruence closure *) + let pr_th = + List.fold_left + (fun pr_th (lit_i, hyps_i) -> + (* [hyps_i |- lit_i] *) + let lemma_i = + P.add_step proof + @@ Rules_.lemma_cc + Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) + in + (* resolve [lit_i] away. *) + P.add_step proof + @@ Rules_.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th) + pr_th sub_proofs + in + P.add_step proof @@ Rules_.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr + in + (* resolve with theory proofs responsible for some merges, if any. *) + List.fold_left resolve_with_th_proof p_cc self.th_lemmas + + let to_resolved_expl (self : t) : Resolved_expl.t = + (* FIXME: package the th lemmas too *) + let { lits; th_lemmas = _ } = self in + let s2 = copy self in + let pr proof = proof_of_th_lemmas s2 proof in + { Resolved_expl.lits; pr } +end + +(* decompose explanation [e] into a list of literals added to [acc] *) +let rec explain_decompose_expl self (st : Expl_state.t) (e : explanation) : unit + = + Log.debugf 5 (fun k -> k "(@[cc.decompose_expl@ %a@])" Expl.pp e); + match e with + | E_trivial -> () + | E_congruence (n1, n2) -> + (match n1.n_sig0, n2.n_sig0 with + | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> + assert (Const.equal f1 f2); + assert (List.length a1 = List.length a2); + List.iter2 (explain_equal_rec_ self st) a1 a2 + | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> + explain_equal_rec_ self st f1 f2; + explain_equal_rec_ self st a1 a2 + | Some (If (a1, b1, c1)), Some (If (a2, b2, c2)) -> + explain_equal_rec_ self st a1 a2; + explain_equal_rec_ self st b1 b2; + explain_equal_rec_ self st c1 c2 + | _ -> assert false) + | E_lit lit -> Expl_state.add_lit st lit + | E_theory (t, u, expl_sets, pr) -> + let sub_proofs = + List.map + (fun (t_i, u_i, expls_i) -> + let lit_i = Lit.make_eq self.tst t_i u_i in + (* use a separate call to [explain_expls] for each set *) + let sub = explain_expls self expls_i in + Expl_state.merge st sub; + lit_i, sub.lits) + expl_sets + in + let lit_t_u = Lit.make_eq self.tst t u in + Expl_state.add_th st lit_t_u sub_proofs pr + | E_merge (a, b) -> explain_equal_rec_ self st a b + | E_merge_t (a, b) -> + (* find nodes for [a] and [b] on the fly *) + (match T_tbl.find self.tbl a, T_tbl.find self.tbl b with + | a, b -> explain_equal_rec_ self st a b + | exception Not_found -> + Error.errorf "expl: cannot find e_node(s) for %a, %a" Term.pp_debug a + Term.pp_debug b) + | E_and (a, b) -> + explain_decompose_expl self st a; + explain_decompose_expl self st b + +and explain_expls self (es : explanation list) : Expl_state.t = + let st = Expl_state.create () in + List.iter (explain_decompose_expl self st) es; + st + +and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) : + unit = + Log.debugf 5 (fun k -> + k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b); + assert (E_node.equal (find_ a) (find_ b)); + let ancestor = find_common_ancestor cc a b in + explain_along_path cc st a ancestor; + explain_along_path cc st b ancestor + +(* explain why [a = parent_a], where [a -> ... -> target] in the + proof forest *) +and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) : + unit = + let rec aux n = + if n == target then + () + else ( + match n.n_expl with + | FL_none -> assert false + | FL_some { next = next_n; expl } -> + explain_decompose_expl self st expl; + (* now prove [next_n = target] *) + aux next_n + ) + in + aux a + +(* add a term *) +let[@inline] rec add_term_rec_ self t : e_node = + match T_tbl.find self.tbl t with + | n -> n + | exception Not_found -> add_new_term_ self t + +(* add [t] when not present already *) +and add_new_term_ self (t : Term.t) : e_node = + assert (not @@ mem self t); + Log.debugf 15 (fun k -> k "(@[cc.add-term@ %a@])" Term.pp_debug t); + let n = E_node.Internal_.make t in + (* register sub-terms, add [t] to their parent list, and return the + corresponding initial signature *) + let sig0 = compute_sig0 self n in + n.n_sig0 <- sig0; + (* remove term when we backtrack *) + on_backtrack self (fun () -> + Log.debugf 30 (fun k -> k "(@[cc.remove-term@ %a@])" Term.pp_debug t); + T_tbl.remove self.tbl t); + (* add term to the table *) + T_tbl.add self.tbl t n; + if Option.is_some sig0 then + (* [n] might be merged with other equiv classes *) + push_pending self n; + Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self); + n + +(* compute the initial signature of the given e_node *) +and compute_sig0 (self : t) (n : e_node) : Signature.t option = + (* add sub-term to [cc], and register [n] to its parents. + Note that we return the exact sub-term, to get proper + explanations, but we add to the sub-term's root's parent list. *) + let deref_sub (u : Term.t) : e_node = + let sub = add_term_rec_ self u in + (* add [n] to [sub.root]'s parent list *) + (let sub_r = find_ sub in + let old_parents = sub_r.n_parents in + if Bag.is_empty old_parents then + (* first time it has parents: tell watchers that this is a subterm *) + Event.emit_iter self.on_is_subterm (self, sub, u) ~f:(push_action_l self); + on_backtrack self (fun () -> sub_r.n_parents <- old_parents); + sub_r.n_parents <- Bag.cons n sub_r.n_parents); + sub + in + let[@inline] return x = Some x in + match self.view_as_cc n.n_term with + | Bool _ | Opaque _ -> None + | Eq (a, b) -> + let a = deref_sub a in + let b = deref_sub b in + return @@ View.Eq (a, b) + | Not u -> return @@ View.Not (deref_sub u) + | App_fun (f, args) -> + let args = args |> Iter.map deref_sub |> Iter.to_list in + if args <> [] then + return @@ View.App_fun (f, args) + else + None + | App_ho (f, a) -> + let f = deref_sub f in + let a = deref_sub a in + return @@ View.App_ho (f, a) + | If (a, b, c) -> return @@ View.If (deref_sub a, deref_sub b, deref_sub c) + +let[@inline] add_term self t : e_node = add_term_rec_ self t +let mem_term = mem + +let set_as_lit self (n : e_node) (lit : Lit.t) : unit = + match n.n_as_lit with + | Some _ -> () + | None -> + Log.debugf 15 (fun k -> + k "(@[cc.set-as-lit@ %a@ %a@])" E_node.pp n Lit.pp lit); + on_backtrack self (fun () -> n.n_as_lit <- None); + n.n_as_lit <- Some lit + +(* is [n] true or false? *) +let n_is_bool_value (self : t) n : bool = + E_node.equal n (n_true self) || E_node.equal n (n_false self) + +(* gather a pair [lits, pr], where [lits] is the set of + asserted literals needed in the explanation (which is useful for + the SAT solver), and [pr] is a proof, including sub-proofs for theory + merges. *) +let lits_and_proof_of_expl (self : t) (st : Expl_state.t) : + Lit.t list * Proof_term.step_id = + let { Expl_state.lits; th_lemmas = _ } = st in + let pr = Expl_state.proof_of_th_lemmas st self.proof in + lits, pr + +(* main CC algo: add terms from [pending] to the signature table, + check for collisions *) +let rec update_tasks (self : t) : unit = + while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do + while not @@ Vec.is_empty self.pending do + task_pending_ self (Vec.pop_exn self.pending) + done; + while not @@ Vec.is_empty self.combine do + task_combine_ self (Vec.pop_exn self.combine) + done + done + +and task_pending_ self (n : e_node) : unit = + (* check if some parent collided *) + match n.n_sig0 with + | None -> () (* no-op *) + | Some (Eq (a, b)) -> + (* if [a=b] is now true, merge [(a=b)] and [true] *) + if same_class a b then ( + let expl = Expl.mk_merge a b in + Log.debugf 5 (fun k -> + k "(@[cc.pending.eq@ %a@ :r1 %a@ :r2 %a@])" E_node.pp n E_node.pp a + E_node.pp b); + merge_classes self n (n_true self) expl + ) + | Some (Not u) -> + (* [u = bool ==> not u = not bool] *) + let r_u = find_ u in + if E_node.equal r_u (n_true self) then ( + let expl = Expl.mk_merge u (n_true self) in + merge_classes self n (n_false self) expl + ) else if E_node.equal r_u (n_false self) then ( + let expl = Expl.mk_merge u (n_false self) in + merge_classes self n (n_true self) expl + ) + | Some s0 -> + (* update the signature by using [find] on each sub-e_node *) + let s = update_sig s0 in + (match find_signature self s with + | None -> + (* add to the signature table [sig(n) --> n] *) + add_signature self s n + | Some u when E_node.equal n u -> () + | Some u -> + (* [t1] and [t2] must be applications of the same symbol to + arguments that are pairwise equal *) + assert (n != u); + let expl = Expl.mk_congruence n u in + merge_classes self n u expl) + +and task_combine_ self = function + | CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab + | CT_act (Handler_action.Act_merge (t, u, e)) -> task_merge_ self t u e + | CT_act (Handler_action.Act_propagate (lit, reason)) -> + (* will return this propagation to the caller *) + Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }) + +(* main CC algo: merge equivalence classes in [st.combine]. + @raise Exn_unsat if merge fails *) +and task_merge_ self a b e_ab : unit = + let ra = find_ a in + let rb = find_ b in + if not @@ E_node.equal ra rb then ( + assert (E_node.is_root ra); + assert (E_node.is_root rb); + Stat.incr self.count_merge; + (* check we're not merging [true] and [false] *) + if + (E_node.equal ra (n_true self) && E_node.equal rb (n_false self)) + || (E_node.equal rb (n_true self) && E_node.equal ra (n_false self)) + then ( + Log.debugf 5 (fun k -> + k + "(@[cc.merge.true_false_conflict@ @[:r1 %a@ :t1 %a@]@ @[:r2 \ + %a@ :t2 %a@]@ :e_ab %a@])" + E_node.pp ra E_node.pp a E_node.pp rb E_node.pp b Expl.pp e_ab); + let th = ref false in + (* TODO: + C1: P.true_neq_false + C2: lemma [lits |- true=false] (and resolve on theory proofs) + C3: r1 C1 C2 + *) + let expl_st = Expl_state.create () in + explain_decompose_expl self expl_st e_ab; + explain_equal_rec_ self expl_st a ra; + explain_equal_rec_ self expl_st b rb; + + (* regular conflict *) + let lits, pr = lits_and_proof_of_expl self expl_st in + raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr + ); + (* We will merge [r_from] into [r_into]. + we try to ensure that [size ra <= size rb] in general, but always + keep values as representative *) + let r_from, r_into = + if n_is_bool_value self ra then + rb, ra + else if n_is_bool_value self rb then + ra, rb + else if size_ ra > size_ rb then + rb, ra + else + ra, rb + in + (* when merging terms with [true] or [false], possibly propagate them to SAT *) + let merge_bool r1 t1 r2 t2 = + if E_node.equal r1 (n_true self) then + propagate_bools self r2 t2 r1 t1 e_ab true + else if E_node.equal r1 (n_false self) then + propagate_bools self r2 t2 r1 t1 e_ab false + in + + merge_bool ra a rb b; + merge_bool rb b ra a; + + (* perform [union r_from r_into] *) + Log.debugf 15 (fun k -> + k "(@[cc.merge@ :from %a@ :into %a@])" E_node.pp r_from E_node.pp r_into); + + (* call [on_pre_merge] functions, and merge theory data items *) + (* explanation is [a=ra & e_ab & b=rb] *) + (let expl = Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ] in + + let handle_act = function + | Ok l -> push_action_l self l + | Error (Handler_action.Conflict expl) -> + raise_conflict_from_expl self expl + in + + Event.emit_iter self.on_pre_merge + (self, r_into, r_from, expl) + ~f:handle_act; + Event.emit_iter self.on_pre_merge2 + (self, r_into, r_from, expl) + ~f:handle_act); + + (* TODO: merge plugin data here, _after_ the pre-merge hooks are called, + so they have a chance of observing pre-merge plugin data *) + ((* parents might have a different signature, check for collisions *) + E_node.iter_parents r_from (fun parent -> push_pending self parent); + (* for each e_node in [r_from]'s class, make it point to [r_into] *) + E_node.iter_class r_from (fun u -> + assert (u.n_root == r_from); + u.n_root <- r_into); + (* capture current state *) + let r_into_old_next = r_into.n_next in + let r_from_old_next = r_from.n_next in + let r_into_old_parents = r_into.n_parents in + let r_into_old_bits = r_into.n_bits in + (* swap [into.next] and [from.next], merging the classes *) + r_into.n_next <- r_from_old_next; + r_from.n_next <- r_into_old_next; + r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; + r_into.n_size <- r_into.n_size + r_from.n_size; + r_into.n_bits <- Bits.merge r_into.n_bits r_from.n_bits; + (* on backtrack, unmerge classes and restore the pointers to [r_from] *) + on_backtrack self (fun () -> + Log.debugf 30 (fun k -> + k "(@[cc.undo_merge@ :from %a@ :into %a@])" E_node.pp r_from + E_node.pp r_into); + r_into.n_bits <- r_into_old_bits; + r_into.n_next <- r_into_old_next; + r_from.n_next <- r_from_old_next; + r_into.n_parents <- r_into_old_parents; + (* NOTE: this must come after the restoration of [next] pointers, + otherwise we'd iterate on too big a class *) + E_node.Internal_.iter_class_ r_from (fun u -> u.n_root <- r_from); + r_into.n_size <- r_into.n_size - r_from.n_size)); + + (* update explanations (a -> b), arbitrarily. + Note that here we merge the classes by adding a bridge between [a] + and [b], not their roots. *) + reroot_expl self a; + assert (a.n_expl = FL_none); + (* on backtracking, link may be inverted, but we delete the one + that bridges between [a] and [b] *) + on_backtrack self (fun () -> + match a.n_expl, b.n_expl with + | FL_some e, _ when E_node.equal e.next b -> a.n_expl <- FL_none + | _, FL_some e when E_node.equal e.next a -> b.n_expl <- FL_none + | _ -> assert false); + a.n_expl <- FL_some { next = b; expl = e_ab }; + (* call [on_post_merge] *) + Event.emit_iter self.on_post_merge (self, r_into, r_from) + ~f:(push_action_l self) + ) + +(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] + in the equiv class of [r1] that is a known literal back to the SAT solver + and which is not the one initially merged. + We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) +and propagate_bools self r1 t1 r2 t2 (e_12 : explanation) sign : unit = + (* explanation for [t1 =e= t2 = r2] *) + let half_expl_and_pr = + lazy + (let st = Expl_state.create () in + explain_decompose_expl self st e_12; + explain_equal_rec_ self st r2 t2; + st) + in + (* TODO: flag per class, `or`-ed on merge, to indicate if the class + contains at least one lit *) + E_node.iter_class r1 (fun u1 -> + (* propagate if: + - [u1] is a proper literal + - [t2 != r2], because that can only happen + after an explicit merge (no way to obtain that by propagation) + *) + match E_node.as_lit u1 with + | Some lit when not (E_node.equal r2 t2) -> + let lit = + if sign then + lit + else + Lit.neg lit + in + (* apply sign *) + Log.debugf 5 (fun k -> k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); + (* complete explanation with the [u1=t1] chunk *) + let (lazy st) = half_expl_and_pr in + let st = Expl_state.copy st in + (* do not modify shared st *) + explain_equal_rec_ self st u1 t1; + + (* propagate only if this doesn't depend on some semantic values *) + let reason () = + (* true literals explaining why t1=t2 *) + let guard = st.lits in + (* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *) + Expl_state.add_lit st (Lit.neg lit); + let _, pr = lits_and_proof_of_expl self st in + guard, pr + in + Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }); + Event.emit_iter self.on_propagate (self, lit, reason) + ~f:(push_action_l self); + Stat.incr self.count_props + | _ -> ()) + +(* raise a conflict from an explanation, typically from an event handler. + Raises E_confl with a result conflict. *) +and raise_conflict_from_expl self (expl : Expl.t) : 'a = + Log.debugf 5 (fun k -> + k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); + let st = Expl_state.create () in + explain_decompose_expl self st expl; + let lits, pr = lits_and_proof_of_expl self st in + let c = List.rev_map Lit.neg lits in + let th = st.th_lemmas <> [] in + raise_conflict_ self ~th c pr + +let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t) + +let push_level (self : t) : unit = + assert (not self.in_loop); + Backtrack_stack.push_level self.undo + +let pop_levels (self : t) n : unit = + assert (not self.in_loop); + Vec.clear self.pending; + Vec.clear self.combine; + Log.debugf 15 (fun k -> + k "(@[cc.pop-levels %d@ :n-lvls %d@])" n + (Backtrack_stack.n_levels self.undo)); + Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f ()); + () + +let assert_eq self t u expl : unit = + assert (not self.in_loop); + let t = add_term self t in + let u = add_term self u in + (* merge [a] and [b] *) + merge_classes self t u expl + +(* assert that this boolean literal holds. + if a lit is [= a b], merge [a] and [b]; + otherwise merge the atom with true/false *) +let assert_lit self lit : unit = + assert (not self.in_loop); + let t = Lit.term lit in + Log.debugf 15 (fun k -> k "(@[cc.assert-lit@ %a@])" Lit.pp lit); + let sign = Lit.sign lit in + match self.view_as_cc t with + | Eq (a, b) when sign -> assert_eq self a b (Expl.mk_lit lit) + | _ -> + (* equate t and true/false *) + let rhs = n_bool self sign in + let n = add_term self t in + (* TODO: ensure that this is O(1). + basically, just have [n] point to true/false and thus acquire + the corresponding value, so its superterms (like [ite]) can evaluate + properly *) + (* TODO: use oriented merge (force direction [n -> rhs]) *) + merge_classes self n rhs (Expl.mk_lit lit) + +let[@inline] assert_lits self lits : unit = + assert (not self.in_loop); + Iter.iter (assert_lit self) lits + +let merge self n1 n2 expl = + assert (not self.in_loop); + Log.debugf 5 (fun k -> + k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" E_node.pp n1 E_node.pp + n2 Expl.pp expl); + assert (Term.equal (Term.ty n1.n_term) (Term.ty n2.n_term)); + merge_classes self n1 n2 expl + +let merge_t self t1 t2 expl = + merge self (add_term self t1) (add_term self t2) expl + +let explain_eq self n1 n2 : Resolved_expl.t = + let st = Expl_state.create () in + explain_equal_rec_ self st n1 n2; + (* FIXME: also need to return the proof? *) + Expl_state.to_resolved_expl st + +let explain_expl (self : t) expl : Resolved_expl.t = + let expl_st = Expl_state.create () in + explain_decompose_expl self expl_st expl; + Expl_state.to_resolved_expl expl_st + +let[@inline] on_pre_merge self = Event.of_emitter self.on_pre_merge +let[@inline] on_pre_merge2 self = Event.of_emitter self.on_pre_merge2 +let[@inline] on_post_merge self = Event.of_emitter self.on_post_merge +let[@inline] on_new_term self = Event.of_emitter self.on_new_term +let[@inline] on_conflict self = Event.of_emitter self.on_conflict +let[@inline] on_propagate self = Event.of_emitter self.on_propagate +let[@inline] on_is_subterm self = Event.of_emitter self.on_is_subterm + +let create_ ?(stat = Stat.global) ?(size = `Big) (tst : Term.store) + (proof : Proof_trace.t) ~view_as_cc : t = + let size = + match size with + | `Small -> 128 + | `Big -> 2048 + in + let bitgen = Bits.mk_gen () in + let field_marked_explain = Bits.mk_field bitgen in + let rec cc = + { + view_as_cc; + tst; + proof; + tbl = T_tbl.create size; + signatures_tbl = Sig_tbl.create size; + bitgen; + on_pre_merge = Event.Emitter.create (); + on_pre_merge2 = Event.Emitter.create (); + on_post_merge = Event.Emitter.create (); + on_new_term = Event.Emitter.create (); + on_conflict = Event.Emitter.create (); + on_propagate = Event.Emitter.create (); + on_is_subterm = Event.Emitter.create (); + pending = Vec.create (); + combine = Vec.create (); + undo = Backtrack_stack.create (); + true_; + false_; + in_loop = false; + res_acts = Vec.create (); + field_marked_explain; + count_conflict = Stat.mk_int stat "cc.conflicts"; + count_props = Stat.mk_int stat "cc.propagations"; + count_merge = Stat.mk_int stat "cc.merges"; + } + and true_ = lazy (add_term cc (Term.true_ tst)) + and false_ = lazy (add_term cc (Term.false_ tst)) in + ignore (Lazy.force true_ : e_node); + ignore (Lazy.force false_ : e_node); + cc + +let[@inline] find_t self t : repr = + let n = T_tbl.find self.tbl t in + find_ n + +let pop_acts_ self = + let rec loop acc = + match Vec.pop self.res_acts with + | None -> acc + | Some x -> loop (x :: acc) + in + loop [] + +let check self : Result_action.or_conflict = + Log.debug 5 "(cc.check)"; + self.in_loop <- true; + let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in + try + update_tasks self; + let l = pop_acts_ self in + Ok l + with E_confl c -> Error c + +let check_inv_enabled_ = true (* XXX NUDGE *) + +(* check some internal invariants *) +let check_inv_ (self : t) : unit = + if check_inv_enabled_ then ( + Log.debug 2 "(cc.check-invariants)"; + all_classes self + |> Iter.flat_map E_node.iter_class + |> Iter.iter (fun n -> + match n.n_sig0 with + | None -> () + | Some s -> + let s' = update_sig s in + let ok = + match find_signature self s' with + | None -> false + | Some r -> E_node.equal r n.n_root + in + if not ok then + Log.debugf 0 (fun k -> + k "(@[cc.check.fail@ :n %a@ :sig %a@ :actual-sig %a@])" + E_node.pp n Signature.pp s Signature.pp s')) + ) + +(* model: return all the classes *) +let get_model (self : t) : repr Iter.t Iter.t = + check_inv_ self; + all_classes self |> Iter.map E_node.iter_class + +(** Arguments to a congruence closure's implementation *) +module type ARG = sig + val view_as_cc : view_as_cc + (** View the Term.t through the lens of the congruence closure *) +end + +module type BUILD = sig + val create : + ?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t + (** Create a new congruence closure. + + @param term_store used to be able to create new terms. All terms + interacting with this congruence closure must belong in this term state + as well. + *) +end + +module Make (A : ARG) : BUILD = struct + let create ?stat ?size tst proof : t = + create_ ?stat ?size tst proof ~view_as_cc:A.view_as_cc +end + +module Default = struct + include Make (struct + let view_as_cc (t : Term.t) : _ View.t = + let f, args = Term.unfold_app t in + match Term.view f, args with + | _, [ _; t; u ] when Term.is_eq f -> View.Eq (t, u) + | _ -> + (match Term.view t with + | Term.E_app (f, a) -> View.App_ho (f, a) + | Term.E_const c -> View.App_fun (c, Iter.empty) + | _ -> View.Opaque t) + end) +end diff --git a/src/cc/CC.mli b/src/cc/CC.mli new file mode 100644 index 00000000..117a56e2 --- /dev/null +++ b/src/cc/CC.mli @@ -0,0 +1,285 @@ +open Sidekick_core + +type e_node = E_node.t +(** A node of the congruence closure *) + +type repr = E_node.t +(** Node that is currently a representative. *) + +type explanation = Expl.t + +type bitfield = Bits.field +(** A field in the bitfield of this node. This should only be + allocated when a theory is initialized. + + Bitfields are accessed using preallocated keys. + See {!allocate_bitfield}. + + All fields are initially 0, are backtracked automatically, + and are merged automatically when classes are merged. *) + +(** Main congruence closure signature. + + The congruence closure handles the theory QF_UF (uninterpreted + function symbols). + It is also responsible for {i theory combination}, and provides + a general framework for equality reasoning that other + theories piggyback on. + + For example, the theory of datatypes relies on the congruence closure + to do most of the work, and "only" adds injectivity/disjointness/acyclicity + lemmas when needed. + + Similarly, a theory of arrays would hook into the congruence closure and + assert (dis)equalities as needed. +*) + +type t +(** The congruence closure object. + It contains a fair amount of state and is mutable + and backtrackable. *) + +(** {3 Accessors} *) + +val term_store : t -> Term.store +val proof : t -> Proof_trace.t + +val find : t -> e_node -> repr +(** Current representative *) + +val add_term : t -> Term.t -> e_node +(** Add the Term.t to the congruence closure, if not present already. + Will be backtracked. *) + +val mem_term : t -> Term.t -> bool +(** Returns [true] if the Term.t is explicitly present in the congruence closure *) + +val allocate_bitfield : t -> descr:string -> bitfield +(** Allocate a new e_node field (see {!E_node.bitfield}). + + This field descriptor is henceforth reserved for all nodes + in this congruence closure, and can be set using {!set_bitfield} + for each class_ individually. + This can be used to efficiently store some metadata on nodes + (e.g. "is there a numeric value in the class" + or "is there a constructor Term.t in the class"). + + There may be restrictions on how many distinct fields are allocated + for a given congruence closure (e.g. at most {!Sys.int_size} fields). + *) + +val get_bitfield : t -> bitfield -> E_node.t -> bool +(** Access the bit field of the given e_node *) + +val set_bitfield : t -> bitfield -> bool -> E_node.t -> unit +(** Set the bitfield for the e_node. This will be backtracked. + See {!E_node.bitfield}. *) + +type propagation_reason = unit -> Lit.t list * Proof_term.step_id + +(** Handler Actions + + Actions that can be scheduled by event handlers. *) +module Handler_action : sig + type t = + | Act_merge of E_node.t * E_node.t * Expl.t + | Act_propagate of Lit.t * propagation_reason + + (* TODO: + - an action to modify data associated with a class + *) + + type conflict = Conflict of Expl.t [@@unboxed] + + type or_conflict = (t list, conflict) result + (** Actions or conflict scheduled by an event handler. + + - [Ok acts] is a list of merges and propagations + - [Error confl] is a conflict to resolve. + *) +end + +(** Result Actions. + + + Actions returned by the congruence closure after calling {!check}. *) +module Result_action : sig + type t = + | Act_propagate of { lit: Lit.t; reason: propagation_reason } + (** [propagate (Lit.t, reason)] declares that [reason() => Lit.t] + is a tautology. + + - [reason()] should return a list of literals that are currently true, + as well as a proof. + - [Lit.t] should be a literal of interest (see {!S.set_as_lit}). + + This function might never be called, a congruence closure has the right + to not propagate and only trigger conflicts. *) + + type conflict = + | Conflict of Lit.t list * Proof_term.step_id + (** [raise_conflict (c,pr)] declares that [c] is a tautology of + the theory of congruence. + @param pr the proof of [c] being a tautology *) + + type or_conflict = (t list, conflict) result +end + +(** {3 Events} + + Events triggered by the congruence closure, to which + other plugins can subscribe. *) + +(** Events emitted by the congruence closure when something changes. *) +val on_pre_merge : + t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t +(** [Ev_on_pre_merge acts n1 n2 expl] is emitted right before [n1] + and [n2] are merged with explanation [expl]. *) + +val on_pre_merge2 : + t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t +(** Second phase of "on pre merge". This runs after {!on_pre_merge} + and is used by Plugins. {b NOTE}: Plugin state might be observed as already + changed in these handlers. *) + +val on_post_merge : + t -> (t * E_node.t * E_node.t, Handler_action.t list) Event.t +(** [ev_on_post_merge acts n1 n2] is emitted right after [n1] + and [n2] were merged. [find cc n1] and [find cc n2] will return + the same E_node.t. *) + +val on_new_term : t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t +(** [ev_on_new_term n t] is emitted whenever a new Term.t [t] + is added to the congruence closure. Its E_node.t is [n]. *) + +type ev_on_conflict = { cc: t; th: bool; c: Lit.t list } +(** Event emitted when a conflict occurs in the CC. + + [th] is true if the explanation for this conflict involves + at least one "theory" explanation; i.e. some of the equations + participating in the conflict are purely syntactic theories + like injectivity of constructors. *) + +val on_conflict : t -> (ev_on_conflict, unit) Event.t +(** [ev_on_conflict {th; c}] is emitted when the congruence + closure triggers a conflict by asserting the tautology [c]. *) + +val on_propagate : + t -> + ( t * Lit.t * (unit -> Lit.t list * Proof_term.step_id), + Handler_action.t list ) + Event.t +(** [ev_on_propagate Lit.t reason] is emitted whenever [reason() => Lit.t] + is a propagated lemma. See {!CC_ACTIONS.propagate}. *) + +val on_is_subterm : t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t +(** [ev_on_is_subterm n t] is emitted when [n] is a subterm of + another E_node.t for the first time. [t] is the Term.t corresponding to + the E_node.t [n]. This can be useful for theory combination. *) + +(** {3 Misc} *) + +val n_true : t -> E_node.t +(** Node for [true] *) + +val n_false : t -> E_node.t +(** Node for [false] *) + +val n_bool : t -> bool -> E_node.t +(** Node for either true or false *) + +val set_as_lit : t -> E_node.t -> Lit.t -> unit +(** map the given e_node to a literal. *) + +val find_t : t -> Term.t -> repr +(** Current representative of the Term.t. + @raise E_node.t_found if the Term.t is not already {!add}-ed. *) + +val add_iter : t -> Term.t Iter.t -> unit +(** Add a sequence of terms to the congruence closure *) + +val all_classes : t -> repr Iter.t +(** All current classes. This is costly, only use if there is no other solution *) + +val explain_eq : t -> E_node.t -> E_node.t -> Resolved_expl.t +(** Explain why the two nodes are equal. + Fails if they are not, in an unspecified way. *) + +val explain_expl : t -> Expl.t -> Resolved_expl.t +(** Transform explanation into an actionable conflict clause *) + +(* FIXME: remove + val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a + (** Raise a conflict with the given explanation. + It must be a theory tautology that [expl ==> absurd]. + To be used in theories. + + This fails in an unspecified way if the explanation, once resolved, + satisfies {!Resolved_expl.is_semantic}. *) +*) + +val merge : t -> E_node.t -> E_node.t -> Expl.t -> unit +(** Merge these two nodes given this explanation. + It must be a theory tautology that [expl ==> n1 = n2]. + To be used in theories. *) + +val merge_t : t -> Term.t -> Term.t -> Expl.t -> unit +(** Shortcut for adding + merging *) + +(** {3 Main API *) + +val assert_eq : t -> Term.t -> Term.t -> Expl.t -> unit +(** Assert that two terms are equal, using the given explanation. *) + +val assert_lit : t -> Lit.t -> unit +(** Given a literal, assume it in the congruence closure and propagate + its consequences. Will be backtracked. + + Useful for the theory combination or the SAT solver's functor *) + +val assert_lits : t -> Lit.t Iter.t -> unit +(** Addition of many literals *) + +val check : t -> Result_action.or_conflict +(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. + Will use the {!actions} to propagate literals, declare conflicts, etc. *) + +val push_level : t -> unit +(** Push backtracking level *) + +val pop_levels : t -> int -> unit +(** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *) + +val get_model : t -> E_node.t Iter.t Iter.t +(** get all the equivalence classes so they can be merged in the model *) + +type view_as_cc = Term.t -> (Const.t, Term.t, Term.t Iter.t) View.t + +(** Arguments to a congruence closure's implementation *) +module type ARG = sig + val view_as_cc : view_as_cc + (** View the Term.t through the lens of the congruence closure *) +end + +module type BUILD = sig + val create : + ?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t + (** Create a new congruence closure. + + @param term_store used to be able to create new terms. All terms + interacting with this congruence closure must belong in this term state + as well. + *) +end + +module Make (_ : ARG) : BUILD +module Default : BUILD + +(**/**) + +module Debug_ : sig + val pp : t Fmt.printer + (** Print the whole CC *) +end + +(**/**) diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 9b272403..ad7dc973 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -1,22 +1,14 @@ open Sidekick_core module View = View +module E_node = E_node +module Expl = Expl +module Signature = Signature +module Resolved_expl = Resolved_expl +module Plugin = Plugin +module CC = CC -module type ARG = Sigs.ARG -module type S = Sigs.S -module type DYN_MONOID_PLUGIN = Sigs.DYN_MONOID_PLUGIN -module type MONOID_PLUGIN_ARG = Sigs.MONOID_PLUGIN_ARG -module type MONOID_PLUGIN_BUILDER = Sigs.MONOID_PLUGIN_BUILDER +module type DYN_MONOID_PLUGIN = Sigs_plugin.DYN_MONOID_PLUGIN +module type MONOID_PLUGIN_ARG = Sigs_plugin.MONOID_PLUGIN_ARG +module type MONOID_PLUGIN_BUILDER = Sigs_plugin.MONOID_PLUGIN_BUILDER -module Make (A : ARG) : S = Core_cc.Make (A) - -module Base : S = Make (struct - let view_as_cc (t : Term.t) : _ View.t = - let f, args = Term.unfold_app t in - match Term.view f, args with - | _, [ _; t; u ] when Term.is_eq f -> View.Eq (t, u) - | _ -> - (match Term.view t with - | Term.E_app (f, a) -> View.App_ho (f, a) - | Term.E_const c -> View.App_fun (c, Iter.empty) - | _ -> View.Opaque t) -end) +include CC diff --git a/src/cc/Sidekick_cc.mli b/src/cc/Sidekick_cc.mli index 9c7da989..0facab95 100644 --- a/src/cc/Sidekick_cc.mli +++ b/src/cc/Sidekick_cc.mli @@ -1,15 +1,19 @@ (** Congruence Closure Implementation *) open Sidekick_core + +module type DYN_MONOID_PLUGIN = Sigs_plugin.DYN_MONOID_PLUGIN +module type MONOID_PLUGIN_ARG = Sigs_plugin.MONOID_PLUGIN_ARG +module type MONOID_PLUGIN_BUILDER = Sigs_plugin.MONOID_PLUGIN_BUILDER + module View = View +module E_node = E_node +module Expl = Expl +module Signature = Signature +module Resolved_expl = Resolved_expl +module Plugin = Plugin +module CC = CC -module type ARG = Sigs.ARG -module type S = Sigs.S -module type DYN_MONOID_PLUGIN = Sigs.DYN_MONOID_PLUGIN -module type MONOID_PLUGIN_ARG = Sigs.MONOID_PLUGIN_ARG -module type MONOID_PLUGIN_BUILDER = Sigs.MONOID_PLUGIN_BUILDER - -module Make (_ : ARG) : S - -module Base : S -(** Basic implementation following terms' shape *) +include module type of struct + include CC +end diff --git a/src/cc/core_cc.ml b/src/cc/core_cc.ml deleted file mode 100644 index 0df1b40a..00000000 --- a/src/cc/core_cc.ml +++ /dev/null @@ -1,1136 +0,0 @@ -(* actual implementation *) - -open Sidekick_core -open View - -module type ARG = Sigs.ARG - -module Make (A : ARG) : Sigs.S = struct - open struct - (* proof rules *) - module Rules_ = Proof_core - module P = Proof_trace - end - - type e_node = { - n_term: Term.t; - mutable n_sig0: signature option; (* initial signature *) - mutable n_bits: Bits.t; (* bitfield for various properties *) - mutable n_parents: e_node Bag.t; (* parent terms of this node *) - mutable n_root: e_node; - (* representative of congruence class (itself if a representative) *) - mutable n_next: e_node; (* pointer to next element of congruence class *) - mutable n_size: int; (* size of the class *) - mutable n_as_lit: Lit.t option; - (* TODO: put into payload? and only in root? *) - mutable n_expl: explanation_forest_link; - (* the rooted forest for explanations *) - } - (** A node of the congruence closure. - An equivalence class is represented by its "root" element, - the representative. *) - - and signature = (Const.t, e_node, e_node list) View.t - - and explanation_forest_link = - | FL_none - | FL_some of { next: e_node; expl: explanation } - - (* atomic explanation in the congruence closure *) - and explanation = - | E_trivial (* by pure reduction, tautologically equal *) - | E_lit of Lit.t (* because of this literal *) - | E_merge of e_node * e_node - | E_merge_t of Term.t * Term.t - | E_congruence of e_node * e_node (* caused by normal congruence *) - | E_and of explanation * explanation - | E_theory of - Term.t - * Term.t - * (Term.t * Term.t * explanation list) list - * Proof_term.step_id - - type repr = e_node - - module E_node = struct - type t = e_node - - let[@inline] equal (n1 : t) n2 = n1 == n2 - let[@inline] hash n = Term.hash n.n_term - let[@inline] term n = n.n_term - let[@inline] pp out n = Term.pp_debug out n.n_term - let[@inline] as_lit n = n.n_as_lit - - let make (t : Term.t) : t = - let rec n = - { - n_term = t; - n_sig0 = None; - n_bits = Bits.empty; - n_parents = Bag.empty; - n_as_lit = None; - (* TODO: provide a method to do it *) - n_root = n; - n_expl = FL_none; - n_next = n; - n_size = 1; - } - in - n - - let[@inline] is_root (n : e_node) : bool = n.n_root == n - - (* traverse the equivalence class of [n] *) - let iter_class_ (n : e_node) : e_node Iter.t = - fun yield -> - let rec aux u = - yield u; - if u.n_next != n then aux u.n_next - in - aux n - - let[@inline] iter_class n = - assert (is_root n); - iter_class_ n - - let[@inline] iter_parents (n : e_node) : e_node Iter.t = - assert (is_root n); - Bag.to_iter n.n_parents - - type bitfield = Bits.field - - let[@inline] get_field f t = Bits.get f t.n_bits - let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits - end - - (* non-recursive, inlinable function for [find] *) - let[@inline] find_ (n : e_node) : repr = - let n2 = n.n_root in - assert (E_node.is_root n2); - n2 - - let[@inline] same_class (n1 : e_node) (n2 : e_node) : bool = - E_node.equal (find_ n1) (find_ n2) - - let[@inline] find _ n = find_ n - - module Expl = struct - type t = explanation - - let rec pp out (e : explanation) = - match e with - | E_trivial -> Fmt.string out "reduction" - | E_lit lit -> Lit.pp out lit - | E_congruence (n1, n2) -> - Fmt.fprintf out "(@[congruence@ %a@ %a@])" E_node.pp n1 E_node.pp n2 - | E_merge (a, b) -> - Fmt.fprintf out "(@[merge@ %a@ %a@])" E_node.pp a E_node.pp b - | E_merge_t (a, b) -> - Fmt.fprintf out "(@[merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp_debug - a Term.pp_debug b - | E_theory (t, u, es, _) -> - Fmt.fprintf out "(@[th@ :t `%a`@ :u `%a`@ :expl_sets %a@])" - Term.pp_debug t Term.pp_debug u - (Util.pp_list - @@ Fmt.Dump.triple Term.pp_debug Term.pp_debug (Fmt.Dump.list pp)) - es - | E_and (a, b) -> Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b - - let mk_trivial : t = E_trivial - let[@inline] mk_congruence n1 n2 : t = E_congruence (n1, n2) - - let[@inline] mk_merge a b : t = - if E_node.equal a b then - mk_trivial - else - E_merge (a, b) - - let[@inline] mk_merge_t a b : t = - if Term.equal a b then - mk_trivial - else - E_merge_t (a, b) - - let[@inline] mk_lit l : t = E_lit l - let[@inline] mk_theory t u es pr = E_theory (t, u, es, pr) - - let rec mk_list l = - match l with - | [] -> mk_trivial - | [ x ] -> x - | E_trivial :: tl -> mk_list tl - | x :: y -> - (match mk_list y with - | E_trivial -> x - | y' -> E_and (x, y')) - end - - module Resolved_expl = struct - type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id } - - let pp out (self : t) = - Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits - end - - (** A signature is a shallow term shape where immediate subterms - are representative *) - module Signature = struct - type t = signature - - let equal (s1 : t) s2 : bool = - match s1, s2 with - | Bool b1, Bool b2 -> b1 = b2 - | App_fun (f1, []), App_fun (f2, []) -> Const.equal f1 f2 - | App_fun (f1, l1), App_fun (f2, l2) -> - Const.equal f1 f2 && CCList.equal E_node.equal l1 l2 - | App_ho (f1, a1), App_ho (f2, a2) -> - E_node.equal f1 f2 && E_node.equal a1 a2 - | Not a, Not b -> E_node.equal a b - | If (a1, b1, c1), If (a2, b2, c2) -> - E_node.equal a1 a2 && E_node.equal b1 b2 && E_node.equal c1 c2 - | Eq (a1, b1), Eq (a2, b2) -> E_node.equal a1 a2 && E_node.equal b1 b2 - | Opaque u1, Opaque u2 -> E_node.equal u1 u2 - | Bool _, _ - | App_fun _, _ - | App_ho _, _ - | If _, _ - | Eq _, _ - | Opaque _, _ - | Not _, _ -> - false - - let hash (s : t) : int = - let module H = CCHash in - match s with - | Bool b -> H.combine2 10 (H.bool b) - | App_fun (f, l) -> H.combine3 20 (Const.hash f) (H.list E_node.hash l) - | App_ho (f, a) -> H.combine3 30 (E_node.hash f) (E_node.hash a) - | Eq (a, b) -> H.combine3 40 (E_node.hash a) (E_node.hash b) - | Opaque u -> H.combine2 50 (E_node.hash u) - | If (a, b, c) -> - H.combine4 60 (E_node.hash a) (E_node.hash b) (E_node.hash c) - | Not u -> H.combine2 70 (E_node.hash u) - - let pp out = function - | Bool b -> Fmt.bool out b - | App_fun (f, []) -> Const.pp out f - | App_fun (f, l) -> - Fmt.fprintf out "(@[%a@ %a@])" Const.pp f (Util.pp_list E_node.pp) l - | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" E_node.pp f E_node.pp a - | Opaque t -> E_node.pp out t - | Not u -> Fmt.fprintf out "(@[not@ %a@])" E_node.pp u - | Eq (a, b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" E_node.pp a E_node.pp b - | If (a, b, c) -> - Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" E_node.pp a E_node.pp b - E_node.pp c - end - - module Sig_tbl = CCHashtbl.Make (Signature) - module T_tbl = CCHashtbl.Make (Term) - - type propagation_reason = unit -> Lit.t list * Proof_term.step_id - - module Handler_action = struct - type t = - | Act_merge of E_node.t * E_node.t * Expl.t - | Act_propagate of Lit.t * propagation_reason - - type conflict = Conflict of Expl.t [@@unboxed] - type or_conflict = (t list, conflict) result - end - - module Result_action = struct - type t = Act_propagate of { lit: Lit.t; reason: propagation_reason } - type conflict = Conflict of Lit.t list * Proof_term.step_id - type or_conflict = (t list, conflict) result - end - - type combine_task = - | CT_merge of e_node * e_node * explanation - | CT_act of Handler_action.t - - type t = { - tst: Term.store; - proof: Proof_trace.t; - tbl: e_node T_tbl.t; (* internalization [term -> e_node] *) - signatures_tbl: e_node Sig_tbl.t; - (* map a signature to the corresponding e_node in some equivalence class. - A signature is a [term_cell] in which every immediate subterm - that participates in the congruence/evaluation relation - is normalized (i.e. is its own representative). - The critical property is that all members of an equivalence class - that have the same "shape" (including head symbol) - have the same signature *) - pending: e_node Vec.t; - combine: combine_task Vec.t; - undo: (unit -> unit) Backtrack_stack.t; - bitgen: Bits.bitfield_gen; - field_marked_explain: Bits.field; - (* used to mark traversed nodes when looking for a common ancestor *) - true_: e_node lazy_t; - false_: e_node lazy_t; - mutable in_loop: bool; (* currently being modified? *) - res_acts: Result_action.t Vec.t; (* to return *) - on_pre_merge: - ( t * E_node.t * E_node.t * Expl.t, - Handler_action.or_conflict ) - Event.Emitter.t; - on_pre_merge2: - ( t * E_node.t * E_node.t * Expl.t, - Handler_action.or_conflict ) - Event.Emitter.t; - on_post_merge: - (t * E_node.t * E_node.t, Handler_action.t list) Event.Emitter.t; - on_new_term: (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t; - on_conflict: (ev_on_conflict, unit) Event.Emitter.t; - on_propagate: - (t * Lit.t * propagation_reason, Handler_action.t list) Event.Emitter.t; - on_is_subterm: - (t * E_node.t * Term.t, Handler_action.t list) Event.Emitter.t; - count_conflict: int Stat.counter; - count_props: int Stat.counter; - count_merge: int Stat.counter; - } - (* TODO: an additional union-find to keep track, for each term, - of the terms they are known to be equal to, according - to the current explanation. That allows not to prove some equality - several times. - See "fast congruence closure and extensions", Nieuwenhuis&al, page 14 *) - - and ev_on_conflict = { cc: t; th: bool; c: Lit.t list } - - let[@inline] size_ (r : repr) = r.n_size - let[@inline] n_true self = Lazy.force self.true_ - let[@inline] n_false self = Lazy.force self.false_ - - let n_bool self b = - if b then - n_true self - else - n_false self - - let[@inline] term_store self = self.tst - let[@inline] proof self = self.proof - - let allocate_bitfield self ~descr = - Log.debugf 5 (fun k -> k "(@[cc.allocate-bit-field@ :descr %s@])" descr); - Bits.mk_field self.bitgen - - let[@inline] on_backtrack self f : unit = - Backtrack_stack.push_if_nonzero_level self.undo f - - let[@inline] get_bitfield _cc field n = E_node.get_field field n - - let set_bitfield self field b n = - let old = E_node.get_field field n in - if old <> b then ( - on_backtrack self (fun () -> E_node.set_field field old n); - E_node.set_field field b n - ) - - (* check if [t] is in the congruence closure. - Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) - let[@inline] mem (self : t) (t : Term.t) : bool = T_tbl.mem self.tbl t - - module Debug_ = struct - (* print full state *) - let pp out (self : t) : unit = - let pp_next out n = Fmt.fprintf out "@ :next %a" E_node.pp n.n_next in - let pp_root out n = - if E_node.is_root n then - Fmt.string out " :is-root" - else - Fmt.fprintf out "@ :root %a" E_node.pp n.n_root - in - let pp_expl out n = - match n.n_expl with - | FL_none -> () - | FL_some e -> - Fmt.fprintf out " (@[:forest %a :expl %a@])" E_node.pp e.next Expl.pp - e.expl - in - let pp_n out n = - Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp_debug n.n_term pp_root n - pp_next n pp_expl n - and pp_sig_e out (s, n) = - Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s E_node.pp n - pp_root n - in - Fmt.fprintf out - "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ \ - %a@])@])" - (Util.pp_iter ~sep:" " pp_n) - (T_tbl.values self.tbl) - (Util.pp_iter ~sep:" " pp_sig_e) - (Sig_tbl.to_iter self.signatures_tbl) - end - - (* compute up-to-date signature *) - let update_sig (s : signature) : Signature.t = - View.map_view s ~f_f:(fun x -> x) ~f_t:find_ ~f_ts:(List.map find_) - - (* find whether the given (parent) term corresponds to some signature - in [signatures_] *) - let[@inline] find_signature cc (s : signature) : repr option = - Sig_tbl.get cc.signatures_tbl s - - (* add to signature table. Assume it's not present already *) - let add_signature self (s : signature) (n : e_node) : unit = - assert (not @@ Sig_tbl.mem self.signatures_tbl s); - Log.debugf 50 (fun k -> - k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s E_node.pp n); - on_backtrack self (fun () -> Sig_tbl.remove self.signatures_tbl s); - Sig_tbl.add self.signatures_tbl s n - - let push_pending self t : unit = - Log.debugf 50 (fun k -> k "(@[cc.push-pending@ %a@])" E_node.pp t); - Vec.push self.pending t - - let push_action self (a : Handler_action.t) : unit = - Vec.push self.combine (CT_act a) - - let push_action_l self (l : _ list) : unit = List.iter (push_action self) l - - let merge_classes self t u e : unit = - if t != u && not (same_class t u) then ( - Log.debugf 50 (fun k -> - k "(@[cc.push-combine@ %a ~@ %a@ :expl %a@])" E_node.pp t - E_node.pp u Expl.pp e); - Vec.push self.combine @@ CT_merge (t, u, e) - ) - - (* re-root the explanation tree of the equivalence class of [n] - so that it points to [n]. - postcondition: [n.n_expl = None] *) - let[@unroll 2] rec reroot_expl (self : t) (n : e_node) : unit = - match n.n_expl with - | FL_none -> () (* already root *) - | FL_some { next = u; expl = e_n_u } -> - (* reroot to [u], then invert link between [u] and [n] *) - reroot_expl self u; - u.n_expl <- FL_some { next = n; expl = e_n_u }; - n.n_expl <- FL_none - - exception E_confl of Result_action.conflict - - let raise_conflict_ (cc : t) ~th (e : Lit.t list) (p : Proof_term.step_id) : _ - = - Profile.instant "cc.conflict"; - (* clear tasks queue *) - Vec.clear cc.pending; - Vec.clear cc.combine; - Event.emit cc.on_conflict { cc; th; c = e }; - Stat.incr cc.count_conflict; - raise (E_confl (Conflict (e, p))) - - let[@inline] all_classes self : repr Iter.t = - T_tbl.values self.tbl |> Iter.filter E_node.is_root - - (* find the closest common ancestor of [a] and [b] in the proof forest. - - Precond: - - [a] and [b] are in the same class - - no e_node has the flag [field_marked_explain] on - Invariants: - - if [n] is marked, then all the predecessors of [n] - from [a] or [b] are marked too. - *) - let find_common_ancestor self (a : e_node) (b : e_node) : e_node = - (* catch up to the other e_node *) - let rec find1 a = - if E_node.get_field self.field_marked_explain a then - a - else ( - match a.n_expl with - | FL_none -> assert false - | FL_some r -> find1 r.next - ) - in - let rec find2 a b = - if E_node.equal a b then - a - else if E_node.get_field self.field_marked_explain a then - a - else if E_node.get_field self.field_marked_explain b then - b - else ( - E_node.set_field self.field_marked_explain true a; - E_node.set_field self.field_marked_explain true b; - match a.n_expl, b.n_expl with - | FL_some r1, FL_some r2 -> find2 r1.next r2.next - | FL_some r, FL_none -> find1 r.next - | FL_none, FL_some r -> find1 r.next - | FL_none, FL_none -> assert false - (* no common ancestor *) - ) - in - - (* cleanup tags on nodes traversed in [find2] *) - let rec cleanup_ n = - if E_node.get_field self.field_marked_explain n then ( - E_node.set_field self.field_marked_explain false n; - match n.n_expl with - | FL_none -> () - | FL_some { next; _ } -> cleanup_ next - ) - in - let n = find2 a b in - cleanup_ a; - cleanup_ b; - n - - module Expl_state = struct - type t = { - mutable lits: Lit.t list; - mutable th_lemmas: - (Lit.t * (Lit.t * Lit.t list) list * Proof_term.step_id) list; - } - - let create () : t = { lits = []; th_lemmas = [] } - let[@inline] copy self : t = { self with lits = self.lits } - let[@inline] add_lit (self : t) lit = self.lits <- lit :: self.lits - - let[@inline] add_th (self : t) lit hyps pr : unit = - self.th_lemmas <- (lit, hyps, pr) :: self.th_lemmas - - let merge self other = - let { lits = o_lits; th_lemmas = o_lemmas } = other in - self.lits <- List.rev_append o_lits self.lits; - self.th_lemmas <- List.rev_append o_lemmas self.th_lemmas; - () - - (* proof of [\/_i ¬lits[i]] *) - let proof_of_th_lemmas (self : t) (proof : Proof_trace.t) : - Proof_term.step_id = - let p_lits1 = Iter.of_list self.lits |> Iter.map Lit.neg in - let p_lits2 = - Iter.of_list self.th_lemmas - |> Iter.map (fun (lit_t_u, _, _) -> Lit.neg lit_t_u) - in - let p_cc = - P.add_step proof @@ Rules_.lemma_cc (Iter.append p_lits1 p_lits2) - in - let resolve_with_th_proof pr (lit_t_u, sub_proofs, pr_th) = - (* pr_th: [sub_proofs |- t=u]. - now resolve away [sub_proofs] to get literals that were - asserted in the congruence closure *) - let pr_th = - List.fold_left - (fun pr_th (lit_i, hyps_i) -> - (* [hyps_i |- lit_i] *) - let lemma_i = - P.add_step proof - @@ Rules_.lemma_cc - Iter.(cons lit_i (of_list hyps_i |> map Lit.neg)) - in - (* resolve [lit_i] away. *) - P.add_step proof - @@ Rules_.proof_res ~pivot:(Lit.term lit_i) lemma_i pr_th) - pr_th sub_proofs - in - P.add_step proof @@ Rules_.proof_res ~pivot:(Lit.term lit_t_u) pr_th pr - in - (* resolve with theory proofs responsible for some merges, if any. *) - List.fold_left resolve_with_th_proof p_cc self.th_lemmas - - let to_resolved_expl (self : t) : Resolved_expl.t = - (* FIXME: package the th lemmas too *) - let { lits; th_lemmas = _ } = self in - let s2 = copy self in - let pr proof = proof_of_th_lemmas s2 proof in - { Resolved_expl.lits; pr } - end - - (* decompose explanation [e] into a list of literals added to [acc] *) - let rec explain_decompose_expl self (st : Expl_state.t) (e : explanation) : - unit = - Log.debugf 5 (fun k -> k "(@[cc.decompose_expl@ %a@])" Expl.pp e); - match e with - | E_trivial -> () - | E_congruence (n1, n2) -> - (match n1.n_sig0, n2.n_sig0 with - | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> - assert (Const.equal f1 f2); - assert (List.length a1 = List.length a2); - List.iter2 (explain_equal_rec_ self st) a1 a2 - | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> - explain_equal_rec_ self st f1 f2; - explain_equal_rec_ self st a1 a2 - | Some (If (a1, b1, c1)), Some (If (a2, b2, c2)) -> - explain_equal_rec_ self st a1 a2; - explain_equal_rec_ self st b1 b2; - explain_equal_rec_ self st c1 c2 - | _ -> assert false) - | E_lit lit -> Expl_state.add_lit st lit - | E_theory (t, u, expl_sets, pr) -> - let sub_proofs = - List.map - (fun (t_i, u_i, expls_i) -> - let lit_i = Lit.make_eq self.tst t_i u_i in - (* use a separate call to [explain_expls] for each set *) - let sub = explain_expls self expls_i in - Expl_state.merge st sub; - lit_i, sub.lits) - expl_sets - in - let lit_t_u = Lit.make_eq self.tst t u in - Expl_state.add_th st lit_t_u sub_proofs pr - | E_merge (a, b) -> explain_equal_rec_ self st a b - | E_merge_t (a, b) -> - (* find nodes for [a] and [b] on the fly *) - (match T_tbl.find self.tbl a, T_tbl.find self.tbl b with - | a, b -> explain_equal_rec_ self st a b - | exception Not_found -> - Error.errorf "expl: cannot find e_node(s) for %a, %a" Term.pp_debug a - Term.pp_debug b) - | E_and (a, b) -> - explain_decompose_expl self st a; - explain_decompose_expl self st b - - and explain_expls self (es : explanation list) : Expl_state.t = - let st = Expl_state.create () in - List.iter (explain_decompose_expl self st) es; - st - - and explain_equal_rec_ (cc : t) (st : Expl_state.t) (a : e_node) (b : e_node) - : unit = - Log.debugf 5 (fun k -> - k "(@[cc.explain_loop.at@ %a@ =?= %a@])" E_node.pp a E_node.pp b); - assert (E_node.equal (find_ a) (find_ b)); - let ancestor = find_common_ancestor cc a b in - explain_along_path cc st a ancestor; - explain_along_path cc st b ancestor - - (* explain why [a = parent_a], where [a -> ... -> target] in the - proof forest *) - and explain_along_path self (st : Expl_state.t) (a : e_node) (target : e_node) - : unit = - let rec aux n = - if n == target then - () - else ( - match n.n_expl with - | FL_none -> assert false - | FL_some { next = next_n; expl } -> - explain_decompose_expl self st expl; - (* now prove [next_n = target] *) - aux next_n - ) - in - aux a - - (* add a term *) - let[@inline] rec add_term_rec_ self t : e_node = - match T_tbl.find self.tbl t with - | n -> n - | exception Not_found -> add_new_term_ self t - - (* add [t] when not present already *) - and add_new_term_ self (t : Term.t) : e_node = - assert (not @@ mem self t); - Log.debugf 15 (fun k -> k "(@[cc.add-term@ %a@])" Term.pp_debug t); - let n = E_node.make t in - (* register sub-terms, add [t] to their parent list, and return the - corresponding initial signature *) - let sig0 = compute_sig0 self n in - n.n_sig0 <- sig0; - (* remove term when we backtrack *) - on_backtrack self (fun () -> - Log.debugf 30 (fun k -> k "(@[cc.remove-term@ %a@])" Term.pp_debug t); - T_tbl.remove self.tbl t); - (* add term to the table *) - T_tbl.add self.tbl t n; - if Option.is_some sig0 then - (* [n] might be merged with other equiv classes *) - push_pending self n; - Event.emit_iter self.on_new_term (self, n, t) ~f:(push_action_l self); - n - - (* compute the initial signature of the given e_node *) - and compute_sig0 (self : t) (n : e_node) : Signature.t option = - (* add sub-term to [cc], and register [n] to its parents. - Note that we return the exact sub-term, to get proper - explanations, but we add to the sub-term's root's parent list. *) - let deref_sub (u : Term.t) : e_node = - let sub = add_term_rec_ self u in - (* add [n] to [sub.root]'s parent list *) - (let sub_r = find_ sub in - let old_parents = sub_r.n_parents in - if Bag.is_empty old_parents then - (* first time it has parents: tell watchers that this is a subterm *) - Event.emit_iter self.on_is_subterm (self, sub, u) - ~f:(push_action_l self); - on_backtrack self (fun () -> sub_r.n_parents <- old_parents); - sub_r.n_parents <- Bag.cons n sub_r.n_parents); - sub - in - let[@inline] return x = Some x in - match A.view_as_cc n.n_term with - | Bool _ | Opaque _ -> None - | Eq (a, b) -> - let a = deref_sub a in - let b = deref_sub b in - return @@ Eq (a, b) - | Not u -> return @@ Not (deref_sub u) - | App_fun (f, args) -> - let args = args |> Iter.map deref_sub |> Iter.to_list in - if args <> [] then - return @@ App_fun (f, args) - else - None - | App_ho (f, a) -> - let f = deref_sub f in - let a = deref_sub a in - return @@ App_ho (f, a) - | If (a, b, c) -> return @@ If (deref_sub a, deref_sub b, deref_sub c) - - let[@inline] add_term self t : e_node = add_term_rec_ self t - let mem_term = mem - - let set_as_lit self (n : e_node) (lit : Lit.t) : unit = - match n.n_as_lit with - | Some _ -> () - | None -> - Log.debugf 15 (fun k -> - k "(@[cc.set-as-lit@ %a@ %a@])" E_node.pp n Lit.pp lit); - on_backtrack self (fun () -> n.n_as_lit <- None); - n.n_as_lit <- Some lit - - (* is [n] true or false? *) - let n_is_bool_value (self : t) n : bool = - E_node.equal n (n_true self) || E_node.equal n (n_false self) - - (* gather a pair [lits, pr], where [lits] is the set of - asserted literals needed in the explanation (which is useful for - the SAT solver), and [pr] is a proof, including sub-proofs for theory - merges. *) - let lits_and_proof_of_expl (self : t) (st : Expl_state.t) : - Lit.t list * Proof_term.step_id = - let { Expl_state.lits; th_lemmas = _ } = st in - let pr = Expl_state.proof_of_th_lemmas st self.proof in - lits, pr - - (* main CC algo: add terms from [pending] to the signature table, - check for collisions *) - let rec update_tasks (self : t) : unit = - while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do - while not @@ Vec.is_empty self.pending do - task_pending_ self (Vec.pop_exn self.pending) - done; - while not @@ Vec.is_empty self.combine do - task_combine_ self (Vec.pop_exn self.combine) - done - done - - and task_pending_ self (n : e_node) : unit = - (* check if some parent collided *) - match n.n_sig0 with - | None -> () (* no-op *) - | Some (Eq (a, b)) -> - (* if [a=b] is now true, merge [(a=b)] and [true] *) - if same_class a b then ( - let expl = Expl.mk_merge a b in - Log.debugf 5 (fun k -> - k "(@[cc.pending.eq@ %a@ :r1 %a@ :r2 %a@])" E_node.pp n E_node.pp a - E_node.pp b); - merge_classes self n (n_true self) expl - ) - | Some (Not u) -> - (* [u = bool ==> not u = not bool] *) - let r_u = find_ u in - if E_node.equal r_u (n_true self) then ( - let expl = Expl.mk_merge u (n_true self) in - merge_classes self n (n_false self) expl - ) else if E_node.equal r_u (n_false self) then ( - let expl = Expl.mk_merge u (n_false self) in - merge_classes self n (n_true self) expl - ) - | Some s0 -> - (* update the signature by using [find] on each sub-e_node *) - let s = update_sig s0 in - (match find_signature self s with - | None -> - (* add to the signature table [sig(n) --> n] *) - add_signature self s n - | Some u when E_node.equal n u -> () - | Some u -> - (* [t1] and [t2] must be applications of the same symbol to - arguments that are pairwise equal *) - assert (n != u); - let expl = Expl.mk_congruence n u in - merge_classes self n u expl) - - and task_combine_ self = function - | CT_merge (a, b, e_ab) -> task_merge_ self a b e_ab - | CT_act (Handler_action.Act_merge (t, u, e)) -> task_merge_ self t u e - | CT_act (Handler_action.Act_propagate (lit, reason)) -> - (* will return this propagation to the caller *) - Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }) - - (* main CC algo: merge equivalence classes in [st.combine]. - @raise Exn_unsat if merge fails *) - and task_merge_ self a b e_ab : unit = - let ra = find_ a in - let rb = find_ b in - if not @@ E_node.equal ra rb then ( - assert (E_node.is_root ra); - assert (E_node.is_root rb); - Stat.incr self.count_merge; - (* check we're not merging [true] and [false] *) - if - (E_node.equal ra (n_true self) && E_node.equal rb (n_false self)) - || (E_node.equal rb (n_true self) && E_node.equal ra (n_false self)) - then ( - Log.debugf 5 (fun k -> - k - "(@[cc.merge.true_false_conflict@ @[:r1 %a@ :t1 %a@]@ @[:r2 \ - %a@ :t2 %a@]@ :e_ab %a@])" - E_node.pp ra E_node.pp a E_node.pp rb E_node.pp b Expl.pp e_ab); - let th = ref false in - (* TODO: - C1: P.true_neq_false - C2: lemma [lits |- true=false] (and resolve on theory proofs) - C3: r1 C1 C2 - *) - let expl_st = Expl_state.create () in - explain_decompose_expl self expl_st e_ab; - explain_equal_rec_ self expl_st a ra; - explain_equal_rec_ self expl_st b rb; - - (* regular conflict *) - let lits, pr = lits_and_proof_of_expl self expl_st in - raise_conflict_ self ~th:!th (List.rev_map Lit.neg lits) pr - ); - (* We will merge [r_from] into [r_into]. - we try to ensure that [size ra <= size rb] in general, but always - keep values as representative *) - let r_from, r_into = - if n_is_bool_value self ra then - rb, ra - else if n_is_bool_value self rb then - ra, rb - else if size_ ra > size_ rb then - rb, ra - else - ra, rb - in - (* when merging terms with [true] or [false], possibly propagate them to SAT *) - let merge_bool r1 t1 r2 t2 = - if E_node.equal r1 (n_true self) then - propagate_bools self r2 t2 r1 t1 e_ab true - else if E_node.equal r1 (n_false self) then - propagate_bools self r2 t2 r1 t1 e_ab false - in - - merge_bool ra a rb b; - merge_bool rb b ra a; - - (* perform [union r_from r_into] *) - Log.debugf 15 (fun k -> - k "(@[cc.merge@ :from %a@ :into %a@])" E_node.pp r_from E_node.pp - r_into); - - (* call [on_pre_merge] functions, and merge theory data items *) - (* explanation is [a=ra & e_ab & b=rb] *) - (let expl = - Expl.mk_list [ e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb ] - in - - let handle_act = function - | Ok l -> push_action_l self l - | Error (Handler_action.Conflict expl) -> - raise_conflict_from_expl self expl - in - - Event.emit_iter self.on_pre_merge - (self, r_into, r_from, expl) - ~f:handle_act; - Event.emit_iter self.on_pre_merge2 - (self, r_into, r_from, expl) - ~f:handle_act); - - (* TODO: merge plugin data here, _after_ the pre-merge hooks are called, - so they have a chance of observing pre-merge plugin data *) - ((* parents might have a different signature, check for collisions *) - E_node.iter_parents r_from (fun parent -> push_pending self parent); - (* for each e_node in [r_from]'s class, make it point to [r_into] *) - E_node.iter_class r_from (fun u -> - assert (u.n_root == r_from); - u.n_root <- r_into); - (* capture current state *) - let r_into_old_next = r_into.n_next in - let r_from_old_next = r_from.n_next in - let r_into_old_parents = r_into.n_parents in - let r_into_old_bits = r_into.n_bits in - (* swap [into.next] and [from.next], merging the classes *) - r_into.n_next <- r_from_old_next; - r_from.n_next <- r_into_old_next; - r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; - r_into.n_size <- r_into.n_size + r_from.n_size; - r_into.n_bits <- Bits.merge r_into.n_bits r_from.n_bits; - (* on backtrack, unmerge classes and restore the pointers to [r_from] *) - on_backtrack self (fun () -> - Log.debugf 30 (fun k -> - k "(@[cc.undo_merge@ :from %a@ :into %a@])" E_node.pp r_from - E_node.pp r_into); - r_into.n_bits <- r_into_old_bits; - r_into.n_next <- r_into_old_next; - r_from.n_next <- r_from_old_next; - r_into.n_parents <- r_into_old_parents; - (* NOTE: this must come after the restoration of [next] pointers, - otherwise we'd iterate on too big a class *) - E_node.iter_class_ r_from (fun u -> u.n_root <- r_from); - r_into.n_size <- r_into.n_size - r_from.n_size)); - - (* update explanations (a -> b), arbitrarily. - Note that here we merge the classes by adding a bridge between [a] - and [b], not their roots. *) - reroot_expl self a; - assert (a.n_expl = FL_none); - (* on backtracking, link may be inverted, but we delete the one - that bridges between [a] and [b] *) - on_backtrack self (fun () -> - match a.n_expl, b.n_expl with - | FL_some e, _ when E_node.equal e.next b -> a.n_expl <- FL_none - | _, FL_some e when E_node.equal e.next a -> b.n_expl <- FL_none - | _ -> assert false); - a.n_expl <- FL_some { next = b; expl = e_ab }; - (* call [on_post_merge] *) - Event.emit_iter self.on_post_merge (self, r_into, r_from) - ~f:(push_action_l self) - ) - - (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] - in the equiv class of [r1] that is a known literal back to the SAT solver - and which is not the one initially merged. - We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) - and propagate_bools self r1 t1 r2 t2 (e_12 : explanation) sign : unit = - (* explanation for [t1 =e= t2 = r2] *) - let half_expl_and_pr = - lazy - (let st = Expl_state.create () in - explain_decompose_expl self st e_12; - explain_equal_rec_ self st r2 t2; - st) - in - (* TODO: flag per class, `or`-ed on merge, to indicate if the class - contains at least one lit *) - E_node.iter_class r1 (fun u1 -> - (* propagate if: - - [u1] is a proper literal - - [t2 != r2], because that can only happen - after an explicit merge (no way to obtain that by propagation) - *) - match E_node.as_lit u1 with - | Some lit when not (E_node.equal r2 t2) -> - let lit = - if sign then - lit - else - Lit.neg lit - in - (* apply sign *) - Log.debugf 5 (fun k -> k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); - (* complete explanation with the [u1=t1] chunk *) - let (lazy st) = half_expl_and_pr in - let st = Expl_state.copy st in - (* do not modify shared st *) - explain_equal_rec_ self st u1 t1; - - (* propagate only if this doesn't depend on some semantic values *) - let reason () = - (* true literals explaining why t1=t2 *) - let guard = st.lits in - (* get a proof of [guard /\ ¬lit] being absurd, to propagate [lit] *) - Expl_state.add_lit st (Lit.neg lit); - let _, pr = lits_and_proof_of_expl self st in - guard, pr - in - Vec.push self.res_acts (Result_action.Act_propagate { lit; reason }); - Event.emit_iter self.on_propagate (self, lit, reason) - ~f:(push_action_l self); - Stat.incr self.count_props - | _ -> ()) - - (* raise a conflict from an explanation, typically from an event handler. - Raises E_confl with a result conflict. *) - and raise_conflict_from_expl self (expl : Expl.t) : 'a = - Log.debugf 5 (fun k -> - k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); - let st = Expl_state.create () in - explain_decompose_expl self st expl; - let lits, pr = lits_and_proof_of_expl self st in - let c = List.rev_map Lit.neg lits in - let th = st.th_lemmas <> [] in - raise_conflict_ self ~th c pr - - let add_iter self it : unit = it (fun t -> ignore @@ add_term_rec_ self t) - - let push_level (self : t) : unit = - assert (not self.in_loop); - Backtrack_stack.push_level self.undo - - let pop_levels (self : t) n : unit = - assert (not self.in_loop); - Vec.clear self.pending; - Vec.clear self.combine; - Log.debugf 15 (fun k -> - k "(@[cc.pop-levels %d@ :n-lvls %d@])" n - (Backtrack_stack.n_levels self.undo)); - Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f ()); - () - - let assert_eq self t u expl : unit = - assert (not self.in_loop); - let t = add_term self t in - let u = add_term self u in - (* merge [a] and [b] *) - merge_classes self t u expl - - (* assert that this boolean literal holds. - if a lit is [= a b], merge [a] and [b]; - otherwise merge the atom with true/false *) - let assert_lit self lit : unit = - assert (not self.in_loop); - let t = Lit.term lit in - Log.debugf 15 (fun k -> k "(@[cc.assert-lit@ %a@])" Lit.pp lit); - let sign = Lit.sign lit in - match A.view_as_cc t with - | Eq (a, b) when sign -> assert_eq self a b (Expl.mk_lit lit) - | _ -> - (* equate t and true/false *) - let rhs = n_bool self sign in - let n = add_term self t in - (* TODO: ensure that this is O(1). - basically, just have [n] point to true/false and thus acquire - the corresponding value, so its superterms (like [ite]) can evaluate - properly *) - (* TODO: use oriented merge (force direction [n -> rhs]) *) - merge_classes self n rhs (Expl.mk_lit lit) - - let[@inline] assert_lits self lits : unit = - assert (not self.in_loop); - Iter.iter (assert_lit self) lits - - let merge self n1 n2 expl = - assert (not self.in_loop); - Log.debugf 5 (fun k -> - k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" E_node.pp n1 - E_node.pp n2 Expl.pp expl); - assert (Term.equal (Term.ty n1.n_term) (Term.ty n2.n_term)); - merge_classes self n1 n2 expl - - let merge_t self t1 t2 expl = - merge self (add_term self t1) (add_term self t2) expl - - let explain_eq self n1 n2 : Resolved_expl.t = - let st = Expl_state.create () in - explain_equal_rec_ self st n1 n2; - (* FIXME: also need to return the proof? *) - Expl_state.to_resolved_expl st - - let explain_expl (self : t) expl : Resolved_expl.t = - let expl_st = Expl_state.create () in - explain_decompose_expl self expl_st expl; - Expl_state.to_resolved_expl expl_st - - let[@inline] on_pre_merge self = Event.of_emitter self.on_pre_merge - let[@inline] on_pre_merge2 self = Event.of_emitter self.on_pre_merge2 - let[@inline] on_post_merge self = Event.of_emitter self.on_post_merge - let[@inline] on_new_term self = Event.of_emitter self.on_new_term - let[@inline] on_conflict self = Event.of_emitter self.on_conflict - let[@inline] on_propagate self = Event.of_emitter self.on_propagate - let[@inline] on_is_subterm self = Event.of_emitter self.on_is_subterm - - let create ?(stat = Stat.global) ?(size = `Big) (tst : Term.store) - (proof : Proof_trace.t) : t = - let size = - match size with - | `Small -> 128 - | `Big -> 2048 - in - let bitgen = Bits.mk_gen () in - let field_marked_explain = Bits.mk_field bitgen in - let rec cc = - { - tst; - proof; - tbl = T_tbl.create size; - signatures_tbl = Sig_tbl.create size; - bitgen; - on_pre_merge = Event.Emitter.create (); - on_pre_merge2 = Event.Emitter.create (); - on_post_merge = Event.Emitter.create (); - on_new_term = Event.Emitter.create (); - on_conflict = Event.Emitter.create (); - on_propagate = Event.Emitter.create (); - on_is_subterm = Event.Emitter.create (); - pending = Vec.create (); - combine = Vec.create (); - undo = Backtrack_stack.create (); - true_; - false_; - in_loop = false; - res_acts = Vec.create (); - field_marked_explain; - count_conflict = Stat.mk_int stat "cc.conflicts"; - count_props = Stat.mk_int stat "cc.propagations"; - count_merge = Stat.mk_int stat "cc.merges"; - } - and true_ = lazy (add_term cc (Term.true_ tst)) - and false_ = lazy (add_term cc (Term.false_ tst)) in - ignore (Lazy.force true_ : e_node); - ignore (Lazy.force false_ : e_node); - cc - - let[@inline] find_t self t : repr = - let n = T_tbl.find self.tbl t in - find_ n - - let pop_acts_ self = - let rec loop acc = - match Vec.pop self.res_acts with - | None -> acc - | Some x -> loop (x :: acc) - in - loop [] - - let check self : Result_action.or_conflict = - Log.debug 5 "(cc.check)"; - self.in_loop <- true; - let@ () = Stdlib.Fun.protect ~finally:(fun () -> self.in_loop <- false) in - try - update_tasks self; - let l = pop_acts_ self in - Ok l - with E_confl c -> Error c - - let check_inv_enabled_ = true (* XXX NUDGE *) - - (* check some internal invariants *) - let check_inv_ (self : t) : unit = - if check_inv_enabled_ then ( - Log.debug 2 "(cc.check-invariants)"; - all_classes self - |> Iter.flat_map E_node.iter_class - |> Iter.iter (fun n -> - match n.n_sig0 with - | None -> () - | Some s -> - let s' = update_sig s in - let ok = - match find_signature self s' with - | None -> false - | Some r -> E_node.equal r n.n_root - in - if not ok then - Log.debugf 0 (fun k -> - k "(@[cc.check.fail@ :n %a@ :sig %a@ :actual-sig %a@])" - E_node.pp n Signature.pp s Signature.pp s')) - ) - - (* model: return all the classes *) - let get_model (self : t) : repr Iter.t Iter.t = - check_inv_ self; - all_classes self |> Iter.map E_node.iter_class -end diff --git a/src/cc/dune b/src/cc/dune index d249010d..5994f7ba 100644 --- a/src/cc/dune +++ b/src/cc/dune @@ -2,6 +2,6 @@ (name Sidekick_cc) (public_name sidekick.cc) (synopsis "main congruence closure implementation") - (private_modules core_cc) + (private_modules types_ signature) (libraries containers iter sidekick.sigs sidekick.core sidekick.util) (flags :standard -open Sidekick_util)) diff --git a/src/cc/e_node.ml b/src/cc/e_node.ml new file mode 100644 index 00000000..2eb3fb01 --- /dev/null +++ b/src/cc/e_node.ml @@ -0,0 +1,50 @@ +open Types_ + +type t = e_node + +let[@inline] equal (n1 : t) n2 = n1 == n2 +let[@inline] hash n = Term.hash n.n_term +let[@inline] term n = n.n_term +let[@inline] pp out n = Term.pp_debug out n.n_term +let[@inline] as_lit n = n.n_as_lit + +let make (t : Term.t) : t = + let rec n = + { + n_term = t; + n_sig0 = None; + n_bits = Bits.empty; + n_parents = Bag.empty; + n_as_lit = None; + (* TODO: provide a method to do it *) + n_root = n; + n_expl = FL_none; + n_next = n; + n_size = 1; + } + in + n + +let[@inline] is_root (n : e_node) : bool = n.n_root == n + +(* traverse the equivalence class of [n] *) +let iter_class_ (n : e_node) : e_node Iter.t = + fun yield -> + let rec aux u = + yield u; + if u.n_next != n then aux u.n_next + in + aux n + +let[@inline] iter_class n = + assert (is_root n); + iter_class_ n + +let[@inline] iter_parents (n : e_node) : e_node Iter.t = + assert (is_root n); + Bag.to_iter n.n_parents + +module Internal_ = struct + let iter_class_ = iter_class_ + let make = make +end diff --git a/src/cc/e_node.mli b/src/cc/e_node.mli new file mode 100644 index 00000000..6486c74d --- /dev/null +++ b/src/cc/e_node.mli @@ -0,0 +1,61 @@ +(** E-node. + + An e-node is a node in the congruence closure that is contained + in some equivalence classe). + An equivalence class is a set of terms that are currently equal + in the partial model built by the solver. + The class is represented by a collection of nodes, one of which is + distinguished and is called the "representative". + + All information pertaining to the whole equivalence class is stored + in its representative's {!E_node.t}. + + When two classes become equal (are "merged"), one of the two + representatives is picked as the representative of the new class. + The new class contains the union of the two old classes' nodes. + + We also allow theories to store additional information in the + representative. This information can be used when two classes are + merged, to detect conflicts and solve equations à la Shostak. + *) + +open Types_ + +type t = Types_.e_node +(** An E-node. + + A value of type [t] points to a particular Term.t, but see + {!find} to get the representative of the class. *) + +include Sidekick_sigs.PRINT with type t := t + +val term : t -> Term.t +(** Term contained in this equivalence class. + If [is_root n], then [Term.t n] is the class' representative Term.t. *) + +val equal : t -> t -> bool +(** Are two classes {b physically} equal? To check for + logical equality, use [CC.E_node.equal (CC.find cc n1) (CC.find cc n2)] + which checks for equality of representatives. *) + +val hash : t -> int +(** An opaque hash of this E_node.t. *) + +val is_root : t -> bool +(** Is the E_node.t a root (ie the representative of its class)? + See {!find} to get the root. *) + +val iter_class : t -> t Iter.t +(** Traverse the congruence class. + Precondition: [is_root n] (see {!find} below) *) + +val iter_parents : t -> t Iter.t +(** Traverse the parents of the class. + Precondition: [is_root n] (see {!find} below) *) + +val as_lit : t -> Lit.t option + +module Internal_ : sig + val iter_class_ : t -> t Iter.t + val make : Term.t -> t +end diff --git a/src/cc/expl.ml b/src/cc/expl.ml new file mode 100644 index 00000000..534b36a7 --- /dev/null +++ b/src/cc/expl.ml @@ -0,0 +1,50 @@ +open Types_ + +type t = explanation + +let rec pp out (e : explanation) = + match e with + | E_trivial -> Fmt.string out "reduction" + | E_lit lit -> Lit.pp out lit + | E_congruence (n1, n2) -> + Fmt.fprintf out "(@[congruence@ %a@ %a@])" E_node.pp n1 E_node.pp n2 + | E_merge (a, b) -> + Fmt.fprintf out "(@[merge@ %a@ %a@])" E_node.pp a E_node.pp b + | E_merge_t (a, b) -> + Fmt.fprintf out "(@[merge@ @[:n1 %a@]@ @[:n2 %a@]@])" Term.pp_debug a + Term.pp_debug b + | E_theory (t, u, es, _) -> + Fmt.fprintf out "(@[th@ :t `%a`@ :u `%a`@ :expl_sets %a@])" Term.pp_debug t + Term.pp_debug u + (Util.pp_list + @@ Fmt.Dump.triple Term.pp_debug Term.pp_debug (Fmt.Dump.list pp)) + es + | E_and (a, b) -> Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b + +let mk_trivial : t = E_trivial +let[@inline] mk_congruence n1 n2 : t = E_congruence (n1, n2) + +let[@inline] mk_merge a b : t = + if E_node.equal a b then + mk_trivial + else + E_merge (a, b) + +let[@inline] mk_merge_t a b : t = + if Term.equal a b then + mk_trivial + else + E_merge_t (a, b) + +let[@inline] mk_lit l : t = E_lit l +let[@inline] mk_theory t u es pr = E_theory (t, u, es, pr) + +let rec mk_list l = + match l with + | [] -> mk_trivial + | [ x ] -> x + | E_trivial :: tl -> mk_list tl + | x :: y -> + (match mk_list y with + | E_trivial -> x + | y' -> E_and (x, y')) diff --git a/src/cc/expl.mli b/src/cc/expl.mli new file mode 100644 index 00000000..24618076 --- /dev/null +++ b/src/cc/expl.mli @@ -0,0 +1,47 @@ +(** Explanations + + Explanations are specialized proofs, created by the congruence closure + when asked to justify why two terms are equal. *) + +open Types_ + +type t = Types_.explanation + +include Sidekick_sigs.PRINT with type t := t + +val mk_merge : E_node.t -> E_node.t -> t +(** Explanation: the nodes were explicitly merged *) + +val mk_merge_t : Term.t -> Term.t -> t +(** Explanation: the terms were explicitly merged *) + +val mk_lit : Lit.t -> t +(** Explanation: we merged [t] and [u] because of literal [t=u], + or we merged [t] and [true] because of literal [t], + or [t] and [false] because of literal [¬t] *) + +val mk_list : t list -> t +(** Conjunction of explanations *) + +val mk_congruence : E_node.t -> E_node.t -> t + +val mk_theory : + Term.t -> Term.t -> (Term.t * Term.t * t list) list -> Proof_term.step_id -> t +(** [mk_theory t u expl_sets pr] builds a theory explanation for + why [|- t=u]. It depends on sub-explanations [expl_sets] which + are tuples [ (t_i, u_i, expls_i) ] where [expls_i] are + explanations that justify [t_i = u_i] in the current congruence closure. + + The proof [pr] is the theory lemma, of the form + [ (t_i = u_i)_i |- t=u ]. + It is resolved against each [expls_i |- t_i=u_i] obtained from + [expl_sets], on pivot [t_i=u_i], to obtain a proof of [Gamma |- t=u] + where [Gamma] is a subset of the literals asserted into the congruence + closure. + + For example for the lemma [a=b] deduced by injectivity + from [Some a=Some b] in the theory of datatypes, + the arguments would be + [a, b, [Some a, Some b, mk_merge_t (Some a)(Some b)], pr] + where [pr] is the injectivity lemma [Some a=Some b |- a=b]. + *) diff --git a/src/cc/plugin/sidekick_cc_plugin.ml b/src/cc/plugin.ml similarity index 90% rename from src/cc/plugin/sidekick_cc_plugin.ml rename to src/cc/plugin.ml index 39327e1e..8c25c203 100644 --- a/src/cc/plugin/sidekick_cc_plugin.ml +++ b/src/cc/plugin.ml @@ -1,16 +1,16 @@ -open Sidekick_core -open Sidekick_cc +open Types_ +open Sigs_plugin module type EXTENDED_PLUGIN_BUILDER = sig include MONOID_PLUGIN_BUILDER - val mem : t -> M.CC.E_node.t -> bool - (** Does the CC E_node.t have a monoid value? *) + val mem : t -> E_node.t -> bool + (** Does the CC.E_node.t have a monoid value? *) - val get : t -> M.CC.E_node.t -> M.t option - (** Get monoid value for this CC E_node.t, if any *) + val get : t -> E_node.t -> M.t option + (** Get monoid value for this CC.E_node.t, if any *) - val iter_all : t -> (M.CC.repr * M.t) Iter.t + val iter_all : t -> (CC.repr * M.t) Iter.t include Sidekick_sigs.BACKTRACKABLE0 with type t := t include Sidekick_sigs.PRINT with type t := t @@ -19,10 +19,7 @@ end module Make (M : MONOID_PLUGIN_ARG) : EXTENDED_PLUGIN_BUILDER with module M = M = struct module M = M - module CC = M.CC - module E_node = CC.E_node module Cls_tbl = Backtrackable_tbl.Make (E_node) - module Expl = CC.Expl module type DYN_PL_FOR_M = DYN_MONOID_PLUGIN with module M = M @@ -40,7 +37,7 @@ module Make (M : MONOID_PLUGIN_ARG) : let values : M.t Cls_tbl.t = Cls_tbl.create ?size () (* bit in CC to filter out quickly classes without value *) - let field_has_value : CC.E_node.bitfield = + let field_has_value : CC.bitfield = CC.allocate_bitfield ~descr:("monoid." ^ M.name ^ ".has-value") cc let push_level () = Cls_tbl.push_level values @@ -91,7 +88,7 @@ module Make (M : MONOID_PLUGIN_ARG) : | Error (CC.Handler_action.Conflict expl) -> Error.errorf "when merging@ @[for node %a@],@ values %a and %a:@ conflict %a" - E_node.pp n_u M.pp m_u M.pp m_u' CC.Expl.pp expl + E_node.pp n_u M.pp m_u M.pp m_u' Expl.pp expl | Ok (m_u_merged, merge_acts) -> acts := List.rev_append merge_acts !acts; Log.debugf 20 (fun k -> @@ -111,7 +108,7 @@ module Make (M : MONOID_PLUGIN_ARG) : let iter_all : _ Iter.t = Cls_tbl.to_iter values let on_pre_merge cc n1 n2 e_n1_n2 : CC.Handler_action.or_conflict = - let exception E of M.CC.Handler_action.conflict in + let exception E of CC.Handler_action.conflict in let acts = ref [] in try (match get n1, get n2 with diff --git a/src/cc/plugin/sidekick_cc_plugin.mli b/src/cc/plugin.mli similarity index 76% rename from src/cc/plugin/sidekick_cc_plugin.mli rename to src/cc/plugin.mli index 413d8408..687e1d26 100644 --- a/src/cc/plugin/sidekick_cc_plugin.mli +++ b/src/cc/plugin.mli @@ -1,17 +1,17 @@ (** Congruence Closure Plugin *) -open Sidekick_cc +open Sigs_plugin module type EXTENDED_PLUGIN_BUILDER = sig include MONOID_PLUGIN_BUILDER - val mem : t -> M.CC.E_node.t -> bool + val mem : t -> E_node.t -> bool (** Does the CC.E_node.t have a monoid value? *) - val get : t -> M.CC.E_node.t -> M.t option + val get : t -> E_node.t -> M.t option (** Get monoid value for this CC.E_node.t, if any *) - val iter_all : t -> (M.CC.repr * M.t) Iter.t + val iter_all : t -> (CC.repr * M.t) Iter.t include Sidekick_sigs.BACKTRACKABLE0 with type t := t include Sidekick_sigs.PRINT with type t := t diff --git a/src/cc/resolved_expl.ml b/src/cc/resolved_expl.ml new file mode 100644 index 00000000..c16c1edd --- /dev/null +++ b/src/cc/resolved_expl.ml @@ -0,0 +1,6 @@ +open Types_ + +type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id } + +let pp out (self : t) = + Fmt.fprintf out "(@[resolved-expl@ %a@])" (Util.pp_list Lit.pp) self.lits diff --git a/src/cc/resolved_expl.mli b/src/cc/resolved_expl.mli new file mode 100644 index 00000000..537a11be --- /dev/null +++ b/src/cc/resolved_expl.mli @@ -0,0 +1,17 @@ +(** Resolved explanations. + + The congruence closure keeps explanations for why terms are in the same + class. However these are represented in a compact, cheap form. + To use these explanations we need to {b resolve} them into a + resolved explanation, typically a list of + literals that are true in the current trail and are responsible for + merges. + + However, we can also have merged classes because they have the same value + in the current model. *) + +open Types_ + +type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id } + +include Sidekick_sigs.PRINT with type t := t diff --git a/src/cc/signature.ml b/src/cc/signature.ml new file mode 100644 index 00000000..8fdc5ee7 --- /dev/null +++ b/src/cc/signature.ml @@ -0,0 +1,53 @@ +(** A signature is a shallow term shape where immediate subterms + are representative *) + +open View +open Types_ + +type t = signature + +let equal (s1 : t) s2 : bool = + let open View in + match s1, s2 with + | Bool b1, Bool b2 -> b1 = b2 + | App_fun (f1, []), App_fun (f2, []) -> Const.equal f1 f2 + | App_fun (f1, l1), App_fun (f2, l2) -> + Const.equal f1 f2 && CCList.equal E_node.equal l1 l2 + | App_ho (f1, a1), App_ho (f2, a2) -> E_node.equal f1 f2 && E_node.equal a1 a2 + | Not a, Not b -> E_node.equal a b + | If (a1, b1, c1), If (a2, b2, c2) -> + E_node.equal a1 a2 && E_node.equal b1 b2 && E_node.equal c1 c2 + | Eq (a1, b1), Eq (a2, b2) -> E_node.equal a1 a2 && E_node.equal b1 b2 + | Opaque u1, Opaque u2 -> E_node.equal u1 u2 + | Bool _, _ + | App_fun _, _ + | App_ho _, _ + | If _, _ + | Eq _, _ + | Opaque _, _ + | Not _, _ -> + false + +let hash (s : t) : int = + let module H = CCHash in + match s with + | Bool b -> H.combine2 10 (H.bool b) + | App_fun (f, l) -> H.combine3 20 (Const.hash f) (H.list E_node.hash l) + | App_ho (f, a) -> H.combine3 30 (E_node.hash f) (E_node.hash a) + | Eq (a, b) -> H.combine3 40 (E_node.hash a) (E_node.hash b) + | Opaque u -> H.combine2 50 (E_node.hash u) + | If (a, b, c) -> + H.combine4 60 (E_node.hash a) (E_node.hash b) (E_node.hash c) + | Not u -> H.combine2 70 (E_node.hash u) + +let pp out = function + | Bool b -> Fmt.bool out b + | App_fun (f, []) -> Const.pp out f + | App_fun (f, l) -> + Fmt.fprintf out "(@[%a@ %a@])" Const.pp f (Util.pp_list E_node.pp) l + | App_ho (f, a) -> Fmt.fprintf out "(@[%a@ %a@])" E_node.pp f E_node.pp a + | Opaque t -> E_node.pp out t + | Not u -> Fmt.fprintf out "(@[not@ %a@])" E_node.pp u + | Eq (a, b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" E_node.pp a E_node.pp b + | If (a, b, c) -> + Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" E_node.pp a E_node.pp b E_node.pp c diff --git a/src/cc/sigs.ml b/src/cc/sigs.ml index 20541c45..fd0fbed3 100644 --- a/src/cc/sigs.ml +++ b/src/cc/sigs.ml @@ -2,505 +2,3 @@ open Sidekick_core module View = View - -(** Arguments to a congruence closure's implementation *) -module type ARG = sig - val view_as_cc : Term.t -> (Const.t, Term.t, Term.t Iter.t) View.t - (** View the Term.t through the lens of the congruence closure *) -end - -(** Collection of input types, and types defined by the congruence closure *) -module type ARGS_CLASSES_EXPL_EVENT = sig - (** E-node. - - An e-node is a node in the congruence closure that is contained - in some equivalence classe). - An equivalence class is a set of terms that are currently equal - in the partial model built by the solver. - The class is represented by a collection of nodes, one of which is - distinguished and is called the "representative". - - All information pertaining to the whole equivalence class is stored - in its representative's {!E_node.t}. - - When two classes become equal (are "merged"), one of the two - representatives is picked as the representative of the new class. - The new class contains the union of the two old classes' nodes. - - We also allow theories to store additional information in the - representative. This information can be used when two classes are - merged, to detect conflicts and solve equations à la Shostak. - *) - module E_node : sig - type t - (** An E-node. - - A value of type [t] points to a particular Term.t, but see - {!find} to get the representative of the class. *) - - include Sidekick_sigs.PRINT with type t := t - - val term : t -> Term.t - (** Term contained in this equivalence class. - If [is_root n], then [Term.t n] is the class' representative Term.t. *) - - val equal : t -> t -> bool - (** Are two classes {b physically} equal? To check for - logical equality, use [CC.E_node.equal (CC.find cc n1) (CC.find cc n2)] - which checks for equality of representatives. *) - - val hash : t -> int - (** An opaque hash of this E_node.t. *) - - val is_root : t -> bool - (** Is the E_node.t a root (ie the representative of its class)? - See {!find} to get the root. *) - - val iter_class : t -> t Iter.t - (** Traverse the congruence class. - Precondition: [is_root n] (see {!find} below) *) - - val iter_parents : t -> t Iter.t - (** Traverse the parents of the class. - Precondition: [is_root n] (see {!find} below) *) - - (* FIXME: - [@@alert refactor "this should be replaced with a Per_class concept"] - *) - - type bitfield - (** A field in the bitfield of this node. This should only be - allocated when a theory is initialized. - - Bitfields are accessed using preallocated keys. - See {!CC_S.allocate_bitfield}. - - All fields are initially 0, are backtracked automatically, - and are merged automatically when classes are merged. *) - end - - (** Explanations - - Explanations are specialized proofs, created by the congruence closure - when asked to justify why two terms are equal. *) - module Expl : sig - type t - - include Sidekick_sigs.PRINT with type t := t - - val mk_merge : E_node.t -> E_node.t -> t - (** Explanation: the nodes were explicitly merged *) - - val mk_merge_t : Term.t -> Term.t -> t - (** Explanation: the terms were explicitly merged *) - - val mk_lit : Lit.t -> t - (** Explanation: we merged [t] and [u] because of literal [t=u], - or we merged [t] and [true] because of literal [t], - or [t] and [false] because of literal [¬t] *) - - val mk_list : t list -> t - (** Conjunction of explanations *) - - val mk_theory : - Term.t -> - Term.t -> - (Term.t * Term.t * t list) list -> - Proof_term.step_id -> - t - (** [mk_theory t u expl_sets pr] builds a theory explanation for - why [|- t=u]. It depends on sub-explanations [expl_sets] which - are tuples [ (t_i, u_i, expls_i) ] where [expls_i] are - explanations that justify [t_i = u_i] in the current congruence closure. - - The proof [pr] is the theory lemma, of the form - [ (t_i = u_i)_i |- t=u ]. - It is resolved against each [expls_i |- t_i=u_i] obtained from - [expl_sets], on pivot [t_i=u_i], to obtain a proof of [Gamma |- t=u] - where [Gamma] is a subset of the literals asserted into the congruence - closure. - - For example for the lemma [a=b] deduced by injectivity - from [Some a=Some b] in the theory of datatypes, - the arguments would be - [a, b, [Some a, Some b, mk_merge_t (Some a)(Some b)], pr] - where [pr] is the injectivity lemma [Some a=Some b |- a=b]. - *) - end - - (** Resolved explanations. - - The congruence closure keeps explanations for why terms are in the same - class. However these are represented in a compact, cheap form. - To use these explanations we need to {b resolve} them into a - resolved explanation, typically a list of - literals that are true in the current trail and are responsible for - merges. - - However, we can also have merged classes because they have the same value - in the current model. *) - module Resolved_expl : sig - type t = { lits: Lit.t list; pr: Proof_trace.t -> Proof_term.step_id } - - include Sidekick_sigs.PRINT with type t := t - end - - (** Per-node data *) - - type e_node = E_node.t - (** A node of the congruence closure *) - - type repr = E_node.t - (** Node that is currently a representative. *) - - type explanation = Expl.t -end - -(** Main congruence closure signature. - - The congruence closure handles the theory QF_UF (uninterpreted - function symbols). - It is also responsible for {i theory combination}, and provides - a general framework for equality reasoning that other - theories piggyback on. - - For example, the theory of datatypes relies on the congruence closure - to do most of the work, and "only" adds injectivity/disjointness/acyclicity - lemmas when needed. - - Similarly, a theory of arrays would hook into the congruence closure and - assert (dis)equalities as needed. -*) -module type S = sig - include ARGS_CLASSES_EXPL_EVENT - - type t - (** The congruence closure object. - It contains a fair amount of state and is mutable - and backtrackable. *) - - (** {3 Accessors} *) - - val term_store : t -> Term.store - val proof : t -> Proof_trace.t - - val find : t -> e_node -> repr - (** Current representative *) - - val add_term : t -> Term.t -> e_node - (** Add the Term.t to the congruence closure, if not present already. - Will be backtracked. *) - - val mem_term : t -> Term.t -> bool - (** Returns [true] if the Term.t is explicitly present in the congruence closure *) - - val allocate_bitfield : t -> descr:string -> E_node.bitfield - (** Allocate a new e_node field (see {!E_node.bitfield}). - - This field descriptor is henceforth reserved for all nodes - in this congruence closure, and can be set using {!set_bitfield} - for each class_ individually. - This can be used to efficiently store some metadata on nodes - (e.g. "is there a numeric value in the class" - or "is there a constructor Term.t in the class"). - - There may be restrictions on how many distinct fields are allocated - for a given congruence closure (e.g. at most {!Sys.int_size} fields). - *) - - val get_bitfield : t -> E_node.bitfield -> E_node.t -> bool - (** Access the bit field of the given e_node *) - - val set_bitfield : t -> E_node.bitfield -> bool -> E_node.t -> unit - (** Set the bitfield for the e_node. This will be backtracked. - See {!E_node.bitfield}. *) - - type propagation_reason = unit -> Lit.t list * Proof_term.step_id - - (** Handler Actions - - Actions that can be scheduled by event handlers. *) - module Handler_action : sig - type t = - | Act_merge of E_node.t * E_node.t * Expl.t - | Act_propagate of Lit.t * propagation_reason - - (* TODO: - - an action to modify data associated with a class - *) - - type conflict = Conflict of Expl.t [@@unboxed] - - type or_conflict = (t list, conflict) result - (** Actions or conflict scheduled by an event handler. - - - [Ok acts] is a list of merges and propagations - - [Error confl] is a conflict to resolve. - *) - end - - (** Result Actions. - - - Actions returned by the congruence closure after calling {!check}. *) - module Result_action : sig - type t = - | Act_propagate of { lit: Lit.t; reason: propagation_reason } - (** [propagate (Lit.t, reason)] declares that [reason() => Lit.t] - is a tautology. - - - [reason()] should return a list of literals that are currently true, - as well as a proof. - - [Lit.t] should be a literal of interest (see {!S.set_as_lit}). - - This function might never be called, a congruence closure has the right - to not propagate and only trigger conflicts. *) - - type conflict = - | Conflict of Lit.t list * Proof_term.step_id - (** [raise_conflict (c,pr)] declares that [c] is a tautology of - the theory of congruence. - @param pr the proof of [c] being a tautology *) - - type or_conflict = (t list, conflict) result - end - - (** {3 Events} - - Events triggered by the congruence closure, to which - other plugins can subscribe. *) - - (** Events emitted by the congruence closure when something changes. *) - val on_pre_merge : - t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t - (** [Ev_on_pre_merge acts n1 n2 expl] is emitted right before [n1] - and [n2] are merged with explanation [expl]. *) - - val on_pre_merge2 : - t -> (t * E_node.t * E_node.t * Expl.t, Handler_action.or_conflict) Event.t - (** Second phase of "on pre merge". This runs after {!on_pre_merge} - and is used by Plugins. {b NOTE}: Plugin state might be observed as already - changed in these handlers. *) - - val on_post_merge : - t -> (t * E_node.t * E_node.t, Handler_action.t list) Event.t - (** [ev_on_post_merge acts n1 n2] is emitted right after [n1] - and [n2] were merged. [find cc n1] and [find cc n2] will return - the same E_node.t. *) - - val on_new_term : t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t - (** [ev_on_new_term n t] is emitted whenever a new Term.t [t] - is added to the congruence closure. Its E_node.t is [n]. *) - - type ev_on_conflict = { cc: t; th: bool; c: Lit.t list } - (** Event emitted when a conflict occurs in the CC. - - [th] is true if the explanation for this conflict involves - at least one "theory" explanation; i.e. some of the equations - participating in the conflict are purely syntactic theories - like injectivity of constructors. *) - - val on_conflict : t -> (ev_on_conflict, unit) Event.t - (** [ev_on_conflict {th; c}] is emitted when the congruence - closure triggers a conflict by asserting the tautology [c]. *) - - val on_propagate : - t -> - ( t * Lit.t * (unit -> Lit.t list * Proof_term.step_id), - Handler_action.t list ) - Event.t - (** [ev_on_propagate Lit.t reason] is emitted whenever [reason() => Lit.t] - is a propagated lemma. See {!CC_ACTIONS.propagate}. *) - - val on_is_subterm : - t -> (t * E_node.t * Term.t, Handler_action.t list) Event.t - (** [ev_on_is_subterm n t] is emitted when [n] is a subterm of - another E_node.t for the first time. [t] is the Term.t corresponding to - the E_node.t [n]. This can be useful for theory combination. *) - - (** {3 Misc} *) - - val n_true : t -> E_node.t - (** Node for [true] *) - - val n_false : t -> E_node.t - (** Node for [false] *) - - val n_bool : t -> bool -> E_node.t - (** Node for either true or false *) - - val set_as_lit : t -> E_node.t -> Lit.t -> unit - (** map the given e_node to a literal. *) - - val find_t : t -> Term.t -> repr - (** Current representative of the Term.t. - @raise E_node.t_found if the Term.t is not already {!add}-ed. *) - - val add_iter : t -> Term.t Iter.t -> unit - (** Add a sequence of terms to the congruence closure *) - - val all_classes : t -> repr Iter.t - (** All current classes. This is costly, only use if there is no other solution *) - - val explain_eq : t -> E_node.t -> E_node.t -> Resolved_expl.t - (** Explain why the two nodes are equal. - Fails if they are not, in an unspecified way. *) - - val explain_expl : t -> Expl.t -> Resolved_expl.t - (** Transform explanation into an actionable conflict clause *) - - (* FIXME: remove - val raise_conflict_from_expl : t -> actions -> Expl.t -> 'a - (** Raise a conflict with the given explanation. - It must be a theory tautology that [expl ==> absurd]. - To be used in theories. - - This fails in an unspecified way if the explanation, once resolved, - satisfies {!Resolved_expl.is_semantic}. *) - *) - - val merge : t -> E_node.t -> E_node.t -> Expl.t -> unit - (** Merge these two nodes given this explanation. - It must be a theory tautology that [expl ==> n1 = n2]. - To be used in theories. *) - - val merge_t : t -> Term.t -> Term.t -> Expl.t -> unit - (** Shortcut for adding + merging *) - - (** {3 Main API *) - - val assert_eq : t -> Term.t -> Term.t -> Expl.t -> unit - (** Assert that two terms are equal, using the given explanation. *) - - val assert_lit : t -> Lit.t -> unit - (** Given a literal, assume it in the congruence closure and propagate - its consequences. Will be backtracked. - - Useful for the theory combination or the SAT solver's functor *) - - val assert_lits : t -> Lit.t Iter.t -> unit - (** Addition of many literals *) - - val check : t -> Result_action.or_conflict - (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. - Will use the {!actions} to propagate literals, declare conflicts, etc. *) - - val push_level : t -> unit - (** Push backtracking level *) - - val pop_levels : t -> int -> unit - (** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *) - - val get_model : t -> E_node.t Iter.t Iter.t - (** get all the equivalence classes so they can be merged in the model *) - - val create : - ?stat:Stat.t -> ?size:[ `Small | `Big ] -> Term.store -> Proof_trace.t -> t - (** Create a new congruence closure. - - @param term_store used to be able to create new terms. All terms - interacting with this congruence closure must belong in this term state - as well. - *) - - (**/**) - - module Debug_ : sig - val pp : t Fmt.printer - (** Print the whole CC *) - end - - (**/**) -end - -(* TODO: full EGG, also have a function to update the value when - the subterms (produced in [of_term]) are updated *) - -(** Data attached to the congruence closure classes. - - This helps theories keeping track of some state for each class. - The state of a class is the monoidal combination of the state for each - Term.t in the class (for example, the set of terms in the - class whose head symbol is a datatype constructor). *) -module type MONOID_PLUGIN_ARG = sig - module CC : S - - type t - (** Some type with a monoid structure *) - - include Sidekick_sigs.PRINT with type t := t - - val name : string - (** name of the monoid structure (short) *) - - (* FIXME: for subs, return list of e_nodes, and assume of_term already - returned data for them. *) - val of_term : - CC.t -> CC.E_node.t -> Term.t -> t option * (CC.E_node.t * t) list - (** [of_term n t], where [t] is the Term.t annotating node [n], - must return [maybe_m, l], where: - - - [maybe_m = Some m] if [t] has monoid value [m]; - otherwise [maybe_m=None] - - [l] is a list of [(u, m_u)] where each [u]'s Term.t - is a direct subterm of [t] - and [m_u] is the monoid value attached to [u]. - - *) - - val merge : - CC.t -> - CC.E_node.t -> - t -> - CC.E_node.t -> - t -> - CC.Expl.t -> - (t * CC.Handler_action.t list, CC.Handler_action.conflict) result - (** Monoidal combination of two values. - - [merge cc n1 mon1 n2 mon2 expl] returns the result of merging - monoid values [mon1] (for class [n1]) and [mon2] (for class [n2]) - when [n1] and [n2] are merged with explanation [expl]. - - @return [Ok mon] if the merge is acceptable, annotating the class of [n1 ∪ n2]; - or [Error expl'] if the merge is unsatisfiable. [expl'] can then be - used to trigger a conflict and undo the merge. - *) -end - -(** Stateful plugin holding a per-equivalence-class monoid. - - Helps keep track of monoid state per equivalence class. - A theory might use one or more instance(s) of this to - aggregate some theory-specific state over all terms, with - the information of what terms are already known to be equal - potentially saving work for the theory. *) -module type DYN_MONOID_PLUGIN = sig - module M : MONOID_PLUGIN_ARG - include Sidekick_sigs.DYN_BACKTRACKABLE - - val pp : unit Fmt.printer - - val mem : M.CC.E_node.t -> bool - (** Does the CC E_node.t have a monoid value? *) - - val get : M.CC.E_node.t -> M.t option - (** Get monoid value for this CC E_node.t, if any *) - - val iter_all : (M.CC.repr * M.t) Iter.t -end - -(** Builder for a plugin. - - The builder takes a congruence closure, and instantiate the - plugin on it. *) -module type MONOID_PLUGIN_BUILDER = sig - module M : MONOID_PLUGIN_ARG - - module type DYN_PL_FOR_M = DYN_MONOID_PLUGIN with module M = M - - type t = (module DYN_PL_FOR_M) - - val create_and_setup : ?size:int -> M.CC.t -> t - (** Create a new monoid state *) -end diff --git a/src/cc/sigs_plugin.ml b/src/cc/sigs_plugin.ml new file mode 100644 index 00000000..a1f0fe18 --- /dev/null +++ b/src/cc/sigs_plugin.ml @@ -0,0 +1,90 @@ +open Types_ + +(* TODO: full EGG, also have a function to update the value when + the subterms (produced in [of_term]) are updated *) + +(** Data attached to the congruence closure classes. + + This helps theories keeping track of some state for each class. + The state of a class is the monoidal combination of the state for each + Term.t in the class (for example, the set of terms in the + class whose head symbol is a datatype constructor). *) +module type MONOID_PLUGIN_ARG = sig + type t + (** Some type with a monoid structure *) + + include Sidekick_sigs.PRINT with type t := t + + val name : string + (** name of the monoid structure (short) *) + + (* FIXME: for subs, return list of e_nodes, and assume of_term already + returned data for them. *) + val of_term : CC.t -> E_node.t -> Term.t -> t option * (E_node.t * t) list + (** [of_term n t], where [t] is the Term.t annotating node [n], + must return [maybe_m, l], where: + + - [maybe_m = Some m] if [t] has monoid value [m]; + otherwise [maybe_m=None] + - [l] is a list of [(u, m_u)] where each [u]'s Term.t + is a direct subterm of [t] + and [m_u] is the monoid value attached to [u]. + + *) + + val merge : + CC.t -> + E_node.t -> + t -> + E_node.t -> + t -> + Expl.t -> + (t * CC.Handler_action.t list, CC.Handler_action.conflict) result + (** Monoidal combination of two values. + + [merge cc n1 mon1 n2 mon2 expl] returns the result of merging + monoid values [mon1] (for class [n1]) and [mon2] (for class [n2]) + when [n1] and [n2] are merged with explanation [expl]. + + @return [Ok mon] if the merge is acceptable, annotating the class of [n1 ∪ n2]; + or [Error expl'] if the merge is unsatisfiable. [expl'] can then be + used to trigger a conflict and undo the merge. + *) +end + +(** Stateful plugin holding a per-equivalence-class monoid. + + Helps keep track of monoid state per equivalence class. + A theory might use one or more instance(s) of this to + aggregate some theory-specific state over all terms, with + the information of what terms are already known to be equal + potentially saving work for the theory. *) +module type DYN_MONOID_PLUGIN = sig + module M : MONOID_PLUGIN_ARG + include Sidekick_sigs.DYN_BACKTRACKABLE + + val pp : unit Fmt.printer + + val mem : E_node.t -> bool + (** Does the CC E_node.t have a monoid value? *) + + val get : E_node.t -> M.t option + (** Get monoid value for this CC E_node.t, if any *) + + val iter_all : (CC.repr * M.t) Iter.t +end + +(** Builder for a plugin. + + The builder takes a congruence closure, and instantiate the + plugin on it. *) +module type MONOID_PLUGIN_BUILDER = sig + module M : MONOID_PLUGIN_ARG + + module type DYN_PL_FOR_M = DYN_MONOID_PLUGIN with module M = M + + type t = (module DYN_PL_FOR_M) + + val create_and_setup : ?size:int -> CC.t -> t + (** Create a new monoid state *) +end diff --git a/src/cc/types_.ml b/src/cc/types_.ml new file mode 100644 index 00000000..c1e1bae1 --- /dev/null +++ b/src/cc/types_.ml @@ -0,0 +1,39 @@ +include Sidekick_core + +type e_node = { + n_term: Term.t; + mutable n_sig0: signature option; (* initial signature *) + mutable n_bits: Bits.t; (* bitfield for various properties *) + mutable n_parents: e_node Bag.t; (* parent terms of this node *) + mutable n_root: e_node; + (* representative of congruence class (itself if a representative) *) + mutable n_next: e_node; (* pointer to next element of congruence class *) + mutable n_size: int; (* size of the class *) + mutable n_as_lit: Lit.t option; + (* TODO: put into payload? and only in root? *) + mutable n_expl: explanation_forest_link; + (* the rooted forest for explanations *) +} +(** A node of the congruence closure. + An equivalence class is represented by its "root" element, + the representative. *) + +and signature = (Const.t, e_node, e_node list) View.t + +and explanation_forest_link = + | FL_none + | FL_some of { next: e_node; expl: explanation } + +(* atomic explanation in the congruence closure *) +and explanation = + | E_trivial (* by pure reduction, tautologically equal *) + | E_lit of Lit.t (* because of this literal *) + | E_merge of e_node * e_node + | E_merge_t of Term.t * Term.t + | E_congruence of e_node * e_node (* caused by normal congruence *) + | E_and of explanation * explanation + | E_theory of + Term.t + * Term.t + * (Term.t * Term.t * explanation list) list + * Proof_term.step_id