refactor: in msat-solver, adapt to new proofs

This commit is contained in:
Simon Cruanes 2021-08-17 23:59:43 -04:00
parent 7bead748a6
commit 6800b44b1c

View file

@ -11,8 +11,12 @@
module type ARG = sig
open Sidekick_core
module T : TERM
module Lit : LIT with module T = T
type proof
module P : PROOF with type term = T.Term.t and type t = proof
module P : PROOF
with type term = T.Term.t
and type t = proof
and type lit = Lit.t
val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
@ -28,61 +32,28 @@ module Make(A : ARG)
: S
with module T = A.T
and type proof = A.proof
and module Lit = A.Lit
and module P = A.P
= struct
module T = A.T
module P = A.P
module Ty = T.Ty
module Term = T.Term
module Lit = A.Lit
type term = Term.t
type ty = Ty.t
type proof = P.t
module Lit_ = struct
module T = T
type t = {
lit_term: term;
lit_sign : bool
}
let[@inline] neg l = {l with lit_sign=not l.lit_sign}
let[@inline] sign t = t.lit_sign
let[@inline] abs t = {t with lit_sign=true}
let[@inline] term (t:t): term = t.lit_term
let[@inline] signed_term t = term t, sign t
let make ~sign t = {lit_sign=sign; lit_term=t}
let atom tst ?(sign=true) (t:term) : t =
let t, sign' = Term.abs tst t in
let sign = if not sign' then not sign else sign in
make ~sign t
let equal a b =
a.lit_sign = b.lit_sign &&
Term.equal a.lit_term b.lit_term
let hash a =
let sign = a.lit_sign in
CCHash.combine3 2 (CCHash.bool sign) (Term.hash a.lit_term)
let pp out l =
if l.lit_sign then Term.pp out l.lit_term
else Format.fprintf out "(@[@<1>¬@ %a@])" Term.pp l.lit_term
let norm_sign l = if l.lit_sign then l, true else neg l, false
end
type lit = Lit_.t
type proof = A.proof
type dproof = proof -> unit
type lit = Lit.t
(* actions from msat *)
type msat_acts = (lit, P.t) Sidekick_sat.acts
type msat_acts = (lit, proof) Sidekick_sat.acts
(* the full argument to the congruence closure *)
module CC_actions = struct
module T = T
module P = P
module Lit = Lit_
module Lit = Lit
type nonrec proof = proof
let cc_view = A.cc_view
@ -90,10 +61,12 @@ module Make(A : ARG)
module T = T
module P = P
module Lit = Lit
type nonrec proof = proof
type dproof = proof -> unit
type t = msat_acts
let[@inline] raise_conflict (a:t) lits pr =
let[@inline] raise_conflict (a:t) lits (dp:dproof) =
let (module A) = a in
A.raise_conflict lits pr
A.raise_conflict lits dp
let[@inline] propagate (a:t) lit ~reason =
let (module A) = a in
let reason = Sidekick_sat.Consequence reason in
@ -109,9 +82,12 @@ module Make(A : ARG)
module Solver_internal = struct
module T = T
module P = P
module Lit = Lit_
module Lit = Lit
module CC = CC
module N = CC.N
type formula = Lit.t
type nonrec proof = proof
type dproof = proof -> unit
type term = Term.t
type ty = Ty.t
type lit = Lit.t
@ -136,7 +112,7 @@ module Make(A : ARG)
mutable hooks: hook list;
cache: Term.t Term.Tbl.t;
}
and hook = t -> term -> (term * P.t) option
and hook = t -> term -> (term * dproof) option
let create tst ty_st : t =
{tst; ty_st; hooks=[]; cache=Term.Tbl.create 32;}
@ -145,8 +121,8 @@ module Make(A : ARG)
let add_hook self f = self.hooks <- f :: self.hooks
let clear self = Term.Tbl.clear self.cache
let normalize (self:t) (t:Term.t) : (Term.t * P.t) option =
let sub_proofs_ = ref [] in
let normalize (self:t) (t:Term.t) : (Term.t * dproof) option =
let sub_proofs_: dproof list ref = ref [] in
(* compute and cache normal form of [t] *)
let rec aux t : Term.t =
@ -172,15 +148,22 @@ module Make(A : ARG)
let u = aux t in
if Term.equal t u then None
else (
(* proof: [sub_proofs |- t=u] by CC *)
let pr = P.cc_imply_l !sub_proofs_ t u in
Some (u, pr)
(* proof: [sub_proofs |- t=u] by CC + subproof *)
let emit_proof p =
if not (T.Term.equal t u) then (
P.begin_subproof p;
List.iter (fun dp -> dp p) !sub_proofs_;
P.lemma_preprocess p t u;
P.end_subproof p;
)
in
Some (u, emit_proof)
)
let normalize_t self t =
match normalize self t with
| None -> t, P.refl t
| Some (u,pr) -> u, pr
| None -> t, (fun _ -> ())
end
type simplify_hook = Simplify.hook
@ -188,6 +171,7 @@ module Make(A : ARG)
tst: Term.store; (** state for managing terms *)
ty_st: Ty.store;
cc: CC.t lazy_t; (** congruence closure *)
proof: proof; (** proof logger *)
stat: Stat.t;
count_axiom: int Stat.counter;
count_preprocess_clause: int Stat.counter;
@ -197,7 +181,7 @@ module Make(A : ARG)
simp: Simplify.t;
mutable preprocess: preprocess_hook list;
mutable mk_model: model_hook list;
preprocess_cache: (Term.t * P.t list) Term.Tbl.t;
preprocess_cache: (Term.t * dproof list) Term.Tbl.t;
mutable t_defs : (term*term) list; (* term definitions *)
mutable th_states : th_states; (** Set of theories *)
mutable on_partial_check: (t -> actions -> lit Iter.t -> unit) list;
@ -208,8 +192,8 @@ module Make(A : ARG)
and preprocess_hook =
t ->
mk_lit:(term -> lit) ->
add_clause:(lit list -> P.t -> unit) ->
term -> (term * P.t) option
add_clause:(lit list -> dproof -> unit) ->
term -> (term * dproof) option
and model_hook =
recurse:(t -> CC.N.t -> term) ->
@ -220,13 +204,12 @@ module Make(A : ARG)
module Formula = struct
include Lit
let norm lit =
let lit', sign = norm_sign lit in
let lit', sign = Lit.norm_sign lit in
lit', if sign then Sidekick_sat.Same_sign else Sidekick_sat.Negated
end
module Eq_class = CC.N
module Expl = CC.Expl
type proof = P.t
module Proof = P
let[@inline] cc (t:t) = Lazy.force t.cc
let[@inline] tst t = t.tst
@ -238,7 +221,7 @@ module Make(A : ARG)
let simplifier self = self.simp
let simplify_t self (t:Term.t) : _ option = Simplify.normalize self.simp t
let simp_t self (t:Term.t) : Term.t * P.t = Simplify.normalize_t self.simp t
let simp_t self (t:Term.t) : Term.t * dproof = Simplify.normalize_t self.simp t
let add_simplifier (self:t) f : unit = Simplify.add_hook self.simp f
@ -263,23 +246,23 @@ module Make(A : ARG)
let[@inline] propagate_l self acts p cs proof : unit =
propagate self acts p ~reason:(fun()->cs,proof)
let add_sat_clause_ self (acts:actions) ~keep lits (proof:P.t) : unit =
let add_sat_clause_ self (acts:actions) ~keep lits (proof:dproof) : unit =
let (module A) = acts in
Stat.incr self.count_axiom;
A.add_clause ~keep lits proof
let preprocess_term_ (self:t) ~add_clause (t:term) : term * proof =
let preprocess_term_ (self:t) ~add_clause (t:term) : term * dproof =
let mk_lit t = Lit.atom self.tst t in (* no further simplification *)
(* compute and cache normal form [u] of [t].
Also cache a list of proofs [ps] such
that [ps |- t=u] by CC. *)
let rec aux t : term * proof list =
let rec aux t : term * dproof list =
match Term.Tbl.find self.preprocess_cache t with
| u, ps ->
u, ps
| exception Not_found ->
let sub_p: P.t list ref = ref [] in
let sub_p: _ list ref = ref [] in
(* try rewrite at root *)
let t1 = aux_rec ~sub_p t self.preprocess in
@ -338,32 +321,38 @@ module Make(A : ARG)
let u, ps_t1_u = aux t1 in
let pr_t_u =
let emit_proof_t_eq_u =
if t != u then (
let hyps =
if t == t1 then ps_t1_u
else p_t_t1 :: ps_t1_u in
P.cc_imply_l hyps t u
) else P.refl u
let emit_proof p =
P.begin_subproof p;
List.iter (fun dp -> dp p) hyps;
P.lemma_preprocess p t u;
P.end_subproof p;
in
emit_proof
) else (fun _->())
in
u, pr_t_u
u, emit_proof_t_eq_u
(* return preprocessed lit + proof they are equal *)
let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit * proof =
let preprocess_lit_ (self:t) ~add_clause (lit:lit) : lit * dproof =
let t, p = Lit.term lit |> preprocess_term_ self ~add_clause in
let lit' = Lit.atom self.tst ~sign:(Lit.sign lit) t in
if not (Lit.equal lit lit') then (
Log.debugf 10
(fun k->k "(@[msat-solver.preprocess.lit@ :lit %a@ :into %a@ :proof %a@])"
Lit.pp lit Lit.pp lit' (P.pp_debug ~sharing:false) p);
(fun k->k "(@[msat-solver.preprocess.lit@ :lit %a@ :into %a@])"
Lit.pp lit Lit.pp lit');
);
lit', p
(* add a clause using [acts] *)
let add_clause_ self acts lits (proof:P.t) : unit =
let add_clause_ self acts lits (proof:dproof) : unit =
Stat.incr self.count_preprocess_clause;
add_sat_clause_ self acts ~keep:true lits proof
@ -375,13 +364,13 @@ module Make(A : ARG)
in
lit
let[@inline] preprocess_term self ~add_clause (t:term) : term * proof =
let[@inline] preprocess_term self ~add_clause (t:term) : term * dproof =
preprocess_term_ self ~add_clause t
let[@inline] add_clause_temp self acts lits (proof:P.t) : unit =
let[@inline] add_clause_temp self acts lits (proof:dproof) : unit =
add_sat_clause_ self acts ~keep:false lits proof
let[@inline] add_clause_permanent self acts lits (proof:P.t) : unit =
let[@inline] add_clause_permanent self acts lits (proof:dproof) : unit =
add_sat_clause_ self acts ~keep:true lits proof
let[@inline] add_lit _self (acts:actions) lit : unit =
@ -487,7 +476,7 @@ module Make(A : ARG)
let[@inline] final_check (self:t) (acts:_ Sidekick_sat.acts) : unit =
check_ ~final:true self acts
let create ~stat (tst:Term.store) (ty_st:Ty.store) () : t =
let create ~stat ~proof (tst:Term.store) (ty_st:Ty.store) () : t =
let rec self = {
tst;
ty_st;
@ -495,6 +484,7 @@ module Make(A : ARG)
(* lazily tie the knot *)
CC.create ~size:`Big self.tst;
);
proof;
th_states=Ths_nil;
stat;
simp=Simplify.create tst ty_st;
@ -514,7 +504,6 @@ module Make(A : ARG)
ignore (Lazy.force @@ self.cc : CC.t);
self
end
module Lit = Solver_internal.Lit
(** the parametrized SAT Solver *)
module Sat_solver = Sidekick_sat.Make_cdcl_t(Solver_internal)
@ -704,12 +693,12 @@ module Make(A : ARG)
let add_theory_l self = List.iter (add_theory self)
(* create a new solver *)
let create ?(stat=Stat.global) ?size ?store_proof ~theories tst ty_st () : t =
let create ?(stat=Stat.global) ?size ~proof ~theories tst ty_st () : t =
Log.debug 5 "msat-solver.create";
let si = Solver_internal.create ~stat tst ty_st () in
let si = Solver_internal.create ~stat ~proof tst ty_st () in
let self = {
si;
solver=Sat_solver.create ?store_proof ?size si;
solver=Sat_solver.create ~proof ?size si;
stat;
count_clause=Stat.mk_int stat "solver.add-clause";
count_solve=Stat.mk_int stat "solver.solve";
@ -718,9 +707,10 @@ module Make(A : ARG)
(* assert [true] and [not false] *)
begin
let tst = Solver_internal.tst self.si in
let t_true = Term.bool tst true in
Sat_solver.assume self.solver [
[Lit.atom tst @@ Term.bool tst true];
] P.true_is_true
[Lit.atom tst t_true];
] (fun p -> P.lemma_true p t_true)
end;
self
@ -756,12 +746,12 @@ module Make(A : ARG)
CC.set_as_lit cc (CC.add_term cc sub ) (Sat_solver.Atom.formula store atom);
())
let rec mk_atom_lit self lit : Atom.t * P.t =
let rec mk_atom_lit self lit : Atom.t * dproof =
let lit, proof = preprocess_lit_ self lit in
add_bool_subterms_ self (Lit.term lit);
Sat_solver.make_atom self.solver lit, proof
and preprocess_lit_ self lit : Lit.t * P.t =
and preprocess_lit_ self lit : Lit.t * dproof =
Solver_internal.preprocess_lit_
~add_clause:(fun lits proof ->
(* recursively add these sub-literals, so they're also properly processed *)
@ -771,22 +761,17 @@ module Make(A : ARG)
List.map
(fun lit ->
let a, pr = mk_atom_lit self lit in
if not (P.is_trivial_refl pr) then (
(* FIXME if not (P.is_trivial_refl pr) then ( *)
pr_l := pr :: !pr_l;
);
(* ); *)
a)
lits
in
(* do paramodulation if needed *)
let proof =
if !pr_l=[] then proof
else P.(hres_l proof (List.rev_map p1 !pr_l))
in
let proof = P.nn proof in (* normalize lits *)
Sat_solver.add_clause self.solver atoms proof)
let emit_proof p = List.iter (fun dp -> dp p) !pr_l; in
Sat_solver.add_clause self.solver atoms emit_proof)
self.si lit
let[@inline] mk_atom_t self ?sign t : Atom.t * P.t =
let[@inline] mk_atom_t self ?sign t : Atom.t * dproof =
let lit = Lit.atom (tst self) ?sign t in
mk_atom_lit self lit
@ -832,7 +817,6 @@ module Make(A : ARG)
type res =
| Sat of Model.t
| Unsat of {
proof: Pre_proof.t option lazy_t;
unsat_core: Atom.t list lazy_t;
}
| Unknown of Unknown.t
@ -843,12 +827,12 @@ module Make(A : ARG)
let pp_stats out (self:t) : unit =
Stat.pp_all out (Stat.all @@ stats self)
let add_clause (self:t) (c:Atom.t IArray.t) (proof:P.t) : unit =
let add_clause (self:t) (c:Atom.t IArray.t) (proof:dproof) : unit =
Stat.incr self.count_clause;
Log.debugf 50 (fun k->
let store = Sat_solver.store self.solver in
k "(@[solver.add-clause@ %a@ :proof %a@])"
(Util.pp_iarray (Sat_solver.Atom.pp store)) c (P.pp_debug ~sharing:false) proof);
k "(@[solver.add-clause@ %a@])"
(Util.pp_iarray (Sat_solver.Atom.pp store)) c);
let pb = Profile.begin_ "add-clause" in
Sat_solver.add_clause_a self.solver (c:> Atom.t array) proof;
Profile.exit pb
@ -856,9 +840,13 @@ module Make(A : ARG)
let add_clause_l self c p = add_clause self (IArray.of_list c) p
let assert_terms self c =
let p = P.assertion_c_l (List.map P.lit_a c) in
let c = CCList.map (mk_atom_t' self) c in
add_clause_l self c p
let c = CCList.map (fun t -> Lit.atom (tst self) t) c in
let emit_proof p =
P.emit_input_clause p (Iter.of_list c)
in
(* FIXME: just emit proofs on the fly? *)
let c = CCList.map (mk_atom_lit' self) c in
add_clause_l self c emit_proof
let assert_term self t = assert_terms self [t]
@ -872,7 +860,8 @@ module Make(A : ARG)
(* first, add all literals to the model using the given propositional model
[lits]. *)
lits
(fun {Lit.lit_term=t;lit_sign=sign} ->
(fun lit ->
let t, sign = Lit.signed_term lit in
M.replace model t (Term.bool tst sign));
(* compute a value for [n]. *)
@ -938,17 +927,9 @@ module Make(A : ARG)
do_on_exit ();
Sat m
| Sat_solver.Unsat (module UNSAT) ->
let proof = lazy (
try
let pr = UNSAT.get_proof () in
let store = Sat_solver.store self.solver in
if check then Sat_solver.Proof.check store pr;
Some (Pre_proof.make self.solver pr (List.rev self.si.t_defs))
with Sidekick_sat.Solver_intf.No_proof -> None
) in
let unsat_core = lazy (UNSAT.unsat_assumptions ()) in
do_on_exit ();
Unsat {proof; unsat_core}
Unsat {unsat_core}
let mk_theory (type st)
~name ~create_and_setup