From 1d212350efd2c3db6911ce9303657682415612c8 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 18 Aug 2018 14:51:49 -0500 Subject: [PATCH] refactor(cc): internal refactorings --- src/smt/Congruence_closure.ml | 101 ++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 48c499cc..dd327630 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -3,15 +3,17 @@ module Vec = Sidekick_sat.Vec module Log = Sidekick_sat.Log open Solver_types -type node = Equiv_class.t -type repr = Equiv_class.t +module N = Equiv_class + +type node = N.t +type repr = N.t type conflict = Theory.conflict (** A signature is a shallow term shape where immediate subterms are representative *) module Signature = struct type t = node Term.view - include Term_cell.Make_eq(Equiv_class) + include Term_cell.Make_eq(N) end module Sig_tbl = CCHashtbl.Make(Signature) @@ -51,6 +53,7 @@ type t = { have the same signature *) tasks: task Vec.t; (* tasks to perform *) + on_backtrack:(unit->unit)->unit; mutable ps_lits: Lit.Set.t; (* proof state *) ps_queue: (node*node) Vec.t; @@ -64,10 +67,7 @@ type t = { several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) -let[@inline] on_backtrack cc f : unit = - let (module A) = cc.acts in - A.on_backtrack f - +let[@inline] on_backtrack cc f : unit = cc.on_backtrack f let[@inline] is_root_ (n:node) : bool = n.n_root == n let[@inline] size_ (r:repr) = @@ -112,7 +112,7 @@ let[@inline] find st (n:node) : repr = let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc let[@inline] same_class cc (n1:node)(n2:node): bool = - Equiv_class.equal (find cc n1) (find cc n2) + N.equal (find cc n1) (find cc n2) (* compute signature *) let signature cc (t:term): node Term.view option = @@ -120,7 +120,7 @@ let signature cc (t:term): node Term.view option = begin match Term.view t with | App_cst (_, a) when IArray.is_empty a -> None | App_cst (c, _) when not @@ Cst.do_cc c -> None (* no CC *) - | App_cst (f, a) -> App_cst (f, IArray.map find a) |> CCOpt.return (* FIXME: relevance *) + | App_cst (f, a) -> Some (App_cst (f, IArray.map find a)) (* FIXME: relevance? *) | Bool _ | If _ -> None (* no congruence for these *) end @@ -146,16 +146,16 @@ let add_signature cc (t:term) (r:node): unit = end let push_pending cc t : unit = - if not @@ Equiv_class.get_field Equiv_class.field_is_pending t then ( - Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" Equiv_class.pp t); - Equiv_class.set_field Equiv_class.field_is_pending true t; + if not @@ N.get_field N.field_is_pending t then ( + Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" N.pp t); + N.set_field N.field_is_pending true t; Vec.push cc.tasks (T_pending t) ) let push_combine cc t u e : unit = Log.debugf 5 (fun k->k "(@[cc.push_combine@ :t1 %a@ :t2 %a@ :expl %a@])" - Equiv_class.pp t Equiv_class.pp u Explanation.pp e); + N.pp t N.pp u Explanation.pp e); Vec.push cc.tasks @@ T_merge (t,u,e) (* re-root the explanation tree of the equivalence class of [n] @@ -177,7 +177,7 @@ let raise_conflict (cc:t) (e:conflict): _ = (* clear tasks queue *) Vec.iter (function - | T_pending n -> Equiv_class.set_field Equiv_class.field_is_pending false n + | T_pending n -> N.set_field N.field_is_pending false n | T_merge _ -> ()) cc.tasks; Vec.clear cc.tasks; @@ -259,8 +259,8 @@ let explain_loop (cc : t) : Lit.Set.t = while not (Vec.is_empty cc.ps_queue) do let a, b = Vec.pop_last cc.ps_queue in Log.debugf 5 - (fun k->k "(@[cc.explain_loop at@ %a@ %a@])" Equiv_class.pp a Equiv_class.pp b); - assert (Equiv_class.equal (find cc a) (find cc b)); + (fun k->k "(@[cc.explain_loop at@ %a@ %a@])" N.pp a N.pp b); + assert (N.equal (find cc a) (find cc b)); let c = find_common_ancestor a b in explain_along_path cc a c; explain_along_path cc b c; @@ -296,6 +296,17 @@ let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit = n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags; ) +let relevant_subterms (t:Term.t) : Term.t Sequence.t = + fun yield -> + match t.term_view with + | App_cst (c, a) when Cst.do_cc c -> IArray.iter yield a + | Bool _ | App_cst _ -> () + | If (a,b,c) -> + (* TODO: relevancy? only [a] needs be decided for now *) + yield a; + yield b; + yield c + (* main CC algo: add terms from [pending] to the signature table, check for collisions *) let rec update_tasks (cc:t): unit = @@ -303,13 +314,13 @@ let rec update_tasks (cc:t): unit = might have changed *) while not (Vec.is_empty cc.tasks) do let task = Vec.pop_last cc.tasks in - match task with + match task with | T_pending n -> task_pending_ cc n | T_merge (t,u,expl) -> task_merge_ cc t u expl done and task_pending_ cc n = - Equiv_class.set_field Equiv_class.field_is_pending false n; + N.set_field N.field_is_pending false n; (* check if some parent collided *) begin match find_by_signature cc n.n_term with | None -> @@ -343,7 +354,7 @@ and task_pending_ cc n = and task_merge_ cc a b e_ab : unit = let ra = find cc a in let rb = find cc b in - if not (Equiv_class.equal ra rb) then ( + if not (N.equal ra rb) then ( assert (is_root_ ra); assert (is_root_ rb); (* We will merge [r_from] into [r_into]. @@ -357,8 +368,8 @@ and task_merge_ cc a b e_ab : unit = Log.debugf 5 (fun k->k "(@[cc.merge.distinct_conflict@ :tag %d@ \ @[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])" - _i Equiv_class.pp n1 Explanation.pp e1 - Equiv_class.pp n2 Explanation.pp e2 Explanation.pp e_ab); + _i N.pp n1 Explanation.pp e1 + N.pp n2 Explanation.pp e2 Explanation.pp e_ab); let lits = explain_unfold cc e1 in let lits = explain_unfold ~init:lits cc e2 in let lits = explain_unfold ~init:lits cc e_ab in @@ -373,7 +384,7 @@ and task_merge_ cc a b e_ab : unit = (fun parent -> push_pending cc parent) end; (* perform [union ra rb] *) - Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" Equiv_class.pp r_from Equiv_class.pp r_into); + Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); begin let r_into_old_parents = r_into.n_parents in let r_into_old_tags = r_into.n_tags in @@ -412,7 +423,7 @@ and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = and add_new_term_ cc (t:term) : node = assert (not @@ mem cc t); Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t); - let n = Equiv_class.make t in + let n = N.make t in (* how to add a subterm *) let add_to_parents_of_sub_node (sub:node) : unit = let old_parents = sub.n_parents in @@ -426,15 +437,7 @@ and add_new_term_ cc (t:term) : node = add_to_parents_of_sub_node n_u in (* register sub-terms, add [t] to their parent list *) - begin match t.term_view with - | App_cst (c, a) when Cst.do_cc c -> IArray.iter add_sub_t a - | Bool _ | App_cst _ -> () - | If (a,b,c) -> - (* TODO: relevancy? only [a] needs be decided for now *) - add_sub_t a; - add_sub_t b; - add_sub_t c - end; + relevant_subterms t add_sub_t; (* remove term when we backtrack *) on_backtrack cc (fun () -> @@ -493,7 +496,7 @@ let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit = let l = List.map (fun t -> t, add cc t |> find cc) l in let coll = Sequence.diagonal_l l - |> Sequence.find_pred (fun ((_,n1),(_,n2)) -> Equiv_class.equal n1 n2) + |> Sequence.find_pred (fun ((_,n1),(_,n2)) -> N.equal n1 n2) in begin match coll with | Some ((t1,_n1),(t2,_n2)) -> @@ -508,17 +511,19 @@ let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit = end let create ?(size=2048) ~actions (tst:Term.state) : t = - let nd = Equiv_class.dummy in + let nd = N.dummy in + let (module A : ACTIONS) = actions in let cc = { tst; acts=actions; tbl = Term.Tbl.create size; signatures_tbl = Sig_tbl.create size; - tasks=Vec.make_empty (T_pending Equiv_class.dummy); + tasks=Vec.make_empty (T_pending N.dummy); ps_lits=Lit.Set.empty; + on_backtrack=A.on_backtrack; ps_queue=Vec.make_empty (nd,nd); - true_ = Equiv_class.dummy; - false_ = Equiv_class.dummy; + true_ = N.dummy; + false_ = N.dummy; } in cc.true_ <- add cc (Term.true_ tst); cc.false_ <- add cc (Term.false_ tst); @@ -531,7 +536,7 @@ let final_check cc : unit = (* model: map each uninterpreted equiv class to some ID *) let mk_model (cc:t) (m:Model.t) : Model.t = (* populate [repr -> value] table *) - let t_tbl = Equiv_class.Tbl.create 32 in + let t_tbl = N.Tbl.create 32 in (* type -> default value *) let ty_tbl = Ty.Tbl.create 8 in Term.Tbl.values cc.tbl @@ -552,7 +557,7 @@ let mk_model (cc:t) (m:Model.t) : Model.t = if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then ( Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *) ); - Equiv_class.Tbl.add t_tbl r v + N.Tbl.add t_tbl r v )); (* now map every uninterpreted term to its representative's value, and create function tables *) @@ -568,20 +573,20 @@ let mk_model (cc:t) (m:Model.t) : Model.t = else if Cst.is_undefined c && IArray.length args > 0 then ( (* update signature of [c] *) let ty = Term.ty t in - let v = Equiv_class.Tbl.find t_tbl r in + let v = N.Tbl.find t_tbl r in let args = args - |> IArray.map (fun t -> Equiv_class.Tbl.find t_tbl @@ find_tn cc t) + |> IArray.map (fun t -> N.Tbl.find t_tbl @@ find_tn cc t) |> IArray.to_list in let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in m, Cst.Map.add c (ty, (args,v)::l) funs ) else ( - let v = Equiv_class.Tbl.find t_tbl r in + let v = N.Tbl.find t_tbl r in Model.add t v m, funs ) | _ -> - let v = Equiv_class.Tbl.find t_tbl r in + let v = N.Tbl.find t_tbl r in Model.add t v m, funs) (m,Cst.Map.empty) in @@ -593,7 +598,7 @@ let mk_model (cc:t) (m:Model.t) : Model.t = (* domain element *) Ty.Tbl.get_or_add ty_tbl ~k:ty ~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty) - | Ty_atomic { def = Ty_def d; args; _} -> + | Ty_atomic { def = Ty_def d; args; _} -> (* ask the theory for a default value *) Ty.Tbl.get_or_add ty_tbl ~k:ty ~f:(fun _ty -> @@ -611,12 +616,12 @@ let mk_model (cc:t) (m:Model.t) : Model.t = let pp_full out (cc:t) : unit = let pp_n out n = let pp_next out n = - if n==n.n_root then () else Fmt.fprintf out "@ :next %a" Equiv_class.pp n.n_root in + if n==n.n_root then () else Fmt.fprintf out "@ :next %a" N.pp n.n_root in let pp_root out n = - let u = find cc n in if n==u||n.n_root==u then () else Fmt.fprintf out "@ :root %a" Equiv_class.pp u in + let u = find cc n in if n==u||n.n_root==u then () else Fmt.fprintf out "@ :root %a" N.pp u in Fmt.fprintf out "(@[%a%a%a@])" Term.pp n.n_term pp_next n pp_root n and pp_sig_e out (s,n) = - Fmt.fprintf out "(@[<1>%a@ -> %a@])" Signature.pp s Equiv_class.pp n + Fmt.fprintf out "(@[<1>%a@ -> %a@])" Signature.pp s N.pp n in Fmt.fprintf out "(@[cc.state@ (@[:nodes@ %a@])@ (@[:sig@ %a@])@])" @@ -633,7 +638,7 @@ let check_invariants_ (cc:t) = Term.Tbl.iter (fun t n -> assert (Term.equal t n.n_term); - assert (not @@ Equiv_class.get_field Equiv_class.field_is_pending n); + assert (not @@ N.get_field N.field_is_pending n); relevant_subterms t (fun u -> assert (Term.Tbl.mem cc.tbl u)); (* check proper signature *)