diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 37fbbc46..b9cd62f1 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -7,6 +7,15 @@ type node = N.t type repr = N.t type conflict = Theory.conflict +module T_arg = struct + module Fun = Cst + module Term = struct + include Term + let view = cc_view + end +end +module Mini_cc = Mini_cc.Make(T_arg) + (** A signature is a shallow term shape where immediate subterms are representative *) module Signature = struct @@ -38,7 +47,7 @@ type t = { combine: combine_task Vec.t; undo: (unit -> unit) Backtrack_stack.t; on_merge: (repr -> repr -> explanation -> unit) option; - mutable ps_lits: Lit.Set.t; + mutable ps_lits: Lit.Set.t; (* TODO: thread it around instead? *) (* proof state *) ps_queue: (node*node) Vec.t; (* pairs to explain *) @@ -121,7 +130,7 @@ let pp_full out (cc:t) : unit = (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) (* compute signature *) -let signature cc (t:term): node Term.view option = +let signature cc (t:term): Signature.t option = let find = find_tn cc in begin match Term.view t with | App_cst (_, a) when IArray.is_empty a -> None @@ -138,19 +147,19 @@ let find_by_signature cc (t:term) : repr option = | None -> None | Some s -> Sig_tbl.get cc.signatures_tbl s -let add_signature cc (r:node): unit = - match signature cc r.n_term with +let add_signature cc (n:node): unit = + match signature cc n.n_term with | None -> () | Some s -> (* add, but only if not present already *) begin match Sig_tbl.find cc.signatures_tbl s with | exception Not_found -> Log.debugf 15 - (fun k->k "(@[cc.add_sig@ %a@ <--> %a@])" Signature.pp s N.pp r); + (fun k->k "(@[cc.add_sig@ %a@ <--> %a@])" Signature.pp s N.pp n); on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s); - Sig_tbl.add cc.signatures_tbl s r; + Sig_tbl.add cc.signatures_tbl s n; | r' -> - assert (same_class cc r r'); + assert (same_class cc n r'); end let push_pending cc t : unit = @@ -191,6 +200,7 @@ let[@inline] all_classes cc : repr Sequence.t = Term.Tbl.values cc.tbl |> Sequence.filter is_root_ +(* TODO: use markers and lockstep iteration instead *) (* distance from [t] to its root in the proof forest *) let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with | E_none -> 0 @@ -267,6 +277,7 @@ let explain_loop (cc : t) : Lit.Set.t = done; cc.ps_lits +(* TODO: do not use ps_lits anymore *) let explain_eq_n ?(init=Lit.Set.empty) cc (n1:node) (n2:node) : Lit.Set.t = ps_clear cc; cc.ps_lits <- init; @@ -340,6 +351,7 @@ and task_pending_ cc n = | App_cst (f1, a1), App_cst (f2, a2) -> assert (Cst.equal f1 f2); assert (IArray.length a1 = IArray.length a2); + (* TODO: just use "congruence" as explanation *) Explanation.mk_merges @@ IArray.map2 (fun u1 u2 -> add_term_rec_ cc u1, add_term_rec_ cc u2) a1 a2 | If _, _ | App_cst _, _ | Bool _, _ -> assert false @@ -361,7 +373,7 @@ and[@inline] task_combine_ cc acts = function and task_merge_ cc acts a b e_ab : unit = let ra = find cc a in let rb = find cc b in - if not (N.equal ra rb) then ( + if not @@ N.equal ra rb then ( assert (is_root_ ra); assert (is_root_ rb); let lazy e_ab = e_ab in @@ -385,6 +397,7 @@ and task_merge_ cc acts a b e_ab : unit = let lits = explain_eq_n ~init:lits cc b rb in raise_conflict cc acts @@ Lit.Set.elements lits ); + (* TODO: isntead call micro theories, including "distinct" *) (* update set of tags the new node cannot be equal to *) let new_tags = Util.Int_map.union @@ -416,6 +429,8 @@ and task_merge_ cc acts a b e_ab : unit = merge_bool rb b ra a; (* perform [union r_from r_into] *) Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); + (* TODO: only iterate on parents of [rb] *) + (* TODO: [ra.parents <- ra.parent ++ rb.parents] *) begin (* for each node in [r_from]'s class: - make it point to [r_into] @@ -440,6 +455,7 @@ and task_merge_ cc acts a b e_ab : unit = r_from.n_next <- r_from_old_next; r_into.n_tags <- r_into_old_tags); r_into.n_tags <- new_tags; + (* swap [into.next] and [from.next], merging the classes *) r_into.n_next <- r_from_old_next; r_from.n_next <- r_into_old_next; end; @@ -471,7 +487,9 @@ and task_distinct_ cc acts (l:node list) tag expl : unit = begin match coll with | Some ((n1,_r1),(n2,_r2)) -> (* two classes are already equal *) - Log.debugf 5 (fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp n2 Explanation.pp expl); + Log.debugf 5 + (fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp + n2 Explanation.pp expl); let lits = explain_unfold cc expl in raise_conflict cc acts (Lit.Set.to_list lits) | None -> @@ -512,6 +530,7 @@ and add_new_term_ cc (t:term) : node = let n = N.make t in (* how to add a subterm *) let add_to_parents_of_sub_node (sub:node) : unit = + let sub = find cc sub in (* update the repr! *) let old_parents = sub.n_parents in on_backtrack cc (fun () -> sub.n_parents <- old_parents); sub.n_parents <- Bag.cons n sub.n_parents; diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli index 73c9892d..22215e37 100644 --- a/src/smt/Congruence_closure.mli +++ b/src/smt/Congruence_closure.mli @@ -68,3 +68,6 @@ val mk_model : t -> Model.t -> Model.t val check_invariants : t -> unit val pp_full : t Fmt.printer (**/**) + +module T_arg : Mini_cc_intf.ARG with type Fun.t = cst and type Term.t = Term.t +module Mini_cc : module type of Mini_cc.Make(T_arg)