mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 11:15:43 -05:00
feat(cc): boolean propagation of literals in CC
This commit is contained in:
parent
0e467e058c
commit
73c7db2b4e
8 changed files with 63 additions and 0 deletions
|
|
@ -863,10 +863,14 @@ module Make (Th : Theory_intf.S) = struct
|
||||||
To be called only from [cancel_until] *)
|
To be called only from [cancel_until] *)
|
||||||
let backtrack_down_to (st:t) (lvl:int): unit =
|
let backtrack_down_to (st:t) (lvl:int): unit =
|
||||||
Log.debugf 2 (fun k->k "(@[@{<Yellow>sat.backtrack@} now at stack depth %d@])" lvl);
|
Log.debugf 2 (fun k->k "(@[@{<Yellow>sat.backtrack@} now at stack depth %d@])" lvl);
|
||||||
|
let done_sth = Vec.size st.backtrack > lvl in
|
||||||
while Vec.size st.backtrack > lvl do
|
while Vec.size st.backtrack > lvl do
|
||||||
let f = Vec.pop_last st.backtrack in
|
let f = Vec.pop_last st.backtrack in
|
||||||
f()
|
f()
|
||||||
done;
|
done;
|
||||||
|
if done_sth then (
|
||||||
|
Th.post_backtrack (theory st);
|
||||||
|
);
|
||||||
(* now re-do permanent actions that were backtracked *)
|
(* now re-do permanent actions that were backtracked *)
|
||||||
while not (Vec.is_empty st.to_redo_after_backtrack) do
|
while not (Vec.is_empty st.to_redo_after_backtrack) do
|
||||||
let f = Vec.pop_last st.to_redo_after_backtrack in
|
let f = Vec.pop_last st.to_redo_after_backtrack in
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,10 @@ module type S = sig
|
||||||
(** Called at the end of the search in case a model has been found. If no new clause is
|
(** Called at the end of the search in case a model has been found. If no new clause is
|
||||||
pushed, then 'sat' is returned, else search is resumed. *)
|
pushed, then 'sat' is returned, else search is resumed. *)
|
||||||
|
|
||||||
|
val post_backtrack : t -> unit
|
||||||
|
(** After backtracking, this is called (can be used to invalidate
|
||||||
|
caches, reset task lists, etc.) *)
|
||||||
|
|
||||||
(**/**)
|
(**/**)
|
||||||
val check_invariants : t -> unit
|
val check_invariants : t -> unit
|
||||||
(**/**)
|
(**/**)
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,10 @@ let[@inline] size_ (r:repr) = r.n_size
|
||||||
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
|
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
|
||||||
let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t
|
let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t
|
||||||
|
|
||||||
|
let[@inline] post_backtrack cc =
|
||||||
|
Vec.clear cc.pending;
|
||||||
|
Vec.clear cc.combine
|
||||||
|
|
||||||
(* find representative, recursively *)
|
(* find representative, recursively *)
|
||||||
let rec find_rec cc (n:node) : repr =
|
let rec find_rec cc (n:node) : repr =
|
||||||
if n==n.n_root then (
|
if n==n.n_root then (
|
||||||
|
|
@ -412,6 +416,13 @@ and task_combine_ cc (a,b,e_ab) : unit =
|
||||||
raise_conflict cc @@ Lit.Set.elements lits)
|
raise_conflict cc @@ Lit.Set.elements lits)
|
||||||
ra.n_tags rb.n_tags
|
ra.n_tags rb.n_tags
|
||||||
in
|
in
|
||||||
|
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
|
||||||
|
let merge_bool r1 t1 r2 t2 =
|
||||||
|
if N.equal r1 cc.true_ then propagate_bools cc r2 t2 r1 t1 e_ab true
|
||||||
|
else if N.equal r1 cc.false_ then propagate_bools cc r2 t2 r1 t1 e_ab false
|
||||||
|
in
|
||||||
|
merge_bool ra a rb b;
|
||||||
|
merge_bool rb b ra a;
|
||||||
(* perform [union r_from r_into] *)
|
(* perform [union r_from r_into] *)
|
||||||
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
||||||
begin
|
begin
|
||||||
|
|
@ -460,6 +471,32 @@ and task_combine_ cc (a,b,e_ab) : unit =
|
||||||
notify_merge cc r_from ~into:r_into e_ab;
|
notify_merge cc r_from ~into:r_into e_ab;
|
||||||
)
|
)
|
||||||
|
|
||||||
|
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
|
||||||
|
in the equiv class of [r1] that is a known literal back to the SAT solver
|
||||||
|
and which is not the one initially merged.
|
||||||
|
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
|
||||||
|
and propagate_bools cc r1 t1 r2 t2 (e_12:explanation) sign : unit =
|
||||||
|
let (module A) = cc.acts in
|
||||||
|
(* explanation for [t1 =e= t2 = r2] *)
|
||||||
|
let half_expl = lazy (
|
||||||
|
let expl = explain_unfold cc e_12 in
|
||||||
|
explain_eq_n ~init:expl cc r2 t2
|
||||||
|
) in
|
||||||
|
iter_class_ r1
|
||||||
|
(fun u1 ->
|
||||||
|
(* propagate if:
|
||||||
|
- [u1] is a proper literal
|
||||||
|
- [t2 != r2], because that can only happen
|
||||||
|
after an explicit merge (no way to obtain that by propagation)
|
||||||
|
*)
|
||||||
|
if N.get_field N.field_is_literal u1 && not (N.equal r2 t2) then (
|
||||||
|
let lit = Lit.atom ~sign u1.n_term in
|
||||||
|
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
|
||||||
|
(* complete explanation with the [u1=t1] chunk *)
|
||||||
|
let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
|
||||||
|
A.propagate lit (Lit.Set.to_list expl)
|
||||||
|
))
|
||||||
|
|
||||||
(* Checks if [ra] and [~into] have compatible normal forms and can
|
(* Checks if [ra] and [~into] have compatible normal forms and can
|
||||||
be merged w.r.t. the theories.
|
be merged w.r.t. the theories.
|
||||||
Side effect: also pushes sub-tasks *)
|
Side effect: also pushes sub-tasks *)
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,8 @@ val assert_distinct : t -> term list -> neq:term -> Lit.t -> unit
|
||||||
|
|
||||||
val final_check : t -> unit
|
val final_check : t -> unit
|
||||||
|
|
||||||
|
val post_backtrack : t -> unit
|
||||||
|
|
||||||
val mk_model : t -> Model.t -> Model.t
|
val mk_model : t -> Model.t -> Model.t
|
||||||
(** Enrich a model by mapping terms to their representative's value,
|
(** Enrich a model by mapping terms to their representative's value,
|
||||||
if any. Otherwise map the representative to a fresh value *)
|
if any. Otherwise map the representative to a fresh value *)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ type payload = equiv_class_payload = ..
|
||||||
|
|
||||||
let field_is_active = Node_bits.mk_field()
|
let field_is_active = Node_bits.mk_field()
|
||||||
let field_is_pending = Node_bits.mk_field()
|
let field_is_pending = Node_bits.mk_field()
|
||||||
|
let field_is_literal = Node_bits.mk_field()
|
||||||
let () = Node_bits.freeze()
|
let () = Node_bits.freeze()
|
||||||
|
|
||||||
let[@inline] equal (n1:t) n2 = n1==n2
|
let[@inline] equal (n1:t) n2 = n1==n2
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,9 @@ val field_is_active : Node_bits.field
|
||||||
val field_is_pending : Node_bits.field
|
val field_is_pending : Node_bits.field
|
||||||
(** true iff the node is in the [cc.pending] queue *)
|
(** true iff the node is in the [cc.pending] queue *)
|
||||||
|
|
||||||
|
val field_is_literal : Node_bits.field
|
||||||
|
(** This term is a boolean literal, subject to propagations *)
|
||||||
|
|
||||||
(** {2 basics} *)
|
(** {2 basics} *)
|
||||||
|
|
||||||
val term : t -> term
|
val term : t -> term
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,12 @@ module type STATE = sig
|
||||||
|
|
||||||
val mk_model : t -> Lit.t Sequence.t -> Model.t
|
val mk_model : t -> Lit.t Sequence.t -> Model.t
|
||||||
(** Make a model for this theory's terms *)
|
(** Make a model for this theory's terms *)
|
||||||
|
|
||||||
|
val post_backtrack : t -> unit
|
||||||
|
|
||||||
|
(**/**)
|
||||||
|
val check_invariants : t -> unit
|
||||||
|
(**/**)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -90,6 +96,7 @@ let make_st
|
||||||
?(on_merge=fun _ _ _ _ -> ())
|
?(on_merge=fun _ _ _ _ -> ())
|
||||||
?(on_assert=fun _ _ -> ())
|
?(on_assert=fun _ _ -> ())
|
||||||
?(mk_model=fun _ _ -> Model.empty)
|
?(mk_model=fun _ _ -> Model.empty)
|
||||||
|
?(post_backtrack=fun _ -> ())
|
||||||
~final_check
|
~final_check
|
||||||
~st
|
~st
|
||||||
() : state =
|
() : state =
|
||||||
|
|
@ -100,6 +107,7 @@ let make_st
|
||||||
let on_assert = on_assert
|
let on_assert = on_assert
|
||||||
let final_check = final_check
|
let final_check = final_check
|
||||||
let mk_model = mk_model
|
let mk_model = mk_model
|
||||||
|
let post_backtrack = post_backtrack
|
||||||
let check_invariants = check_invariants
|
let check_invariants = check_invariants
|
||||||
end in
|
end in
|
||||||
(module A : STATE)
|
(module A : STATE)
|
||||||
|
|
|
||||||
|
|
@ -204,3 +204,7 @@ let add_theory (self:t) (th:Theory.t) : unit =
|
||||||
self.theories <- th_s :: self.theories
|
self.theories <- th_s :: self.theories
|
||||||
|
|
||||||
let add_theory_l self = List.iter (add_theory self)
|
let add_theory_l self = List.iter (add_theory self)
|
||||||
|
|
||||||
|
let post_backtrack self =
|
||||||
|
C_clos.post_backtrack (cc self);
|
||||||
|
theories self (fun (module Th) -> Th.post_backtrack Th.state)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue