feat(check): use mini-cc to check CC conflicts on the fly

This commit is contained in:
Simon Cruanes 2019-06-07 13:57:12 -05:00
parent 2000114ab4
commit 357dc73426
9 changed files with 108 additions and 72 deletions

View file

@ -66,7 +66,6 @@ module Make(CC_A: ARG) = struct
| E_and of explanation * explanation
type repr = node
type conflict = lit list
module N = struct
type t = node
@ -213,6 +212,7 @@ module Make(CC_A: ARG) = struct
undo: (unit -> unit) Backtrack_stack.t;
mutable on_merge: ev_on_merge list;
mutable on_new_term: ev_on_new_term list;
mutable on_conflict: ev_on_conflict list;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
@ -232,6 +232,7 @@ module Make(CC_A: ARG) = struct
and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
and ev_on_new_term = t -> N.t -> term -> unit
and ev_on_conflict = t -> lit list -> unit
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
@ -337,11 +338,12 @@ module Make(CC_A: ARG) = struct
n.n_expl <- FL_none;
end
let raise_conflict (cc:t) (acts:actions) (e:conflict) : _ =
let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ =
(* clear tasks queue *)
Vec.iter (N.set_field field_is_pending false) cc.pending;
Vec.clear cc.pending;
Vec.clear cc.combine;
List.iter (fun f -> f cc e) cc.on_conflict;
Stat.incr cc.count_conflict;
CC_A.Actions.raise_conflict acts e A.Proof.default
@ -813,9 +815,10 @@ module Make(CC_A: ARG) = struct
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
let create ?(stat=Stat.global)
?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big)
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(size=`Big)
(tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = {
@ -824,6 +827,7 @@ module Make(CC_A: ARG) = struct
signatures_tbl = Sig_tbl.create size;
on_merge;
on_new_term;
on_conflict;
pending=Vec.create();
combine=Vec.create();
ps_lits=[];

View file

@ -207,8 +207,6 @@ module type CC_S = sig
type explanation = Expl.t
type conflict = lit list
(** Accessors *)
val term_state : t -> term_state
@ -222,11 +220,13 @@ module type CC_S = sig
type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
type ev_on_new_term = t -> N.t -> term -> unit
type ev_on_conflict = t -> lit list -> unit
val create :
?stat:Stat.t ->
?on_merge:ev_on_merge list ->
?on_new_term:ev_on_new_term list ->
?on_conflict:ev_on_conflict list ->
?size:[`Small | `Big] ->
term_state ->
t
@ -239,6 +239,9 @@ module type CC_S = sig
val on_new_term : t -> ev_on_new_term -> unit
(** Add a function to be called when a new node is created *)
val on_conflict : t -> ev_on_conflict -> unit
(** Called when the congruence closure finds a conflict *)
val set_as_lit : t -> N.t -> lit -> unit
(** map the given node to a literal. *)
@ -311,10 +314,6 @@ module type SOLVER_INTERNAL = sig
module Expl = CC.Expl
module N = CC.N
(** Unsatisfiable conjunction.
Its negation will become a conflict clause *)
type conflict = lit list
val tst : t -> term_state
val cc : t -> CC.t
@ -363,7 +362,6 @@ module type SOLVER_INTERNAL = sig
(** Propagate a boolean using a unit clause.
[expl => lit] must be a theory lemma, that is, a T-tautology *)
val add_clause_temp : t -> actions -> lit list -> unit
(** Add local clause to the SAT solver. This clause will be
removed when the solver backtracks. *)
@ -411,6 +409,9 @@ module type SOLVER_INTERNAL = sig
(** Callback to add data on terms when they are added to the congruence
closure *)
val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit
(** Callback called on every CC conflict *)
val on_partial_check : t -> (t -> actions -> lit Iter.t -> unit) -> unit
(** Register callbacked to be called with the slice of literals
newly added on the trail.

View file

@ -110,6 +110,10 @@ let main () =
in
Process.Solver.create ~store_proof:!check ~theories tst ()
in
if !check then (
(* might have to check conflicts *)
Solver.add_theory solver Process.Check_cc.theory;
);
let dot_proof = if !p_dot_proof = "" then None else Some !p_dot_proof in
Sidekick_smtlib.parse !file >>= fun input ->
(* process statements *)

View file

@ -1,8 +1,3 @@
type res =
| Sat
| Unsat
module CC_view = Sidekick_core.CC_view
module type ARG = sig
@ -23,7 +18,7 @@ module type S = sig
val add_lit : t -> term -> bool -> unit
val distinct : t -> term list -> unit
val check : t -> res
val check_sat : t -> bool
val classes : t -> term Iter.t Iter.t
end
@ -42,9 +37,9 @@ module Make(A: ARG) = struct
type node = {
n_t: term;
mutable n_next: node; (* next in class *)
mutable n_size: int; (* size of parent list *)
mutable n_size: int; (* size of class *)
mutable n_parents: node list;
mutable n_root: node;
mutable n_root: node; (* root of the class *)
}
type signature = (fun_, node, node list) CC_view.t
@ -55,17 +50,16 @@ module Make(A: ARG) = struct
let[@inline] hash (n:t) = T.hash n.n_t
let[@inline] size (n:t) = n.n_size
let[@inline] is_root n = n == n.n_root
let[@inline] root n = n.n_root
let[@inline] term n = n.n_t
let pp out n = T.pp out n.n_t
let add_parent (self:t) ~p : unit =
self.n_parents <- p :: self.n_parents;
self.n_size <- 1 + self.n_size;
()
self.n_parents <- p :: self.n_parents
let make (t:T.t) : t =
let rec n = {
n_t=t; n_size=0; n_next=n;
n_t=t; n_size=1; n_next=n;
n_parents=[]; n_root=n;
} in
n
@ -167,28 +161,19 @@ module Make(A: ARG) = struct
| n -> n
| exception Not_found ->
let node = Node.make t in
T_tbl.add self.tbl t node;
(* add sub-terms, and add [t] to their parent list *)
sub_ t
(fun u ->
let n_u = add_t self u in
let n_u = Node.root @@ add_t self u in
Node.add_parent n_u ~p:node);
T_tbl.add self.tbl t node;
(* need to compute signature *)
Vec.push self.pending node;
node
(* find representative *)
let[@inline] find_ (n:node) : node =
let r = n.n_root in
assert (Node.is_root r);
r
let find_t_ (self:t) (t:term): node =
let n =
try T_tbl.find self.tbl t
with Not_found -> Error.errorf "minicc.find_t: no node for %a" T.pp t
in
find_ n
try T_tbl.find self.tbl t |> Node.root
with Not_found -> Error.errorf "mini-cc.find_t: no node for %a" T.pp t
(* does this list contain a duplicate? *)
let has_dups (l:node list) : bool =
@ -200,7 +185,7 @@ module Make(A: ARG) = struct
let check_distinct_ self : unit =
Vec.iter
(fun r ->
r := List.map find_ !r;
r := List.rev_map Node.root !r;
if has_dups !r then raise_notrace E_unsat)
self.distinct
@ -232,17 +217,17 @@ module Make(A: ARG) = struct
(* reduce to [true] *)
let n2 = self.true_ in
Log.debugf 5
(fun k->k "(@[minicc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2);
(fun k->k "(@[mini-cc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2);
Vec.push self.combine (n,n2)
)
| Some s ->
Log.debugf 5 (fun k->k "(@[minicc.update-sig@ %a@])" Signature.pp s);
Log.debugf 5 (fun k->k "(@[mini-cc.update-sig@ %a@])" Signature.pp s);
match Sig_tbl.find self.sig_tbl s with
| n2 when Node.equal n n2 -> ()
| n2 ->
(* collision, merge *)
Log.debugf 5
(fun k->k "(@[minicc.congruence-by-sig@ %a@ %a@])" Node.pp n Node.pp n2);
(fun k->k "(@[mini-cc.congruence-by-sig@ %a@ %a@])" Node.pp n Node.pp n2);
Vec.push self.combine (n,n2)
| exception Not_found ->
Sig_tbl.add self.sig_tbl s n
@ -251,8 +236,8 @@ module Make(A: ARG) = struct
(* merge the two classes *)
let merge_ self (n1,n2) : unit =
let n1 = find_ n1 in
let n2 = find_ n2 in
let n1 = Node.root n1 in
let n2 = Node.root n2 in
if not @@ Node.equal n1 n2 then (
(* merge into largest class, or into a boolean *)
let n1, n2 =
@ -260,10 +245,10 @@ module Make(A: ARG) = struct
else if is_bool self n2 then n2, n1
else if Node.size n1 > Node.size n2 then n1, n2
else n2, n1 in
Log.debugf 5 (fun k->k "(@[minicc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2);
Log.debugf 5 (fun k->k "(@[mini-cc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2);
if is_bool self n1 && is_bool self n2 then (
Log.debugf 5 (fun k->k "(minicc.conflict.merge-true-false)");
Log.debugf 5 (fun k->k "(mini-cc.conflict.merge-true-false)");
self.ok <- false;
raise E_unsat
);
@ -276,9 +261,14 @@ module Make(A: ARG) = struct
(* update root pointer in [n2.class] *)
Node.iter_cls n2 (fun n -> n.n_root <- n1);
(* merge classes [next] pointers *)
let n1_next = n1.n_next in
n1.n_next <- n2.n_next;
n2.n_next <- n1_next;
)
let check_ok_ self =
let[@inline] check_ok_ self =
if not self.ok then raise_notrace E_unsat
(* fixpoint of the congruence closure *)
@ -309,18 +299,17 @@ module Make(A: ARG) = struct
Vec.push self.combine (n,n2)
let distinct (self:t) l =
begin match l with
| [] | [_] -> invalid_arg "distinct: need at least 2 terms"
| _ -> ()
end;
let l = List.map (add_t self) l in
match l with
| [] | [_] -> () (* trivial *)
| _ ->
let l = List.rev_map (add_t self) l in
Vec.push self.distinct (ref l)
let check (self:t) : res =
try fixpoint self; Sat
let check_sat (self:t) : bool =
try fixpoint self; true
with E_unsat ->
self.ok <- false;
Unsat
false
let classes self : _ Iter.t =
T_tbl.values self.tbl

View file

@ -1,4 +1,3 @@
(** {1 Mini congruence closure}
This implementation is as simple as possible, and doesn't provide
@ -6,10 +5,6 @@
It just decides the satisfiability of a set of (dis)equations.
*)
type res =
| Sat
| Unsat
module CC_view = Sidekick_core.CC_view
module type ARG = sig
@ -33,7 +28,9 @@ module type S = sig
val distinct : t -> term list -> unit
(** [distinct cc l] asserts that all terms in [l] are distinct *)
val check : t -> res
val check_sat : t -> bool
(** [check_sat cc] returns [true] if the current state is satisfiable, [false]
if it's unsatisfiable *)
val classes : t -> term Iter.t Iter.t
(** Traverse the set of classes in the congruence closure.

View file

@ -133,7 +133,6 @@ module Make(A : ARG)
module Expl = CC.Expl
type proof = A.Proof.t
type conflict = lit list
let[@inline] cc (t:t) = Lazy.force t.cc
let[@inline] tst t = t.tst
@ -210,6 +209,7 @@ module Make(A : ARG)
let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check
let on_cc_new_term self f = CC.on_new_term (cc self) f
let on_cc_merge self f = CC.on_merge (cc self) f
let on_cc_conflict self f = CC.on_conflict (cc self) f
let cc_add_term self t = CC.add_term (cc self) t
let cc_find self n = CC.find (cc self) n

View file

@ -2,6 +2,8 @@
open Sidekick_base_term
[@@@ocaml.warning "-32"]
type 'a or_error = ('a, string) CCResult.t
module E = CCResult
@ -391,7 +393,7 @@ let conv_ty = Conv.conv_ty
let conv_term = Conv.conv_term
(* instantiate solver here *)
module Solver = Sidekick_msat_solver.Make(struct
module Solver_arg = struct
include Sidekick_base_term
let cc_view = Term.cc_view
@ -400,7 +402,41 @@ module Solver = Sidekick_msat_solver.Make(struct
let default=Default
let pp out _ = Fmt.string out "default"
end
end)
end
module Solver = Sidekick_msat_solver.Make(Solver_arg)
module Check_cc = struct
module SI = Solver.Solver_internal
module CC = Solver.Solver_internal.CC
module MCC = Sidekick_mini_cc.Make(Solver_arg)
let pp_c out c = Fmt.fprintf out "(@[%a@])" (Util.pp_list ~sep:" " Lit.pp) c
let add_cc_lit (cc:MCC.t) (lit:Lit.t) : unit =
let t = Lit.term lit in
MCC.add_lit cc t (Lit.sign lit)
(* check that this is a proper CC conflict *)
let check_conflict si _cc (confl:Lit.t list) : unit =
Log.debugf 15 (fun k->k "(@[check-cc-conflict@ %a@])" pp_c confl);
let tst = SI.tst si in
let cc = MCC.create tst in
(* add [¬confl] and check it's unsat *)
List.iter (fun lit -> add_cc_lit cc @@ Lit.neg lit) confl;
if MCC.check_sat cc then (
Error.errorf "@[<2>check-cc-conflict:@ @[clause %a@]@ \
is not a UF conflict (negation is sat)@]" pp_c confl
) else (
Log.debugf 15 (fun k->k "(@[check-cc-conflict.ok@ %a@])" pp_c confl);
)
let theory =
Solver.mk_theory ~name:"cc-check"
~create_and_setup:(fun si ->
Solver.Solver_internal.on_cc_conflict si (check_conflict si))
()
end
(* TODO
(* check SMT model *)
@ -501,7 +537,7 @@ let solve
(* process a single statement *)
let process_stmt
?hyps
?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?check
?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?(check=false)
?time ?memory ?progress
(solver:Solver.t)
(stmt:Ast.statement) : unit or_error =
@ -532,7 +568,7 @@ let process_stmt
raise Exit
| A.CheckSat ->
solve
?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress
?gc ?restarts ?dot_proof ~check ?pp_model ?time ?memory ?progress
~assumptions:[] ?hyps
solver;
E.return()

View file

@ -18,6 +18,11 @@ val conv_ty : Ast.Ty.t -> Ty.t
val conv_term : Term.state -> Ast.term -> Term.t
module Check_cc : sig
(** theory that check validity of conflicts *)
val theory : Solver.theory
end
val process_stmt :
?hyps:Lit.t list Vec.t ->
?gc:bool ->

View file

@ -3,7 +3,7 @@
(public_name sidekick.smtlib)
(libraries containers zarith msat sidekick.core sidekick.util
sidekick.msat-solver sidekick.base-term sidekick.th-bool-static
msat.backend)
sidekick.mini-cc msat.backend)
(flags :standard -open Sidekick_util))
(menhir