diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index 59d5a653..c3645053 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -11,8 +11,12 @@ module type ARG = sig open Sidekick_core module T : TERM + module Lit : LIT with module T = T type proof - module P : PROOF with type term = T.Term.t and type t = proof + module P : PROOF + with type term = T.Term.t + and type t = proof + and type lit = Lit.t val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t @@ -28,61 +32,28 @@ module Make(A : ARG) : S with module T = A.T and type proof = A.proof + and module Lit = A.Lit and module P = A.P = struct module T = A.T module P = A.P module Ty = T.Ty module Term = T.Term + module Lit = A.Lit type term = Term.t type ty = Ty.t - type proof = P.t - - module Lit_ = struct - module T = T - type t = { - lit_term: term; - lit_sign : bool - } - - let[@inline] neg l = {l with lit_sign=not l.lit_sign} - let[@inline] sign t = t.lit_sign - let[@inline] abs t = {t with lit_sign=true} - let[@inline] term (t:t): term = t.lit_term - let[@inline] signed_term t = term t, sign t - - let make ~sign t = {lit_sign=sign; lit_term=t} - - let atom tst ?(sign=true) (t:term) : t = - let t, sign' = Term.abs tst t in - let sign = if not sign' then not sign else sign in - make ~sign t - - let equal a b = - a.lit_sign = b.lit_sign && - Term.equal a.lit_term b.lit_term - - let hash a = - let sign = a.lit_sign in - CCHash.combine3 2 (CCHash.bool sign) (Term.hash a.lit_term) - - let pp out l = - if l.lit_sign then Term.pp out l.lit_term - else Format.fprintf out "(@[@<1>¬@ %a@])" Term.pp l.lit_term - - let norm_sign l = if l.lit_sign then l, true else neg l, false - end - - type lit = Lit_.t + type proof = A.proof + type dproof = proof -> unit + type lit = Lit.t (* actions from msat *) - type msat_acts = (lit, P.t) Sidekick_sat.acts + type msat_acts = (lit, proof) Sidekick_sat.acts (* the full argument to the congruence closure *) module CC_actions = struct module T = T module P = P - module Lit = Lit_ + module Lit = Lit type nonrec proof = proof let cc_view = A.cc_view @@ -90,10 +61,12 @@ module Make(A : ARG) module T = T module P = P module Lit = Lit + type nonrec proof = proof + type dproof = proof -> unit type t = msat_acts - let[@inline] raise_conflict (a:t) lits pr = + let[@inline] raise_conflict (a:t) lits (dp:dproof) = let (module A) = a in - A.raise_conflict lits pr + A.raise_conflict lits dp let[@inline] propagate (a:t) lit ~reason = let (module A) = a in let reason = Sidekick_sat.Consequence reason in @@ -109,9 +82,12 @@ module Make(A : ARG) module Solver_internal = struct module T = T module P = P - module Lit = Lit_ + module Lit = Lit module CC = CC module N = CC.N + type formula = Lit.t + type nonrec proof = proof + type dproof = proof -> unit type term = Term.t type ty = Ty.t type lit = Lit.t @@ -136,7 +112,7 @@ module Make(A : ARG) mutable hooks: hook list; cache: Term.t Term.Tbl.t; } - and hook = t -> term -> (term * P.t) option + and hook = t -> term -> (term * dproof) option let create tst ty_st : t = {tst; ty_st; hooks=[]; cache=Term.Tbl.create 32;} @@ -145,8 +121,8 @@ module Make(A : ARG) let add_hook self f = self.hooks <- f :: self.hooks let clear self = Term.Tbl.clear self.cache - let normalize (self:t) (t:Term.t) : (Term.t * P.t) option = - let sub_proofs_ = ref [] in + let normalize (self:t) (t:Term.t) : (Term.t * dproof) option = + let sub_proofs_: dproof list ref = ref [] in (* compute and cache normal form of [t] *) let rec aux t : Term.t = @@ -172,15 +148,22 @@ module Make(A : ARG) let u = aux t in if Term.equal t u then None else ( - (* proof: [sub_proofs |- t=u] by CC *) - let pr = P.cc_imply_l !sub_proofs_ t u in - Some (u, pr) + (* proof: [sub_proofs |- t=u] by CC + subproof *) + let emit_proof p = + if not (T.Term.equal t u) then ( + P.begin_subproof p; + List.iter (fun dp -> dp p) !sub_proofs_; + P.lemma_preprocess p t u; + P.end_subproof p; + ) + in + Some (u, emit_proof) ) let normalize_t self t = match normalize self t with - | None -> t, P.refl t | Some (u,pr) -> u, pr + | None -> t, (fun _ -> ()) end type simplify_hook = Simplify.hook @@ -188,6 +171,7 @@ module Make(A : ARG) tst: Term.store; (** state for managing terms *) ty_st: Ty.store; cc: CC.t lazy_t; (** congruence closure *) + proof: proof; (** proof logger *) stat: Stat.t; count_axiom: int Stat.counter; count_preprocess_clause: int Stat.counter; @@ -197,7 +181,7 @@ module Make(A : ARG) simp: Simplify.t; mutable preprocess: preprocess_hook list; mutable mk_model: model_hook list; - preprocess_cache: (Term.t * P.t list) Term.Tbl.t; + preprocess_cache: (Term.t * dproof list) Term.Tbl.t; mutable t_defs : (term*term) list; (* term definitions *) mutable th_states : th_states; (** Set of theories *) mutable on_partial_check: (t -> actions -> lit Iter.t -> unit) list; @@ -208,8 +192,8 @@ module Make(A : ARG) and preprocess_hook = t -> mk_lit:(term -> lit) -> - add_clause:(lit list -> P.t -> unit) -> - term -> (term * P.t) option + add_clause:(lit list -> dproof -> unit) -> + term -> (term * dproof) option and model_hook = recurse:(t -> CC.N.t -> term) -> @@ -220,13 +204,12 @@ module Make(A : ARG) module Formula = struct include Lit let norm lit = - let lit', sign = norm_sign lit in + let lit', sign = Lit.norm_sign lit in lit', if sign then Sidekick_sat.Same_sign else Sidekick_sat.Negated end module Eq_class = CC.N module Expl = CC.Expl - - type proof = P.t + module Proof = P let[@inline] cc (t:t) = Lazy.force t.cc let[@inline] tst t = t.tst @@ -238,7 +221,7 @@ module Make(A : ARG) let simplifier self = self.simp let simplify_t self (t:Term.t) : _ option = Simplify.normalize self.simp t - let simp_t self (t:Term.t) : Term.t * P.t = Simplify.normalize_t self.simp t + let simp_t self (t:Term.t) : Term.t * dproof = Simplify.normalize_t self.simp t let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f @@ -263,23 +246,23 @@ module Make(A : ARG) let[@inline] propagate_l self acts p cs proof : unit = propagate self acts p ~reason:(fun()->cs,proof) - let add_sat_clause_ self (acts:actions) ~keep lits (proof:P.t) : unit = + let add_sat_clause_ self (acts:actions) ~keep lits (proof:dproof) : unit = let (module A) = acts in Stat.incr self.count_axiom; A.add_clause ~keep lits proof - let preprocess_term_ (self:t) ~add_clause (t:term) : term * proof = + let preprocess_term_ (self:t) ~add_clause (t:term) : term * dproof = let mk_lit t = Lit.atom self.tst t in (* no further simplification *) (* compute and cache normal form [u] of [t]. Also cache a list of proofs [ps] such that [ps |- t=u] by CC. *) - let rec aux t : term * proof list = + let rec aux t : term * dproof list = match Term.Tbl.find self.preprocess_cache t with | u, ps -> u, ps | exception Not_found -> - let sub_p: P.t list ref = ref [] in + let sub_p: _ list ref = ref [] in (* try rewrite at root *) let t1 = aux_rec ~sub_p t self.preprocess in @@ -338,32 +321,38 @@ module Make(A : ARG) let u, ps_t1_u = aux t1 in - let pr_t_u = + let emit_proof_t_eq_u = if t != u then ( let hyps = if t == t1 then ps_t1_u else p_t_t1 :: ps_t1_u in - P.cc_imply_l hyps t u - ) else P.refl u + let emit_proof p = + P.begin_subproof p; + List.iter (fun dp -> dp p) hyps; + P.lemma_preprocess p t u; + P.end_subproof p; + in + emit_proof + ) else (fun _->()) in - u, pr_t_u + u, emit_proof_t_eq_u (* return preprocessed lit + proof they are equal *) - let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit * proof = + let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit * dproof = let t, p = Lit.term lit |> preprocess_term_ self ~add_clause in let lit' = Lit.atom self.tst ~sign:(Lit.sign lit) t in if not (Lit.equal lit lit') then ( Log.debugf 10 - (fun k->k "(@[msat-solver.preprocess.lit@ :lit %a@ :into %a@ :proof %a@])" - Lit.pp lit Lit.pp lit' (P.pp_debug ~sharing:false) p); + (fun k->k "(@[msat-solver.preprocess.lit@ :lit %a@ :into %a@])" + Lit.pp lit Lit.pp lit'); ); lit', p (* add a clause using [acts] *) - let add_clause_ self acts lits (proof:P.t) : unit = + let add_clause_ self acts lits (proof:dproof) : unit = Stat.incr self.count_preprocess_clause; add_sat_clause_ self acts ~keep:true lits proof @@ -375,13 +364,13 @@ module Make(A : ARG) in lit - let[@inline] preprocess_term self ~add_clause (t:term) : term * proof = + let[@inline] preprocess_term self ~add_clause (t:term) : term * dproof = preprocess_term_ self ~add_clause t - let[@inline] add_clause_temp self acts lits (proof:P.t) : unit = + let[@inline] add_clause_temp self acts lits (proof:dproof) : unit = add_sat_clause_ self acts ~keep:false lits proof - let[@inline] add_clause_permanent self acts lits (proof:P.t) : unit = + let[@inline] add_clause_permanent self acts lits (proof:dproof) : unit = add_sat_clause_ self acts ~keep:true lits proof let[@inline] add_lit _self (acts:actions) lit : unit = @@ -487,7 +476,7 @@ module Make(A : ARG) let[@inline] final_check (self:t) (acts:_ Sidekick_sat.acts) : unit = check_ ~final:true self acts - let create ~stat (tst:Term.store) (ty_st:Ty.store) () : t = + let create ~stat ~proof (tst:Term.store) (ty_st:Ty.store) () : t = let rec self = { tst; ty_st; @@ -495,6 +484,7 @@ module Make(A : ARG) (* lazily tie the knot *) CC.create ~size:`Big self.tst; ); + proof; th_states=Ths_nil; stat; simp=Simplify.create tst ty_st; @@ -514,7 +504,6 @@ module Make(A : ARG) ignore (Lazy.force @@ self.cc : CC.t); self end - module Lit = Solver_internal.Lit (** the parametrized SAT Solver *) module Sat_solver = Sidekick_sat.Make_cdcl_t(Solver_internal) @@ -704,12 +693,12 @@ module Make(A : ARG) let add_theory_l self = List.iter (add_theory self) (* create a new solver *) - let create ?(stat=Stat.global) ?size ?store_proof ~theories tst ty_st () : t = + let create ?(stat=Stat.global) ?size ~proof ~theories tst ty_st () : t = Log.debug 5 "msat-solver.create"; - let si = Solver_internal.create ~stat tst ty_st () in + let si = Solver_internal.create ~stat ~proof tst ty_st () in let self = { si; - solver=Sat_solver.create ?store_proof ?size si; + solver=Sat_solver.create ~proof ?size si; stat; count_clause=Stat.mk_int stat "solver.add-clause"; count_solve=Stat.mk_int stat "solver.solve"; @@ -718,9 +707,10 @@ module Make(A : ARG) (* assert [true] and [not false] *) begin let tst = Solver_internal.tst self.si in + let t_true = Term.bool tst true in Sat_solver.assume self.solver [ - [Lit.atom tst @@ Term.bool tst true]; - ] P.true_is_true + [Lit.atom tst t_true]; + ] (fun p -> P.lemma_true p t_true) end; self @@ -756,12 +746,12 @@ module Make(A : ARG) CC.set_as_lit cc (CC.add_term cc sub ) (Sat_solver.Atom.formula store atom); ()) - let rec mk_atom_lit self lit : Atom.t * P.t = + let rec mk_atom_lit self lit : Atom.t * dproof = let lit, proof = preprocess_lit_ self lit in add_bool_subterms_ self (Lit.term lit); Sat_solver.make_atom self.solver lit, proof - and preprocess_lit_ self lit : Lit.t * P.t = + and preprocess_lit_ self lit : Lit.t * dproof = Solver_internal.preprocess_lit_ ~add_clause:(fun lits proof -> (* recursively add these sub-literals, so they're also properly processed *) @@ -771,22 +761,17 @@ module Make(A : ARG) List.map (fun lit -> let a, pr = mk_atom_lit self lit in - if not (P.is_trivial_refl pr) then ( + (* FIXME if not (P.is_trivial_refl pr) then ( *) pr_l := pr :: !pr_l; - ); + (* ); *) a) lits in - (* do paramodulation if needed *) - let proof = - if !pr_l=[] then proof - else P.(hres_l proof (List.rev_map p1 !pr_l)) - in - let proof = P.nn proof in (* normalize lits *) - Sat_solver.add_clause self.solver atoms proof) + let emit_proof p = List.iter (fun dp -> dp p) !pr_l; in + Sat_solver.add_clause self.solver atoms emit_proof) self.si lit - let[@inline] mk_atom_t self ?sign t : Atom.t * P.t = + let[@inline] mk_atom_t self ?sign t : Atom.t * dproof = let lit = Lit.atom (tst self) ?sign t in mk_atom_lit self lit @@ -832,7 +817,6 @@ module Make(A : ARG) type res = | Sat of Model.t | Unsat of { - proof: Pre_proof.t option lazy_t; unsat_core: Atom.t list lazy_t; } | Unknown of Unknown.t @@ -843,12 +827,12 @@ module Make(A : ARG) let pp_stats out (self:t) : unit = Stat.pp_all out (Stat.all @@ stats self) - let add_clause (self:t) (c:Atom.t IArray.t) (proof:P.t) : unit = + let add_clause (self:t) (c:Atom.t IArray.t) (proof:dproof) : unit = Stat.incr self.count_clause; Log.debugf 50 (fun k-> let store = Sat_solver.store self.solver in - k "(@[solver.add-clause@ %a@ :proof %a@])" - (Util.pp_iarray (Sat_solver.Atom.pp store)) c (P.pp_debug ~sharing:false) proof); + k "(@[solver.add-clause@ %a@])" + (Util.pp_iarray (Sat_solver.Atom.pp store)) c); let pb = Profile.begin_ "add-clause" in Sat_solver.add_clause_a self.solver (c:> Atom.t array) proof; Profile.exit pb @@ -856,9 +840,13 @@ module Make(A : ARG) let add_clause_l self c p = add_clause self (IArray.of_list c) p let assert_terms self c = - let p = P.assertion_c_l (List.map P.lit_a c) in - let c = CCList.map (mk_atom_t' self) c in - add_clause_l self c p + let c = CCList.map (fun t -> Lit.atom (tst self) t) c in + let emit_proof p = + P.emit_input_clause p (Iter.of_list c) + in + (* FIXME: just emit proofs on the fly? *) + let c = CCList.map (mk_atom_lit' self) c in + add_clause_l self c emit_proof let assert_term self t = assert_terms self [t] @@ -872,7 +860,8 @@ module Make(A : ARG) (* first, add all literals to the model using the given propositional model [lits]. *) lits - (fun {Lit.lit_term=t;lit_sign=sign} -> + (fun lit -> + let t, sign = Lit.signed_term lit in M.replace model t (Term.bool tst sign)); (* compute a value for [n]. *) @@ -938,17 +927,9 @@ module Make(A : ARG) do_on_exit (); Sat m | Sat_solver.Unsat (module UNSAT) -> - let proof = lazy ( - try - let pr = UNSAT.get_proof () in - let store = Sat_solver.store self.solver in - if check then Sat_solver.Proof.check store pr; - Some (Pre_proof.make self.solver pr (List.rev self.si.t_defs)) - with Sidekick_sat.Solver_intf.No_proof -> None - ) in let unsat_core = lazy (UNSAT.unsat_assumptions ()) in do_on_exit (); - Unsat {proof; unsat_core} + Unsat {unsat_core} let mk_theory (type st) ~name ~create_and_setup