From 73c7db2b4e481511ee796141380b2d595b47447a Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 18 Aug 2018 19:56:22 -0500 Subject: [PATCH] feat(cc): boolean propagation of literals in CC --- src/sat/Internal.ml | 4 ++++ src/sat/Theory_intf.ml | 4 ++++ src/smt/Congruence_closure.ml | 37 ++++++++++++++++++++++++++++++++++ src/smt/Congruence_closure.mli | 2 ++ src/smt/Equiv_class.ml | 1 + src/smt/Equiv_class.mli | 3 +++ src/smt/Theory.ml | 8 ++++++++ src/smt/Theory_combine.ml | 4 ++++ 8 files changed, 63 insertions(+) diff --git a/src/sat/Internal.ml b/src/sat/Internal.ml index 82a519bc..1759e54a 100644 --- a/src/sat/Internal.ml +++ b/src/sat/Internal.ml @@ -863,10 +863,14 @@ module Make (Th : Theory_intf.S) = struct To be called only from [cancel_until] *) let backtrack_down_to (st:t) (lvl:int): unit = Log.debugf 2 (fun k->k "(@[@{sat.backtrack@} now at stack depth %d@])" lvl); + let done_sth = Vec.size st.backtrack > lvl in while Vec.size st.backtrack > lvl do let f = Vec.pop_last st.backtrack in f() done; + if done_sth then ( + Th.post_backtrack (theory st); + ); (* now re-do permanent actions that were backtracked *) while not (Vec.is_empty st.to_redo_after_backtrack) do let f = Vec.pop_last st.to_redo_after_backtrack in diff --git a/src/sat/Theory_intf.ml b/src/sat/Theory_intf.ml index 4616d56a..1155cbbf 100644 --- a/src/sat/Theory_intf.ml +++ b/src/sat/Theory_intf.ml @@ -131,6 +131,10 @@ module type S = sig (** Called at the end of the search in case a model has been found. If no new clause is pushed, then 'sat' is returned, else search is resumed. *) + val post_backtrack : t -> unit + (** After backtracking, this is called (can be used to invalidate + caches, reset task lists, etc.) *) + (**/**) val check_invariants : t -> unit (**/**) diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 7447a941..a5a2b709 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -73,6 +73,10 @@ let[@inline] size_ (r:repr) = r.n_size Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t +let[@inline] post_backtrack cc = + Vec.clear cc.pending; + Vec.clear cc.combine + (* find representative, recursively *) let rec find_rec cc (n:node) : repr = if n==n.n_root then ( @@ -412,6 +416,13 @@ and task_combine_ cc (a,b,e_ab) : unit = raise_conflict cc @@ Lit.Set.elements lits) ra.n_tags rb.n_tags in + (* when merging terms with [true] or [false], possibly propagate them to SAT *) + let merge_bool r1 t1 r2 t2 = + if N.equal r1 cc.true_ then propagate_bools cc r2 t2 r1 t1 e_ab true + else if N.equal r1 cc.false_ then propagate_bools cc r2 t2 r1 t1 e_ab false + in + merge_bool ra a rb b; + merge_bool rb b ra a; (* perform [union r_from r_into] *) Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); begin @@ -460,6 +471,32 @@ and task_combine_ cc (a,b,e_ab) : unit = notify_merge cc r_from ~into:r_into e_ab; ) +(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] + in the equiv class of [r1] that is a known literal back to the SAT solver + and which is not the one initially merged. + We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) +and propagate_bools cc r1 t1 r2 t2 (e_12:explanation) sign : unit = + let (module A) = cc.acts in + (* explanation for [t1 =e= t2 = r2] *) + let half_expl = lazy ( + let expl = explain_unfold cc e_12 in + explain_eq_n ~init:expl cc r2 t2 + ) in + iter_class_ r1 + (fun u1 -> + (* propagate if: + - [u1] is a proper literal + - [t2 != r2], because that can only happen + after an explicit merge (no way to obtain that by propagation) + *) + if N.get_field N.field_is_literal u1 && not (N.equal r2 t2) then ( + let lit = Lit.atom ~sign u1.n_term in + Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); + (* complete explanation with the [u1=t1] chunk *) + let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in + A.propagate lit (Lit.Set.to_list expl) + )) + (* Checks if [ra] and [~into] have compatible normal forms and can be merged w.r.t. the theories. Side effect: also pushes sub-tasks *) diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli index af57e03b..adec138e 100644 --- a/src/smt/Congruence_closure.mli +++ b/src/smt/Congruence_closure.mli @@ -64,6 +64,8 @@ val assert_distinct : t -> term list -> neq:term -> Lit.t -> unit val final_check : t -> unit +val post_backtrack : t -> unit + val mk_model : t -> Model.t -> Model.t (** Enrich a model by mapping terms to their representative's value, if any. Otherwise map the representative to a fresh value *) diff --git a/src/smt/Equiv_class.ml b/src/smt/Equiv_class.ml index ae3a523e..6800cba4 100644 --- a/src/smt/Equiv_class.ml +++ b/src/smt/Equiv_class.ml @@ -6,6 +6,7 @@ type payload = equiv_class_payload = .. let field_is_active = Node_bits.mk_field() let field_is_pending = Node_bits.mk_field() +let field_is_literal = Node_bits.mk_field() let () = Node_bits.freeze() let[@inline] equal (n1:t) n2 = n1==n2 diff --git a/src/smt/Equiv_class.mli b/src/smt/Equiv_class.mli index f3a1f434..6d10d982 100644 --- a/src/smt/Equiv_class.mli +++ b/src/smt/Equiv_class.mli @@ -30,6 +30,9 @@ val field_is_active : Node_bits.field val field_is_pending : Node_bits.field (** true iff the node is in the [cc.pending] queue *) +val field_is_literal : Node_bits.field +(** This term is a boolean literal, subject to propagations *) + (** {2 basics} *) val term : t -> term diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index 7e8f3652..9be25be6 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -31,6 +31,12 @@ module type STATE = sig val mk_model : t -> Lit.t Sequence.t -> Model.t (** Make a model for this theory's terms *) + + val post_backtrack : t -> unit + + (**/**) + val check_invariants : t -> unit + (**/**) end @@ -90,6 +96,7 @@ let make_st ?(on_merge=fun _ _ _ _ -> ()) ?(on_assert=fun _ _ -> ()) ?(mk_model=fun _ _ -> Model.empty) + ?(post_backtrack=fun _ -> ()) ~final_check ~st () : state = @@ -100,6 +107,7 @@ let make_st let on_assert = on_assert let final_check = final_check let mk_model = mk_model + let post_backtrack = post_backtrack let check_invariants = check_invariants end in (module A : STATE) diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index 60f90acb..090b46d2 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -204,3 +204,7 @@ let add_theory (self:t) (th:Theory.t) : unit = self.theories <- th_s :: self.theories let add_theory_l self = List.iter (add_theory self) + +let post_backtrack self = + C_clos.post_backtrack (cc self); + theories self (fun (module Th) -> Th.post_backtrack Th.state)