feat(cc): boolean propagation of literals in CC

This commit is contained in:
Simon Cruanes 2018-08-18 19:56:22 -05:00
parent 0e467e058c
commit 73c7db2b4e
8 changed files with 63 additions and 0 deletions

View file

@ -863,10 +863,14 @@ module Make (Th : Theory_intf.S) = struct
To be called only from [cancel_until] *) To be called only from [cancel_until] *)
let backtrack_down_to (st:t) (lvl:int): unit = let backtrack_down_to (st:t) (lvl:int): unit =
Log.debugf 2 (fun k->k "(@[@{<Yellow>sat.backtrack@} now at stack depth %d@])" lvl); Log.debugf 2 (fun k->k "(@[@{<Yellow>sat.backtrack@} now at stack depth %d@])" lvl);
let done_sth = Vec.size st.backtrack > lvl in
while Vec.size st.backtrack > lvl do while Vec.size st.backtrack > lvl do
let f = Vec.pop_last st.backtrack in let f = Vec.pop_last st.backtrack in
f() f()
done; done;
if done_sth then (
Th.post_backtrack (theory st);
);
(* now re-do permanent actions that were backtracked *) (* now re-do permanent actions that were backtracked *)
while not (Vec.is_empty st.to_redo_after_backtrack) do while not (Vec.is_empty st.to_redo_after_backtrack) do
let f = Vec.pop_last st.to_redo_after_backtrack in let f = Vec.pop_last st.to_redo_after_backtrack in

View file

@ -131,6 +131,10 @@ module type S = sig
(** Called at the end of the search in case a model has been found. If no new clause is (** Called at the end of the search in case a model has been found. If no new clause is
pushed, then 'sat' is returned, else search is resumed. *) pushed, then 'sat' is returned, else search is resumed. *)
val post_backtrack : t -> unit
(** After backtracking, this is called (can be used to invalidate
caches, reset task lists, etc.) *)
(**/**) (**/**)
val check_invariants : t -> unit val check_invariants : t -> unit
(**/**) (**/**)

View file

@ -73,6 +73,10 @@ let[@inline] size_ (r:repr) = r.n_size
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *) Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t
let[@inline] post_backtrack cc =
Vec.clear cc.pending;
Vec.clear cc.combine
(* find representative, recursively *) (* find representative, recursively *)
let rec find_rec cc (n:node) : repr = let rec find_rec cc (n:node) : repr =
if n==n.n_root then ( if n==n.n_root then (
@ -412,6 +416,13 @@ and task_combine_ cc (a,b,e_ab) : unit =
raise_conflict cc @@ Lit.Set.elements lits) raise_conflict cc @@ Lit.Set.elements lits)
ra.n_tags rb.n_tags ra.n_tags rb.n_tags
in in
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
let merge_bool r1 t1 r2 t2 =
if N.equal r1 cc.true_ then propagate_bools cc r2 t2 r1 t1 e_ab true
else if N.equal r1 cc.false_ then propagate_bools cc r2 t2 r1 t1 e_ab false
in
merge_bool ra a rb b;
merge_bool rb b ra a;
(* perform [union r_from r_into] *) (* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
begin begin
@ -460,6 +471,32 @@ and task_combine_ cc (a,b,e_ab) : unit =
notify_merge cc r_from ~into:r_into e_ab; notify_merge cc r_from ~into:r_into e_ab;
) )
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
in the equiv class of [r1] that is a known literal back to the SAT solver
and which is not the one initially merged.
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
and propagate_bools cc r1 t1 r2 t2 (e_12:explanation) sign : unit =
let (module A) = cc.acts in
(* explanation for [t1 =e= t2 = r2] *)
let half_expl = lazy (
let expl = explain_unfold cc e_12 in
explain_eq_n ~init:expl cc r2 t2
) in
iter_class_ r1
(fun u1 ->
(* propagate if:
- [u1] is a proper literal
- [t2 != r2], because that can only happen
after an explicit merge (no way to obtain that by propagation)
*)
if N.get_field N.field_is_literal u1 && not (N.equal r2 t2) then (
let lit = Lit.atom ~sign u1.n_term in
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *)
let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
A.propagate lit (Lit.Set.to_list expl)
))
(* Checks if [ra] and [~into] have compatible normal forms and can (* Checks if [ra] and [~into] have compatible normal forms and can
be merged w.r.t. the theories. be merged w.r.t. the theories.
Side effect: also pushes sub-tasks *) Side effect: also pushes sub-tasks *)

View file

@ -64,6 +64,8 @@ val assert_distinct : t -> term list -> neq:term -> Lit.t -> unit
val final_check : t -> unit val final_check : t -> unit
val post_backtrack : t -> unit
val mk_model : t -> Model.t -> Model.t val mk_model : t -> Model.t -> Model.t
(** Enrich a model by mapping terms to their representative's value, (** Enrich a model by mapping terms to their representative's value,
if any. Otherwise map the representative to a fresh value *) if any. Otherwise map the representative to a fresh value *)

View file

@ -6,6 +6,7 @@ type payload = equiv_class_payload = ..
let field_is_active = Node_bits.mk_field() let field_is_active = Node_bits.mk_field()
let field_is_pending = Node_bits.mk_field() let field_is_pending = Node_bits.mk_field()
let field_is_literal = Node_bits.mk_field()
let () = Node_bits.freeze() let () = Node_bits.freeze()
let[@inline] equal (n1:t) n2 = n1==n2 let[@inline] equal (n1:t) n2 = n1==n2

View file

@ -30,6 +30,9 @@ val field_is_active : Node_bits.field
val field_is_pending : Node_bits.field val field_is_pending : Node_bits.field
(** true iff the node is in the [cc.pending] queue *) (** true iff the node is in the [cc.pending] queue *)
val field_is_literal : Node_bits.field
(** This term is a boolean literal, subject to propagations *)
(** {2 basics} *) (** {2 basics} *)
val term : t -> term val term : t -> term

View file

@ -31,6 +31,12 @@ module type STATE = sig
val mk_model : t -> Lit.t Sequence.t -> Model.t val mk_model : t -> Lit.t Sequence.t -> Model.t
(** Make a model for this theory's terms *) (** Make a model for this theory's terms *)
val post_backtrack : t -> unit
(**/**)
val check_invariants : t -> unit
(**/**)
end end
@ -90,6 +96,7 @@ let make_st
?(on_merge=fun _ _ _ _ -> ()) ?(on_merge=fun _ _ _ _ -> ())
?(on_assert=fun _ _ -> ()) ?(on_assert=fun _ _ -> ())
?(mk_model=fun _ _ -> Model.empty) ?(mk_model=fun _ _ -> Model.empty)
?(post_backtrack=fun _ -> ())
~final_check ~final_check
~st ~st
() : state = () : state =
@ -100,6 +107,7 @@ let make_st
let on_assert = on_assert let on_assert = on_assert
let final_check = final_check let final_check = final_check
let mk_model = mk_model let mk_model = mk_model
let post_backtrack = post_backtrack
let check_invariants = check_invariants let check_invariants = check_invariants
end in end in
(module A : STATE) (module A : STATE)

View file

@ -204,3 +204,7 @@ let add_theory (self:t) (th:Theory.t) : unit =
self.theories <- th_s :: self.theories self.theories <- th_s :: self.theories
let add_theory_l self = List.iter (add_theory self) let add_theory_l self = List.iter (add_theory self)
let post_backtrack self =
C_clos.post_backtrack (cc self);
theories self (fun (module Th) -> Th.post_backtrack Th.state)