diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index ff04c1e9..997f66fe 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -66,7 +66,6 @@ module Make(CC_A: ARG) = struct | E_and of explanation * explanation type repr = node - type conflict = lit list module N = struct type t = node @@ -213,6 +212,7 @@ module Make(CC_A: ARG) = struct undo: (unit -> unit) Backtrack_stack.t; mutable on_merge: ev_on_merge list; mutable on_new_term: ev_on_new_term list; + mutable on_conflict: ev_on_conflict list; mutable ps_lits: lit list; (* TODO: thread it around instead? *) (* proof state *) ps_queue: (node*node) Vec.t; @@ -232,6 +232,7 @@ module Make(CC_A: ARG) = struct and ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit and ev_on_new_term = t -> N.t -> term -> unit + and ev_on_conflict = t -> lit list -> unit let[@inline] size_ (r:repr) = r.n_size let[@inline] true_ cc = Lazy.force cc.true_ @@ -337,11 +338,12 @@ module Make(CC_A: ARG) = struct n.n_expl <- FL_none; end - let raise_conflict (cc:t) (acts:actions) (e:conflict) : _ = + let raise_conflict (cc:t) (acts:actions) (e:lit list) : _ = (* clear tasks queue *) Vec.iter (N.set_field field_is_pending false) cc.pending; Vec.clear cc.pending; Vec.clear cc.combine; + List.iter (fun f -> f cc e) cc.on_conflict; Stat.incr cc.count_conflict; CC_A.Actions.raise_conflict acts e A.Proof.default @@ -813,9 +815,10 @@ module Make(CC_A: ARG) = struct let on_merge cc f = cc.on_merge <- f :: cc.on_merge let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term + let on_conflict cc f = cc.on_conflict <- f :: cc.on_conflict let create ?(stat=Stat.global) - ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) + ?(on_merge=[]) ?(on_new_term=[]) ?(on_conflict=[]) ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in let rec cc = { @@ -824,6 +827,7 @@ module Make(CC_A: ARG) = struct signatures_tbl = Sig_tbl.create size; on_merge; on_new_term; + on_conflict; pending=Vec.create(); combine=Vec.create(); ps_lits=[]; diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index fb11e67e..346ec14e 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -207,8 +207,6 @@ module type CC_S = sig type explanation = Expl.t - type conflict = lit list - (** Accessors *) val term_state : t -> term_state @@ -222,11 +220,13 @@ module type CC_S = sig type ev_on_merge = t -> actions -> N.t -> N.t -> Expl.t -> unit type ev_on_new_term = t -> N.t -> term -> unit + type ev_on_conflict = t -> lit list -> unit val create : ?stat:Stat.t -> ?on_merge:ev_on_merge list -> ?on_new_term:ev_on_new_term list -> + ?on_conflict:ev_on_conflict list -> ?size:[`Small | `Big] -> term_state -> t @@ -239,6 +239,9 @@ module type CC_S = sig val on_new_term : t -> ev_on_new_term -> unit (** Add a function to be called when a new node is created *) + val on_conflict : t -> ev_on_conflict -> unit + (** Called when the congruence closure finds a conflict *) + val set_as_lit : t -> N.t -> lit -> unit (** map the given node to a literal. *) @@ -311,10 +314,6 @@ module type SOLVER_INTERNAL = sig module Expl = CC.Expl module N = CC.N - (** Unsatisfiable conjunction. - Its negation will become a conflict clause *) - type conflict = lit list - val tst : t -> term_state val cc : t -> CC.t @@ -363,7 +362,6 @@ module type SOLVER_INTERNAL = sig (** Propagate a boolean using a unit clause. [expl => lit] must be a theory lemma, that is, a T-tautology *) - val add_clause_temp : t -> actions -> lit list -> unit (** Add local clause to the SAT solver. This clause will be removed when the solver backtracks. *) @@ -411,6 +409,9 @@ module type SOLVER_INTERNAL = sig (** Callback to add data on terms when they are added to the congruence closure *) + val on_cc_conflict : t -> (CC.t -> lit list -> unit) -> unit + (** Callback called on every CC conflict *) + val on_partial_check : t -> (t -> actions -> lit Iter.t -> unit) -> unit (** Register callbacked to be called with the slice of literals newly added on the trail. diff --git a/src/main/main.ml b/src/main/main.ml index 5c951969..5450b70a 100644 --- a/src/main/main.ml +++ b/src/main/main.ml @@ -110,6 +110,10 @@ let main () = in Process.Solver.create ~store_proof:!check ~theories tst () in + if !check then ( + (* might have to check conflicts *) + Solver.add_theory solver Process.Check_cc.theory; + ); let dot_proof = if !p_dot_proof = "" then None else Some !p_dot_proof in Sidekick_smtlib.parse !file >>= fun input -> (* process statements *) diff --git a/src/mini-cc/Mini_cc.ml b/src/mini-cc/Sidekick_mini_cc.ml similarity index 86% rename from src/mini-cc/Mini_cc.ml rename to src/mini-cc/Sidekick_mini_cc.ml index 0b89e069..18a98ca0 100644 --- a/src/mini-cc/Mini_cc.ml +++ b/src/mini-cc/Sidekick_mini_cc.ml @@ -1,8 +1,3 @@ - -type res = - | Sat - | Unsat - module CC_view = Sidekick_core.CC_view module type ARG = sig @@ -23,7 +18,7 @@ module type S = sig val add_lit : t -> term -> bool -> unit val distinct : t -> term list -> unit - val check : t -> res + val check_sat : t -> bool val classes : t -> term Iter.t Iter.t end @@ -42,9 +37,9 @@ module Make(A: ARG) = struct type node = { n_t: term; mutable n_next: node; (* next in class *) - mutable n_size: int; (* size of parent list *) + mutable n_size: int; (* size of class *) mutable n_parents: node list; - mutable n_root: node; + mutable n_root: node; (* root of the class *) } type signature = (fun_, node, node list) CC_view.t @@ -55,17 +50,16 @@ module Make(A: ARG) = struct let[@inline] hash (n:t) = T.hash n.n_t let[@inline] size (n:t) = n.n_size let[@inline] is_root n = n == n.n_root + let[@inline] root n = n.n_root let[@inline] term n = n.n_t let pp out n = T.pp out n.n_t let add_parent (self:t) ~p : unit = - self.n_parents <- p :: self.n_parents; - self.n_size <- 1 + self.n_size; - () + self.n_parents <- p :: self.n_parents let make (t:T.t) : t = let rec n = { - n_t=t; n_size=0; n_next=n; + n_t=t; n_size=1; n_next=n; n_parents=[]; n_root=n; } in n @@ -167,28 +161,19 @@ module Make(A: ARG) = struct | n -> n | exception Not_found -> let node = Node.make t in + T_tbl.add self.tbl t node; (* add sub-terms, and add [t] to their parent list *) sub_ t (fun u -> - let n_u = add_t self u in + let n_u = Node.root @@ add_t self u in Node.add_parent n_u ~p:node); - T_tbl.add self.tbl t node; (* need to compute signature *) Vec.push self.pending node; node - (* find representative *) - let[@inline] find_ (n:node) : node = - let r = n.n_root in - assert (Node.is_root r); - r - let find_t_ (self:t) (t:term): node = - let n = - try T_tbl.find self.tbl t - with Not_found -> Error.errorf "minicc.find_t: no node for %a" T.pp t - in - find_ n + try T_tbl.find self.tbl t |> Node.root + with Not_found -> Error.errorf "mini-cc.find_t: no node for %a" T.pp t (* does this list contain a duplicate? *) let has_dups (l:node list) : bool = @@ -200,7 +185,7 @@ module Make(A: ARG) = struct let check_distinct_ self : unit = Vec.iter (fun r -> - r := List.map find_ !r; + r := List.rev_map Node.root !r; if has_dups !r then raise_notrace E_unsat) self.distinct @@ -232,17 +217,17 @@ module Make(A: ARG) = struct (* reduce to [true] *) let n2 = self.true_ in Log.debugf 5 - (fun k->k "(@[minicc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2); + (fun k->k "(@[mini-cc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2); Vec.push self.combine (n,n2) ) | Some s -> - Log.debugf 5 (fun k->k "(@[minicc.update-sig@ %a@])" Signature.pp s); + Log.debugf 5 (fun k->k "(@[mini-cc.update-sig@ %a@])" Signature.pp s); match Sig_tbl.find self.sig_tbl s with | n2 when Node.equal n n2 -> () | n2 -> (* collision, merge *) Log.debugf 5 - (fun k->k "(@[minicc.congruence-by-sig@ %a@ %a@])" Node.pp n Node.pp n2); + (fun k->k "(@[mini-cc.congruence-by-sig@ %a@ %a@])" Node.pp n Node.pp n2); Vec.push self.combine (n,n2) | exception Not_found -> Sig_tbl.add self.sig_tbl s n @@ -251,8 +236,8 @@ module Make(A: ARG) = struct (* merge the two classes *) let merge_ self (n1,n2) : unit = - let n1 = find_ n1 in - let n2 = find_ n2 in + let n1 = Node.root n1 in + let n2 = Node.root n2 in if not @@ Node.equal n1 n2 then ( (* merge into largest class, or into a boolean *) let n1, n2 = @@ -260,10 +245,10 @@ module Make(A: ARG) = struct else if is_bool self n2 then n2, n1 else if Node.size n1 > Node.size n2 then n1, n2 else n2, n1 in - Log.debugf 5 (fun k->k "(@[minicc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2); + Log.debugf 5 (fun k->k "(@[mini-cc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2); if is_bool self n1 && is_bool self n2 then ( - Log.debugf 5 (fun k->k "(minicc.conflict.merge-true-false)"); + Log.debugf 5 (fun k->k "(mini-cc.conflict.merge-true-false)"); self.ok <- false; raise E_unsat ); @@ -276,9 +261,14 @@ module Make(A: ARG) = struct (* update root pointer in [n2.class] *) Node.iter_cls n2 (fun n -> n.n_root <- n1); + + (* merge classes [next] pointers *) + let n1_next = n1.n_next in + n1.n_next <- n2.n_next; + n2.n_next <- n1_next; ) - let check_ok_ self = + let[@inline] check_ok_ self = if not self.ok then raise_notrace E_unsat (* fixpoint of the congruence closure *) @@ -309,18 +299,17 @@ module Make(A: ARG) = struct Vec.push self.combine (n,n2) let distinct (self:t) l = - begin match l with - | [] | [_] -> invalid_arg "distinct: need at least 2 terms" - | _ -> () - end; - let l = List.map (add_t self) l in - Vec.push self.distinct (ref l) + match l with + | [] | [_] -> () (* trivial *) + | _ -> + let l = List.rev_map (add_t self) l in + Vec.push self.distinct (ref l) - let check (self:t) : res = - try fixpoint self; Sat + let check_sat (self:t) : bool = + try fixpoint self; true with E_unsat -> self.ok <- false; - Unsat + false let classes self : _ Iter.t = T_tbl.values self.tbl diff --git a/src/mini-cc/Mini_cc.mli b/src/mini-cc/Sidekick_mini_cc.mli similarity index 87% rename from src/mini-cc/Mini_cc.mli rename to src/mini-cc/Sidekick_mini_cc.mli index 53cf6781..4da3afd6 100644 --- a/src/mini-cc/Mini_cc.mli +++ b/src/mini-cc/Sidekick_mini_cc.mli @@ -1,4 +1,3 @@ - (** {1 Mini congruence closure} This implementation is as simple as possible, and doesn't provide @@ -6,10 +5,6 @@ It just decides the satisfiability of a set of (dis)equations. *) -type res = - | Sat - | Unsat - module CC_view = Sidekick_core.CC_view module type ARG = sig @@ -33,7 +28,9 @@ module type S = sig val distinct : t -> term list -> unit (** [distinct cc l] asserts that all terms in [l] are distinct *) - val check : t -> res + val check_sat : t -> bool + (** [check_sat cc] returns [true] if the current state is satisfiable, [false] + if it's unsatisfiable *) val classes : t -> term Iter.t Iter.t (** Traverse the set of classes in the congruence closure. diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index de56191c..ed7841bb 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -133,7 +133,6 @@ module Make(A : ARG) module Expl = CC.Expl type proof = A.Proof.t - type conflict = lit list let[@inline] cc (t:t) = Lazy.force t.cc let[@inline] tst t = t.tst @@ -210,6 +209,7 @@ module Make(A : ARG) let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check let on_cc_new_term self f = CC.on_new_term (cc self) f let on_cc_merge self f = CC.on_merge (cc self) f + let on_cc_conflict self f = CC.on_conflict (cc self) f let cc_add_term self t = CC.add_term (cc self) t let cc_find self n = CC.find (cc self) n diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index fa0c590b..7b3d0f8f 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -2,6 +2,8 @@ open Sidekick_base_term +[@@@ocaml.warning "-32"] + type 'a or_error = ('a, string) CCResult.t module E = CCResult @@ -391,16 +393,50 @@ let conv_ty = Conv.conv_ty let conv_term = Conv.conv_term (* instantiate solver here *) -module Solver = Sidekick_msat_solver.Make(struct - include Sidekick_base_term +module Solver_arg = struct + include Sidekick_base_term - let cc_view = Term.cc_view - module Proof = struct - type t = Default - let default=Default - let pp out _ = Fmt.string out "default" - end - end) + let cc_view = Term.cc_view + module Proof = struct + type t = Default + let default=Default + let pp out _ = Fmt.string out "default" + end +end +module Solver = Sidekick_msat_solver.Make(Solver_arg) + +module Check_cc = struct + module SI = Solver.Solver_internal + module CC = Solver.Solver_internal.CC + module MCC = Sidekick_mini_cc.Make(Solver_arg) + + let pp_c out c = Fmt.fprintf out "(@[%a@])" (Util.pp_list ~sep:" ∨ " Lit.pp) c + + let add_cc_lit (cc:MCC.t) (lit:Lit.t) : unit = + let t = Lit.term lit in + MCC.add_lit cc t (Lit.sign lit) + + (* check that this is a proper CC conflict *) + let check_conflict si _cc (confl:Lit.t list) : unit = + Log.debugf 15 (fun k->k "(@[check-cc-conflict@ %a@])" pp_c confl); + let tst = SI.tst si in + let cc = MCC.create tst in + (* add [¬confl] and check it's unsat *) + List.iter (fun lit -> add_cc_lit cc @@ Lit.neg lit) confl; + if MCC.check_sat cc then ( + Error.errorf "@[<2>check-cc-conflict:@ @[clause %a@]@ \ + is not a UF conflict (negation is sat)@]" pp_c confl + ) else ( + Log.debugf 15 (fun k->k "(@[check-cc-conflict.ok@ %a@])" pp_c confl); + ) + + let theory = + Solver.mk_theory ~name:"cc-check" + ~create_and_setup:(fun si -> + Solver.Solver_internal.on_cc_conflict si (check_conflict si)) + () + +end (* TODO (* check SMT model *) @@ -501,7 +537,7 @@ let solve (* process a single statement *) let process_stmt ?hyps - ?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?check + ?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?(check=false) ?time ?memory ?progress (solver:Solver.t) (stmt:Ast.statement) : unit or_error = @@ -532,7 +568,7 @@ let process_stmt raise Exit | A.CheckSat -> solve - ?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress + ?gc ?restarts ?dot_proof ~check ?pp_model ?time ?memory ?progress ~assumptions:[] ?hyps solver; E.return() diff --git a/src/smtlib/Process.mli b/src/smtlib/Process.mli index d4c58f5d..6d845728 100644 --- a/src/smtlib/Process.mli +++ b/src/smtlib/Process.mli @@ -18,6 +18,11 @@ val conv_ty : Ast.Ty.t -> Ty.t val conv_term : Term.state -> Ast.term -> Term.t +module Check_cc : sig + (** theory that check validity of conflicts *) + val theory : Solver.theory +end + val process_stmt : ?hyps:Lit.t list Vec.t -> ?gc:bool -> diff --git a/src/smtlib/dune b/src/smtlib/dune index c9337849..ac71b585 100644 --- a/src/smtlib/dune +++ b/src/smtlib/dune @@ -3,7 +3,7 @@ (public_name sidekick.smtlib) (libraries containers zarith msat sidekick.core sidekick.util sidekick.msat-solver sidekick.base-term sidekick.th-bool-static - msat.backend) + sidekick.mini-cc msat.backend) (flags :standard -open Sidekick_util)) (menhir