refactor: use hyper-res steps in proofs

- accelerates proof checking significantly
- provide a way to expand hyper-res steps into individual resolutions
  (eg for the Coq backend)
This commit is contained in:
Simon Cruanes 2019-02-15 19:06:39 -06:00 committed by Guillaume Bury
parent b2cec9eaa2
commit e30c54e11b
4 changed files with 112 additions and 57 deletions

View file

@ -134,7 +134,8 @@ module Make(S : Msat.S)(A : Arg with type hyp := S.clause
let c = P.conclusion p in let c = P.conclusion p in
let () = elim_duplicate fmt clause c l in let () = elim_duplicate fmt clause c l in
clean t fmt [c] clean t fmt [c]
| P.Resolution (p1, p2, a) -> | P.Hyper_res hr ->
let (p1, p2, a) = P.res_of_hyper_res hr in
let c1 = P.conclusion p1 in let c1 = P.conclusion p1 in
let c2 = P.conclusion p2 in let c2 = P.conclusion p2 in
if resolution fmt clause c1 c2 a then clean t fmt [c1; c2] if resolution fmt clause c1 c2 a then clean t fmt [c1; c2]

View file

@ -58,10 +58,9 @@ module Make(S : Msat.S)(A : Arg with type atom := S.atom
module P = S.Proof module P = S.Proof
let node_id n = Clause.name n.P.conclusion let node_id n = Clause.name n.P.conclusion
let res_node_id n = (node_id n) ^ "_res"
let proof_id p = node_id (P.expand p) let proof_id p = node_id (P.expand p)
let res_nn_id n1 n2 = node_id n1 ^ "_" ^ node_id n2 ^ "_res"
let res_np_id n1 n2 = node_id n1 ^ "_" ^ proof_id n2 ^ "_res"
let print_clause fmt c = let print_clause fmt c =
let v = Clause.atoms c in let v = Clause.atoms c in
@ -80,9 +79,11 @@ module Make(S : Msat.S)(A : Arg with type atom := S.atom
let print_edges fmt n = let print_edges fmt n =
match P.(n.step) with match P.(n.step) with
| P.Resolution (p1, p2, _) -> | P.Hyper_res {P.hr_init; hr_steps} ->
print_edge fmt (res_node_id n) (proof_id p1); print_edge fmt (res_np_id n hr_init) (proof_id hr_init);
print_edge fmt (res_node_id n) (proof_id p2) List.iter
(fun (_,p2) -> print_edge fmt (res_np_id n p2) (proof_id p2))
hr_steps;
| _ -> () | _ -> ()
let table_options fmt color = let table_options fmt color =
@ -129,11 +130,15 @@ module Make(S : Msat.S)(A : Arg with type atom := S.atom
((fun fmt () -> (Format.fprintf fmt "%s" (node_id n))) :: ((fun fmt () -> (Format.fprintf fmt "%s" (node_id n))) ::
List.map (ttify A.print_atom) l); List.map (ttify A.print_atom) l);
print_edge fmt (node_id n) (node_id (P.expand p)) print_edge fmt (node_id n) (node_id (P.expand p))
| P.Resolution (_, _, a) -> | P.Hyper_res {P.hr_init; hr_steps} ->
print_dot_node fmt (node_id n) "GREY" P.(n.conclusion) "Resolution" "GREY" print_dot_node fmt (node_id n) "GREY" P.(n.conclusion) "Resolution" "GREY"
[(fun fmt () -> (Format.fprintf fmt "%s" (node_id n)))]; [(fun fmt () -> (Format.fprintf fmt "%s" (node_id n)))];
print_dot_res_node fmt (res_node_id n) a; print_edge fmt (node_id n) (res_np_id n hr_init);
print_edge fmt (node_id n) (res_node_id n) List.iter
(fun (a,p2) ->
print_dot_res_node fmt (res_np_id n p2) a;
print_edge fmt (node_id n) (res_np_id n p2))
hr_steps
let print_node fmt n = let print_node fmt n =
print_contents fmt n; print_contents fmt n;

View file

@ -302,6 +302,8 @@ module Make(Plugin : PLUGIN)
let debug_a out vec = let debug_a out vec =
Array.iter (fun a -> Format.fprintf out "%a@ " debug a) vec Array.iter (fun a -> Format.fprintf out "%a@ " debug a) vec
let debug_l out l =
List.iter (fun a -> Format.fprintf out "%a@ " debug a) l
module Set = Set.Make(struct type t=atom let compare=compare end) module Set = Set.Make(struct type t=atom let compare=compare end)
end end
@ -360,6 +362,7 @@ module Make(Plugin : PLUGIN)
let[@inline] equal c1 c2 = c1.cid = c2.cid let[@inline] equal c1 c2 = c1.cid = c2.cid
let[@inline] hash c = Hashtbl.hash c.cid let[@inline] hash c = Hashtbl.hash c.cid
let[@inline] atoms c = c.atoms let[@inline] atoms c = c.atoms
let[@inline] atoms_seq c = Sequence.of_array c.atoms
let[@inline] atoms_l c = Array.to_list c.atoms let[@inline] atoms_l c = Array.to_list c.atoms
let flag_attached = 0b1 let flag_attached = 0b1
@ -424,9 +427,10 @@ module Make(Plugin : PLUGIN)
let error_res_f msg = Format.kasprintf (fun s -> raise (Resolution_error s)) msg let error_res_f msg = Format.kasprintf (fun s -> raise (Resolution_error s)) msg
let[@inline] cleanup_ (a:atom) = Var.clear a.var let[@inline] clear_var_of_ (a:atom) = Var.clear a.var
(* Compute resolution of 2 clauses *) (* Compute resolution of 2 clauses.
returns [pivots, resulting_atoms] *)
let resolve (c1:clause) (c2:clause) : atom list * atom list = let resolve (c1:clause) (c2:clause) : atom list * atom list =
(* invariants: only atoms in [c2] are marked, and the pivot is (* invariants: only atoms in [c2] are marked, and the pivot is
cleared when traversing [c1] *) cleared when traversing [c1] *)
@ -438,7 +442,7 @@ module Make(Plugin : PLUGIN)
if Atom.seen a then l if Atom.seen a then l
else if Atom.seen a.neg then ( else if Atom.seen a.neg then (
pivots := a.var.pa :: !pivots; pivots := a.var.pa :: !pivots;
cleanup_ a; clear_var_of_ a;
l l
) else a::l) ) else a::l)
[] c1.atoms [] c1.atoms
@ -446,7 +450,7 @@ module Make(Plugin : PLUGIN)
let l = let l =
Array.fold_left (fun l a -> if Atom.seen a then a::l else l) l c2.atoms Array.fold_left (fun l a -> if Atom.seen a then a::l else l) l c2.atoms
in in
Array.iter cleanup_ c2.atoms; Array.iter clear_var_of_ c2.atoms;
!pivots, l !pivots, l
(* [find_dups c] returns a list of duplicate atoms, and the deduplicated list *) (* [find_dups c] returns a list of duplicate atoms, and the deduplicated list *)
@ -462,15 +466,15 @@ module Make(Plugin : PLUGIN)
)) ))
([], []) c.atoms ([], []) c.atoms
in in
Array.iter cleanup_ c.atoms; Array.iter clear_var_of_ c.atoms;
res res
(* do [c1] and [c2] have the same lits, modulo reordering and duplicates? *) (* do [c1] and [c2] have the same lits, modulo reordering and duplicates? *)
let same_lits (c1:atom array) (c2:atom array): bool = let same_lits (c1:atom Sequence.t) (c2:atom Sequence.t): bool =
let subset a b = let subset a b =
Array.iter Atom.mark b; Sequence.iter Atom.mark b;
let res = Array.for_all Atom.seen a in let res = Sequence.for_all Atom.seen a in
Array.iter cleanup_ b; Sequence.iter clear_var_of_ b;
res res
in in
subset c1 c2 && subset c2 c1 subset c1 c2 && subset c2 c1
@ -533,7 +537,12 @@ module Make(Plugin : PLUGIN)
| Assumption | Assumption
| Lemma of lemma | Lemma of lemma
| Duplicate of t * atom list | Duplicate of t * atom list
| Resolution of t * t * atom | Hyper_res of hyper_res_step
and hyper_res_step = {
hr_init: t;
hr_steps: (atom * t) list; (* list of pivot+clause to resolve against [init] *)
}
let[@inline] conclusion (p:t) : clause = p let[@inline] conclusion (p:t) : clause = p
@ -544,31 +553,51 @@ module Make(Plugin : PLUGIN)
rs_pivot: atom; rs_pivot: atom;
} }
let rec chain_res (c:clause) (hist:_ list) : res_step = (* find pivots for resolving [l] with [init], and also return
match hist with the atoms of the conclusion *)
| d :: r -> let find_pivots (init:clause) (l:clause list) : _ * (atom * t) list =
Log.debugf 5 Log.debugf 15
(fun k -> k "(@[sat.analyze.resolving@ :c1 %a@ :c2 %a@])" Clause.debug c Clause.debug d); (fun k->k "(@[proof.find-pivots@ :init %a@ :l %a@])"
begin match resolve c d with Clause.debug init (Format.pp_print_list Clause.debug) l);
| [a], l -> Array.iter Atom.mark init.atoms;
begin match r with let steps =
| [] -> {rs_res=l; rs_c1=c; rs_c2=d; rs_pivot=a} List.map
| _ -> (fun c ->
let new_clause = Clause.make ~flags:c.flags l (History [c; d]) in let pivot =
chain_res new_clause r match
end Sequence.of_array c.atoms
| _ -> |> Sequence.filter (fun a -> Atom.seen (Atom.neg a))
error_res_f "@[<2>clause mismatch while resolving@ %a@ and %a@]" |> Sequence.to_list
Clause.debug c Clause.debug d with
end | [a] -> a
| _ -> | [] ->
error_res_f "bad history" error_res_f "(@[proof.expand.pivot_missing@ %a@])" Clause.debug c
| pivots ->
error_res_f "(@[proof.expand.multiple_pivots@ %a@ :pivots %a@])"
Clause.debug c Atom.debug_l pivots
in
Array.iter Atom.mark c.atoms; (* add atoms to result *)
clear_var_of_ pivot;
Atom.abs pivot, c)
l
in
(* cleanup *)
let res = ref [] in
let cleanup_a_ a =
if Atom.seen a then (
res := a :: !res;
clear_var_of_ a
)
in
Array.iter cleanup_a_ init.atoms;
List.iter (fun c -> Array.iter cleanup_a_ c.atoms) l;
!res, steps
let expand conclusion = let expand conclusion =
Log.debugf 5 (fun k -> k "(@[sat.proof.expand@ @[%a@]@])" Clause.debug conclusion); Log.debugf 5 (fun k -> k "(@[sat.proof.expand@ @[%a@]@])" Clause.debug conclusion);
match conclusion.cpremise with match conclusion.cpremise with
| Lemma l -> | Lemma l ->
{conclusion; step = Lemma l; } { conclusion; step = Lemma l; }
| Local -> | Local ->
{ conclusion; step = Assumption; } { conclusion; step = Assumption; }
| Hyp l -> | Hyp l ->
@ -577,40 +606,51 @@ module Make(Plugin : PLUGIN)
error_res_f "@[empty history for clause@ %a@]" Clause.debug conclusion error_res_f "@[empty history for clause@ %a@]" Clause.debug conclusion
| History [c] -> | History [c] ->
let duplicates, res = find_dups c in let duplicates, res = find_dups c in
assert (same_lits (Array.of_list res) conclusion.atoms); assert (same_lits (Sequence.of_list res) (Clause.atoms_seq conclusion));
{ conclusion; step = Duplicate (c, duplicates) } { conclusion; step = Duplicate (c, duplicates) }
| History (c :: ([_] as r)) -> | History (c :: ([_] as r)) ->
let rs = chain_res c r in let res, steps = find_pivots c r in
assert (same_lits (Array.of_list rs.rs_res) conclusion.atoms); assert (same_lits (Sequence.of_list res) (Clause.atoms_seq conclusion));
{ conclusion; step = Resolution (rs.rs_c1, rs.rs_c2, rs.rs_pivot); } { conclusion; step = Hyper_res { hr_init=c; hr_steps=steps; }; }
| History (c :: r) -> | History (c :: r) ->
let rs = chain_res c r in let res, steps = find_pivots c r in
conclusion.cpremise <- History [rs.rs_c1; rs.rs_c2]; assert (same_lits (Sequence.of_list res) (Clause.atoms_seq conclusion));
assert (same_lits (Array.of_list rs.rs_res) conclusion.atoms); { conclusion; step = Hyper_res {hr_init=c; hr_steps=steps}; }
{ conclusion; step = Resolution (rs.rs_c1, rs.rs_c2, rs.rs_pivot); }
| Empty_premise -> raise Solver_intf.No_proof | Empty_premise -> raise Solver_intf.No_proof
let rec res_of_hyper_res (hr: hyper_res_step) : _ * _ * atom =
let {hr_init=c1; hr_steps=l} = hr in
match l with
| [] -> assert false
| [a, c2] -> c1, c2, a (* done *)
| (a,c2) :: steps' ->
(* resolve [c1] with [c2], then resolve that against [steps] *)
let pivots, l = resolve c1 c2 in
assert (match pivots with [a'] -> Atom.equal a a' | _ -> false);
let c_1_2 = Clause.make_removable l (History [c1; c2]) in
res_of_hyper_res {hr_init=c_1_2; hr_steps=steps'}
(* Proof nodes manipulation *) (* Proof nodes manipulation *)
let is_leaf = function let is_leaf = function
| Hypothesis _ | Hypothesis _
| Assumption | Assumption
| Lemma _ -> true | Lemma _ -> true
| Duplicate _ | Duplicate _
| Resolution _ -> false | Hyper_res _ -> false
let parents = function let parents = function
| Hypothesis _ | Hypothesis _
| Assumption | Assumption
| Lemma _ -> [] | Lemma _ -> []
| Duplicate (p, _) -> [p] | Duplicate (p, _) -> [p]
| Resolution (p, p', _) -> [p; p'] | Hyper_res {hr_init; hr_steps} -> hr_init :: List.map snd hr_steps
let expl = function let expl = function
| Hypothesis _ -> "hypothesis" | Hypothesis _ -> "hypothesis"
| Assumption -> "assumption" | Assumption -> "assumption"
| Lemma _ -> "lemma" | Lemma _ -> "lemma"
| Duplicate _ -> "duplicate" | Duplicate _ -> "duplicate"
| Resolution _ -> "resolution" | Hyper_res _ -> "hyper-resolution"
(* Compute unsat-core (* Compute unsat-core
TODO: replace visited bool by a int unique to each call TODO: replace visited bool by a int unique to each call
@ -658,9 +698,9 @@ module Make(Plugin : PLUGIN)
begin match node.step with begin match node.step with
| Duplicate (p1, _) -> | Duplicate (p1, _) ->
Stack.push (Enter p1) s Stack.push (Enter p1) s
| Resolution (p1, p2, _) -> | Hyper_res {hr_init=p1; hr_steps=l} ->
Stack.push (Enter p2) s; List.iter (fun (_,p2) -> Stack.push (Enter p2) s) l;
Stack.push (Enter p1) s Stack.push (Enter p1) s;
| Hypothesis _ | Assumption | Lemma _ -> () | Hypothesis _ | Assumption | Lemma _ -> ()
end end
end; end;

View file

@ -286,9 +286,12 @@ module type PROOF = sig
| Duplicate of t * atom list | Duplicate of t * atom list
(** The conclusion is obtained by eliminating multiple occurences of the atom in (** The conclusion is obtained by eliminating multiple occurences of the atom in
the conclusion of the provided proof. *) the conclusion of the provided proof. *)
| Resolution of t * t * atom | Hyper_res of hyper_res_step
(** The conclusion can be deduced by performing a resolution between the conclusions
of the two given proofs. The atom on which to perform the resolution is also given. *) and hyper_res_step = {
hr_init: t;
hr_steps: (atom * t) list; (* list of pivot+clause to resolve against [init] *)
}
(** {3 Proof building functions} *) (** {3 Proof building functions} *)
@ -303,6 +306,12 @@ module type PROOF = sig
val prove_atom : atom -> t option val prove_atom : atom -> t option
(** Given an atom [a], returns a proof of the clause [[a]] if [a] is true at level 0 *) (** Given an atom [a], returns a proof of the clause [[a]] if [a] is true at level 0 *)
val res_of_hyper_res : hyper_res_step -> t * t * atom
(** Turn an hyper resolution step into a resolution step.
The conclusion can be deduced by performing a resolution between the conclusions
of the two given proofs.
The atom on which to perform the resolution is also given. *)
(** {3 Proof Nodes} *) (** {3 Proof Nodes} *)
val parents : step -> t list val parents : step -> t list