mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-10 05:03:59 -05:00
feat(check): use mini-cc to check CC conflicts on the fly
This commit is contained in:
parent
2000114ab4
commit
357dc73426
9 changed files with 108 additions and 72 deletions
|
|
@ -66,7 +66,6 @@ module Make(CC_A: ARG) = struct
|
|||
| E_and of explanation * explanation
|
||||
|
||||
type repr = node
|
||||
type conflict = lit list
|
||||
|
||||
module N = struct
|
||||
type t = node
|
||||
|
|
@ -213,6 +212,7 @@ module Make(CC_A: ARG) = struct
|
|||
undo: (unit -> unit) Backtrack_stack.t;
|
||||
mutable on_merge: ev_on_merge list;
|
||||
mutable on_new_term: ev_on_new_term list;
|
||||
mutable on_conflict: ev_on_conflict list;
|
||||
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
|
||||
(* proof state *)
|
||||
ps_queue: (node*node) Vec.t;
|
||||
|
|
@ -232,6 +232,7 @@ module Make(CC_A: ARG) = struct
|
|||
|
||||
and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
|
||||
and ev_on_new_term = t -> N.t -> term -> unit
|
||||
and ev_on_conflict = t -> lit list -> unit
|
||||
|
||||
let[@inline] size_ (r:repr) = r.n_size
|
||||
let[@inline] true_ cc = Lazy.force cc.true_
|
||||
|
|
@ -337,11 +338,12 @@ module Make(CC_A: ARG) = struct
|
|||
n.n_expl <- FL_none;
|
||||
end
|
||||
|
||||
let raise_conflict (cc:t) (acts:actions) (e:conflict) : _ =
|
||||
let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ =
|
||||
(* clear tasks queue *)
|
||||
Vec.iter (N.set_field field_is_pending false) cc.pending;
|
||||
Vec.clear cc.pending;
|
||||
Vec.clear cc.combine;
|
||||
List.iter (fun f -> f cc e) cc.on_conflict;
|
||||
Stat.incr cc.count_conflict;
|
||||
CC_A.Actions.raise_conflict acts e A.Proof.default
|
||||
|
||||
|
|
@ -813,9 +815,10 @@ module Make(CC_A: ARG) = struct
|
|||
|
||||
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
|
||||
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
|
||||
let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict
|
||||
|
||||
let create ?(stat=Stat.global)
|
||||
?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big)
|
||||
?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(size=`Big)
|
||||
(tst:term_state) : t =
|
||||
let size = match size with `Small -> 128 | `Big -> 2048 in
|
||||
let rec cc = {
|
||||
|
|
@ -824,6 +827,7 @@ module Make(CC_A: ARG) = struct
|
|||
signatures_tbl = Sig_tbl.create size;
|
||||
on_merge;
|
||||
on_new_term;
|
||||
on_conflict;
|
||||
pending=Vec.create();
|
||||
combine=Vec.create();
|
||||
ps_lits=[];
|
||||
|
|
|
|||
|
|
@ -207,8 +207,6 @@ module type CC_S = sig
|
|||
|
||||
type explanation = Expl.t
|
||||
|
||||
type conflict = lit list
|
||||
|
||||
(** Accessors *)
|
||||
|
||||
val term_state : t -> term_state
|
||||
|
|
@ -222,11 +220,13 @@ module type CC_S = sig
|
|||
|
||||
type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit
|
||||
type ev_on_new_term = t -> N.t -> term -> unit
|
||||
type ev_on_conflict = t -> lit list -> unit
|
||||
|
||||
val create :
|
||||
?stat:Stat.t ->
|
||||
?on_merge:ev_on_merge list ->
|
||||
?on_new_term:ev_on_new_term list ->
|
||||
?on_conflict:ev_on_conflict list ->
|
||||
?size:[`Small | `Big] ->
|
||||
term_state ->
|
||||
t
|
||||
|
|
@ -239,6 +239,9 @@ module type CC_S = sig
|
|||
val on_new_term : t -> ev_on_new_term -> unit
|
||||
(** Add a function to be called when a new node is created *)
|
||||
|
||||
val on_conflict : t -> ev_on_conflict -> unit
|
||||
(** Called when the congruence closure finds a conflict *)
|
||||
|
||||
val set_as_lit : t -> N.t -> lit -> unit
|
||||
(** map the given node to a literal. *)
|
||||
|
||||
|
|
@ -311,10 +314,6 @@ module type SOLVER_INTERNAL = sig
|
|||
module Expl = CC.Expl
|
||||
module N = CC.N
|
||||
|
||||
(** Unsatisfiable conjunction.
|
||||
Its negation will become a conflict clause *)
|
||||
type conflict = lit list
|
||||
|
||||
val tst : t -> term_state
|
||||
|
||||
val cc : t -> CC.t
|
||||
|
|
@ -363,7 +362,6 @@ module type SOLVER_INTERNAL = sig
|
|||
(** Propagate a boolean using a unit clause.
|
||||
[expl => lit] must be a theory lemma, that is, a T-tautology *)
|
||||
|
||||
|
||||
val add_clause_temp : t -> actions -> lit list -> unit
|
||||
(** Add local clause to the SAT solver. This clause will be
|
||||
removed when the solver backtracks. *)
|
||||
|
|
@ -411,6 +409,9 @@ module type SOLVER_INTERNAL = sig
|
|||
(** Callback to add data on terms when they are added to the congruence
|
||||
closure *)
|
||||
|
||||
val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit
|
||||
(** Callback called on every CC conflict *)
|
||||
|
||||
val on_partial_check : t -> (t -> actions -> lit Iter.t -> unit) -> unit
|
||||
(** Register callbacked to be called with the slice of literals
|
||||
newly added on the trail.
|
||||
|
|
|
|||
|
|
@ -110,6 +110,10 @@ let main () =
|
|||
in
|
||||
Process.Solver.create ~store_proof:!check ~theories tst ()
|
||||
in
|
||||
if !check then (
|
||||
(* might have to check conflicts *)
|
||||
Solver.add_theory solver Process.Check_cc.theory;
|
||||
);
|
||||
let dot_proof = if !p_dot_proof = "" then None else Some !p_dot_proof in
|
||||
Sidekick_smtlib.parse !file >>= fun input ->
|
||||
(* process statements *)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,3 @@
|
|||
|
||||
type res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module CC_view = Sidekick_core.CC_view
|
||||
|
||||
module type ARG = sig
|
||||
|
|
@ -23,7 +18,7 @@ module type S = sig
|
|||
val add_lit : t -> term -> bool -> unit
|
||||
val distinct : t -> term list -> unit
|
||||
|
||||
val check : t -> res
|
||||
val check_sat : t -> bool
|
||||
|
||||
val classes : t -> term Iter.t Iter.t
|
||||
end
|
||||
|
|
@ -42,9 +37,9 @@ module Make(A: ARG) = struct
|
|||
type node = {
|
||||
n_t: term;
|
||||
mutable n_next: node; (* next in class *)
|
||||
mutable n_size: int; (* size of parent list *)
|
||||
mutable n_size: int; (* size of class *)
|
||||
mutable n_parents: node list;
|
||||
mutable n_root: node;
|
||||
mutable n_root: node; (* root of the class *)
|
||||
}
|
||||
|
||||
type signature = (fun_, node, node list) CC_view.t
|
||||
|
|
@ -55,17 +50,16 @@ module Make(A: ARG) = struct
|
|||
let[@inline] hash (n:t) = T.hash n.n_t
|
||||
let[@inline] size (n:t) = n.n_size
|
||||
let[@inline] is_root n = n == n.n_root
|
||||
let[@inline] root n = n.n_root
|
||||
let[@inline] term n = n.n_t
|
||||
let pp out n = T.pp out n.n_t
|
||||
|
||||
let add_parent (self:t) ~p : unit =
|
||||
self.n_parents <- p :: self.n_parents;
|
||||
self.n_size <- 1 + self.n_size;
|
||||
()
|
||||
self.n_parents <- p :: self.n_parents
|
||||
|
||||
let make (t:T.t) : t =
|
||||
let rec n = {
|
||||
n_t=t; n_size=0; n_next=n;
|
||||
n_t=t; n_size=1; n_next=n;
|
||||
n_parents=[]; n_root=n;
|
||||
} in
|
||||
n
|
||||
|
|
@ -167,28 +161,19 @@ module Make(A: ARG) = struct
|
|||
| n -> n
|
||||
| exception Not_found ->
|
||||
let node = Node.make t in
|
||||
T_tbl.add self.tbl t node;
|
||||
(* add sub-terms, and add [t] to their parent list *)
|
||||
sub_ t
|
||||
(fun u ->
|
||||
let n_u = add_t self u in
|
||||
let n_u = Node.root @@ add_t self u in
|
||||
Node.add_parent n_u ~p:node);
|
||||
T_tbl.add self.tbl t node;
|
||||
(* need to compute signature *)
|
||||
Vec.push self.pending node;
|
||||
node
|
||||
|
||||
(* find representative *)
|
||||
let[@inline] find_ (n:node) : node =
|
||||
let r = n.n_root in
|
||||
assert (Node.is_root r);
|
||||
r
|
||||
|
||||
let find_t_ (self:t) (t:term): node =
|
||||
let n =
|
||||
try T_tbl.find self.tbl t
|
||||
with Not_found -> Error.errorf "minicc.find_t: no node for %a" T.pp t
|
||||
in
|
||||
find_ n
|
||||
try T_tbl.find self.tbl t |> Node.root
|
||||
with Not_found -> Error.errorf "mini-cc.find_t: no node for %a" T.pp t
|
||||
|
||||
(* does this list contain a duplicate? *)
|
||||
let has_dups (l:node list) : bool =
|
||||
|
|
@ -200,7 +185,7 @@ module Make(A: ARG) = struct
|
|||
let check_distinct_ self : unit =
|
||||
Vec.iter
|
||||
(fun r ->
|
||||
r := List.map find_ !r;
|
||||
r := List.rev_map Node.root !r;
|
||||
if has_dups !r then raise_notrace E_unsat)
|
||||
self.distinct
|
||||
|
||||
|
|
@ -232,17 +217,17 @@ module Make(A: ARG) = struct
|
|||
(* reduce to [true] *)
|
||||
let n2 = self.true_ in
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[minicc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2);
|
||||
(fun k->k "(@[mini-cc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2);
|
||||
Vec.push self.combine (n,n2)
|
||||
)
|
||||
| Some s ->
|
||||
Log.debugf 5 (fun k->k "(@[minicc.update-sig@ %a@])" Signature.pp s);
|
||||
Log.debugf 5 (fun k->k "(@[mini-cc.update-sig@ %a@])" Signature.pp s);
|
||||
match Sig_tbl.find self.sig_tbl s with
|
||||
| n2 when Node.equal n n2 -> ()
|
||||
| n2 ->
|
||||
(* collision, merge *)
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[minicc.congruence-by-sig@ %a@ %a@])" Node.pp n Node.pp n2);
|
||||
(fun k->k "(@[mini-cc.congruence-by-sig@ %a@ %a@])" Node.pp n Node.pp n2);
|
||||
Vec.push self.combine (n,n2)
|
||||
| exception Not_found ->
|
||||
Sig_tbl.add self.sig_tbl s n
|
||||
|
|
@ -251,8 +236,8 @@ module Make(A: ARG) = struct
|
|||
|
||||
(* merge the two classes *)
|
||||
let merge_ self (n1,n2) : unit =
|
||||
let n1 = find_ n1 in
|
||||
let n2 = find_ n2 in
|
||||
let n1 = Node.root n1 in
|
||||
let n2 = Node.root n2 in
|
||||
if not @@ Node.equal n1 n2 then (
|
||||
(* merge into largest class, or into a boolean *)
|
||||
let n1, n2 =
|
||||
|
|
@ -260,10 +245,10 @@ module Make(A: ARG) = struct
|
|||
else if is_bool self n2 then n2, n1
|
||||
else if Node.size n1 > Node.size n2 then n1, n2
|
||||
else n2, n1 in
|
||||
Log.debugf 5 (fun k->k "(@[minicc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2);
|
||||
Log.debugf 5 (fun k->k "(@[mini-cc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2);
|
||||
|
||||
if is_bool self n1 && is_bool self n2 then (
|
||||
Log.debugf 5 (fun k->k "(minicc.conflict.merge-true-false)");
|
||||
Log.debugf 5 (fun k->k "(mini-cc.conflict.merge-true-false)");
|
||||
self.ok <- false;
|
||||
raise E_unsat
|
||||
);
|
||||
|
|
@ -276,9 +261,14 @@ module Make(A: ARG) = struct
|
|||
|
||||
(* update root pointer in [n2.class] *)
|
||||
Node.iter_cls n2 (fun n -> n.n_root <- n1);
|
||||
|
||||
(* merge classes [next] pointers *)
|
||||
let n1_next = n1.n_next in
|
||||
n1.n_next <- n2.n_next;
|
||||
n2.n_next <- n1_next;
|
||||
)
|
||||
|
||||
let check_ok_ self =
|
||||
let[@inline] check_ok_ self =
|
||||
if not self.ok then raise_notrace E_unsat
|
||||
|
||||
(* fixpoint of the congruence closure *)
|
||||
|
|
@ -309,18 +299,17 @@ module Make(A: ARG) = struct
|
|||
Vec.push self.combine (n,n2)
|
||||
|
||||
let distinct (self:t) l =
|
||||
begin match l with
|
||||
| [] | [_] -> invalid_arg "distinct: need at least 2 terms"
|
||||
| _ -> ()
|
||||
end;
|
||||
let l = List.map (add_t self) l in
|
||||
Vec.push self.distinct (ref l)
|
||||
match l with
|
||||
| [] | [_] -> () (* trivial *)
|
||||
| _ ->
|
||||
let l = List.rev_map (add_t self) l in
|
||||
Vec.push self.distinct (ref l)
|
||||
|
||||
let check (self:t) : res =
|
||||
try fixpoint self; Sat
|
||||
let check_sat (self:t) : bool =
|
||||
try fixpoint self; true
|
||||
with E_unsat ->
|
||||
self.ok <- false;
|
||||
Unsat
|
||||
false
|
||||
|
||||
let classes self : _ Iter.t =
|
||||
T_tbl.values self.tbl
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
(** {1 Mini congruence closure}
|
||||
|
||||
This implementation is as simple as possible, and doesn't provide
|
||||
|
|
@ -6,10 +5,6 @@
|
|||
It just decides the satisfiability of a set of (dis)equations.
|
||||
*)
|
||||
|
||||
type res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module CC_view = Sidekick_core.CC_view
|
||||
|
||||
module type ARG = sig
|
||||
|
|
@ -33,7 +28,9 @@ module type S = sig
|
|||
val distinct : t -> term list -> unit
|
||||
(** [distinct cc l] asserts that all terms in [l] are distinct *)
|
||||
|
||||
val check : t -> res
|
||||
val check_sat : t -> bool
|
||||
(** [check_sat cc] returns [true] if the current state is satisfiable, [false]
|
||||
if it's unsatisfiable *)
|
||||
|
||||
val classes : t -> term Iter.t Iter.t
|
||||
(** Traverse the set of classes in the congruence closure.
|
||||
|
|
@ -133,7 +133,6 @@ module Make(A : ARG)
|
|||
module Expl = CC.Expl
|
||||
|
||||
type proof = A.Proof.t
|
||||
type conflict = lit list
|
||||
|
||||
let[@inline] cc (t:t) = Lazy.force t.cc
|
||||
let[@inline] tst t = t.tst
|
||||
|
|
@ -210,6 +209,7 @@ module Make(A : ARG)
|
|||
let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check
|
||||
let on_cc_new_term self f = CC.on_new_term (cc self) f
|
||||
let on_cc_merge self f = CC.on_merge (cc self) f
|
||||
let on_cc_conflict self f = CC.on_conflict (cc self) f
|
||||
|
||||
let cc_add_term self t = CC.add_term (cc self) t
|
||||
let cc_find self n = CC.find (cc self) n
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
open Sidekick_base_term
|
||||
|
||||
[@@@ocaml.warning "-32"]
|
||||
|
||||
type 'a or_error = ('a, string) CCResult.t
|
||||
|
||||
module E = CCResult
|
||||
|
|
@ -391,16 +393,50 @@ let conv_ty = Conv.conv_ty
|
|||
let conv_term = Conv.conv_term
|
||||
|
||||
(* instantiate solver here *)
|
||||
module Solver = Sidekick_msat_solver.Make(struct
|
||||
include Sidekick_base_term
|
||||
module Solver_arg = struct
|
||||
include Sidekick_base_term
|
||||
|
||||
let cc_view = Term.cc_view
|
||||
module Proof = struct
|
||||
type t = Default
|
||||
let default=Default
|
||||
let pp out _ = Fmt.string out "default"
|
||||
end
|
||||
end)
|
||||
let cc_view = Term.cc_view
|
||||
module Proof = struct
|
||||
type t = Default
|
||||
let default=Default
|
||||
let pp out _ = Fmt.string out "default"
|
||||
end
|
||||
end
|
||||
module Solver = Sidekick_msat_solver.Make(Solver_arg)
|
||||
|
||||
module Check_cc = struct
|
||||
module SI = Solver.Solver_internal
|
||||
module CC = Solver.Solver_internal.CC
|
||||
module MCC = Sidekick_mini_cc.Make(Solver_arg)
|
||||
|
||||
let pp_c out c = Fmt.fprintf out "(@[%a@])" (Util.pp_list ~sep:" ∨ " Lit.pp) c
|
||||
|
||||
let add_cc_lit (cc:MCC.t) (lit:Lit.t) : unit =
|
||||
let t = Lit.term lit in
|
||||
MCC.add_lit cc t (Lit.sign lit)
|
||||
|
||||
(* check that this is a proper CC conflict *)
|
||||
let check_conflict si _cc (confl:Lit.t list) : unit =
|
||||
Log.debugf 15 (fun k->k "(@[check-cc-conflict@ %a@])" pp_c confl);
|
||||
let tst = SI.tst si in
|
||||
let cc = MCC.create tst in
|
||||
(* add [¬confl] and check it's unsat *)
|
||||
List.iter (fun lit -> add_cc_lit cc @@ Lit.neg lit) confl;
|
||||
if MCC.check_sat cc then (
|
||||
Error.errorf "@[<2>check-cc-conflict:@ @[clause %a@]@ \
|
||||
is not a UF conflict (negation is sat)@]" pp_c confl
|
||||
) else (
|
||||
Log.debugf 15 (fun k->k "(@[check-cc-conflict.ok@ %a@])" pp_c confl);
|
||||
)
|
||||
|
||||
let theory =
|
||||
Solver.mk_theory ~name:"cc-check"
|
||||
~create_and_setup:(fun si ->
|
||||
Solver.Solver_internal.on_cc_conflict si (check_conflict si))
|
||||
()
|
||||
|
||||
end
|
||||
|
||||
(* TODO
|
||||
(* check SMT model *)
|
||||
|
|
@ -501,7 +537,7 @@ let solve
|
|||
(* process a single statement *)
|
||||
let process_stmt
|
||||
?hyps
|
||||
?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?check
|
||||
?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?(check=false)
|
||||
?time ?memory ?progress
|
||||
(solver:Solver.t)
|
||||
(stmt:Ast.statement) : unit or_error =
|
||||
|
|
@ -532,7 +568,7 @@ let process_stmt
|
|||
raise Exit
|
||||
| A.CheckSat ->
|
||||
solve
|
||||
?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress
|
||||
?gc ?restarts ?dot_proof ~check ?pp_model ?time ?memory ?progress
|
||||
~assumptions:[] ?hyps
|
||||
solver;
|
||||
E.return()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ val conv_ty : Ast.Ty.t -> Ty.t
|
|||
|
||||
val conv_term : Term.state -> Ast.term -> Term.t
|
||||
|
||||
module Check_cc : sig
|
||||
(** theory that check validity of conflicts *)
|
||||
val theory : Solver.theory
|
||||
end
|
||||
|
||||
val process_stmt :
|
||||
?hyps:Lit.t list Vec.t ->
|
||||
?gc:bool ->
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
(public_name sidekick.smtlib)
|
||||
(libraries containers zarith msat sidekick.core sidekick.util
|
||||
sidekick.msat-solver sidekick.base-term sidekick.th-bool-static
|
||||
msat.backend)
|
||||
sidekick.mini-cc msat.backend)
|
||||
(flags :standard -open Sidekick_util))
|
||||
|
||||
(menhir
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue