feat: decode proofs from traces; print them in show_trace

This commit is contained in:
Simon Cruanes 2022-10-13 00:03:08 -04:00
parent 4e1272d64a
commit fb8614f304
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
16 changed files with 187 additions and 23 deletions

View file

@ -24,6 +24,8 @@ let add_const_decoders (self : t) (decs : Const.decoders) : unit =
self.const_decode <- Util.Str_map.add tag (ops, dec) self.const_decode)
decs
let store self = self.tst
let create ?(const_decoders = []) ~source tst : t =
let self =
{
@ -104,3 +106,13 @@ let rec read_term_err (self : t) (id : term_ref) : _ result =
let read_term self id =
Result.map_error Dec.Error.to_string @@ read_term_err self id
let read_term_exn self id =
match read_term_err self id with
| Ok t -> t
| Error err -> Error.errorf "term_reader: %a" Dec.Error.pp err
let deref_term self t =
match T_ref.as_ref t with
| None -> t
| Some id -> read_term_exn self id

View file

@ -9,6 +9,15 @@ type t
val create :
?const_decoders:const_decoders list -> source:Tr.Source.t -> Term.store -> t
val store : t -> Term.store
val add_const_decoders : t -> const_decoders -> unit
val read_term : t -> term_ref -> (Term.t, string) result
val read_term_err : t -> term_ref -> (Term.t, Ser_decode.Error.t) result
val read_term_exn : t -> term_ref -> Term.t
(** @raise Error.Error if it fails *)
val deref_term : t -> Term.t -> Term.t
(** [deref_term reader t] dereferences the root node of [t].
If [t] is [Ref id], this returns the term at [id] instead.
@raise Error.Error if it fails *)

View file

@ -16,5 +16,5 @@
(modules show_trace)
(modes native)
(libraries containers sidekick.util sidekick.core sidekick.trace
sidekick.smt-solver sidekick-base)
sidekick.smt-solver sidekick.proof sidekick-base)
(flags :standard -safe-string -color always -open Sidekick_util))

View file

@ -1,10 +1,16 @@
open Sidekick_core
open Sidekick_trace
module Proof = Sidekick_proof
module Smt = Sidekick_smt_solver
type state = { dump: bool; src: Source.t; t_reader: Smt.Trace_reader.t }
type state = {
dump: bool;
src: Source.t;
t_reader: Smt.Trace_reader.t;
p_reader: Proof.Trace_reader.t;
}
let show_sat (_self : state) ~tag v : unit =
let show_sat (self : state) ~tag v : unit =
match tag with
| "AssCSat" ->
(match
@ -64,8 +70,11 @@ let show_file ~dump file : unit =
~const_decoders:
[ Sidekick_core.const_decoders; Sidekick_base.const_decoders ]
in
let p_reader =
Proof.Trace_reader.create ~src (Smt.Trace_reader.term_trace_reader t_reader)
in
let st = { t_reader; src; dump } in
let st = { t_reader; src; dump; p_reader } in
Source.iter_all src (fun i ~tag v -> show_event st i ~tag v)
let () =

View file

@ -22,7 +22,13 @@ type delayed = unit -> t
let rec pp out = function
| P_ref r -> Fmt.fprintf out "!%d" r
| P_local id -> Fmt.fprintf out "s%d" id
| P_apply r -> Fmt.fprintf out "%s" r.rule_name
| P_apply r ->
Fmt.fprintf out "%s{@[" r.rule_name;
if r.premises <> [] then
Fmt.fprintf out "@ :prem %a" Fmt.Dump.(list int) r.premises;
if r.term_args <> [] then
Fmt.fprintf out "@ :ts %a" Fmt.Dump.(list Term.pp) r.term_args;
Fmt.fprintf out "@]}"
| P_let (bs, bod) ->
let pp_b out (x, t) = Fmt.fprintf out "s%d := %a" x pp t in
Fmt.fprintf out "(@[let %a@ in %a@])"
@ -52,6 +58,7 @@ let apply_rule ?(lits = []) ?(terms = []) ?(substs = []) ?(premises = [])
}
module V = Ser_value
module Dec = Ser_decode
let ser_apply_rule (tracer : Term.Tracer.t) (r : rule_apply) : Ser_value.t =
let { rule_name; lit_args; subst_args; term_args; premises; indices } = r in
@ -97,3 +104,51 @@ let rec to_ser (tracer : Term.Tracer.t) t : Ser_value.t =
| P_let (bs, bod) ->
let ser_b (x, t) = list [ int x; recurse t ] in
list [ int 3; list (List.map ser_b bs); recurse bod ])
let deser_apply_rule (t_read : Term.Trace_reader.t) : rule_apply Ser_decode.t =
let open Dec.Infix in
let tst = Term.Trace_reader.store t_read in
let dec_t =
let* i = Dec.int in
match Term.Trace_reader.read_term_err t_read i with
| Ok t -> Dec.return t
| Error e -> Dec.fail_err e
in
let dec_lit : Lit.t Dec.t =
let+ sign, t = Dec.tup2 Dec.bool dec_t in
Lit.atom ~sign tst t
in
let dec_premise : step_id Dec.t = Dec.int in
let dec_indice : step_id Dec.t = Dec.int in
let dec_subst : _ Dec.t = Dec.delay (fun () -> assert false (* TODO *)) in
let+ rule_name = Dec.dict_field "name" Dec.string
and+ lit_args = Dec.dict_field_or [] "lits" (Dec.list dec_lit)
and+ term_args = Dec.dict_field_or [] "t" (Dec.list dec_t)
and+ subst_args = Dec.dict_field_or [] "su" (Dec.list dec_subst)
and+ indices = Dec.dict_field_or [] "idx" (Dec.list dec_indice)
and+ premises = Dec.dict_field_or [] "ps" (Dec.list dec_premise) in
{ rule_name; lit_args; subst_args; term_args; premises; indices }
let rec deser (t_read : Term.Trace_reader.t) : t Ser_decode.t =
let open Dec.Infix in
let* l = Dec.list Dec.any in
match l with
| [ V.Int 0; v ] ->
let+ i = Dec.reflect_or_fail Dec.int v in
P_ref i
| [ V.Int 1; v ] ->
let+ i = Dec.reflect_or_fail Dec.int v in
P_local i
| [ V.Int 2; v ] ->
let+ r = Dec.reflect_or_fail (deser_apply_rule t_read) v in
P_apply r
| [ V.Int 3; bs; body ] ->
let dec_b = Dec.tup2 Dec.int (deser t_read) in
let+ bs = Dec.reflect_or_fail (Dec.list dec_b) bs
and+ body = Dec.reflect_or_fail (deser t_read) body in
P_let (bs, body)
| _ -> Dec.failf "unknown proof-term %a" (Fmt.Dump.list V.pp) l

View file

@ -48,3 +48,5 @@ val apply_rule :
val to_ser : Term.Tracer.t -> t -> Ser_value.t
(** Serialize *)
val deser : Term.Trace_reader.t -> t Ser_decode.t

View file

@ -4,6 +4,7 @@ module Sat_rules = Sat_rules
module Core_rules = Core_rules
module Pterm = Pterm
module Tracer = Tracer
module Trace_reader = Trace_reader
module Arg = Arg
type term = Pterm.t

29
src/proof/trace_reader.ml Normal file
View file

@ -0,0 +1,29 @@
open Sidekick_core
module Tr = Sidekick_trace
module Dec = Ser_decode
open Dec.Infix
type step_id = Step.id
type t = { src: Tr.Source.t; t_reader: Term.Trace_reader.t }
let create ~src t_reader : t = { src; t_reader }
let rec read_step ?(fix = false) (self : t) (id : step_id) : _ result =
match Tr.Source.get_entry self.src id with
| Some ("Pt", v) ->
let res = Dec.run (Pterm.deser self.t_reader) v in
(match res with
| Ok (Pterm.P_ref id') when fix ->
(* read reference recursively *)
read_step ~fix self id'
| _ -> res)
| None ->
Error (Dec.Error.of_string "unknown source entry" (Ser_value.int id))
| Some (tag, _) ->
Error
(Dec.Error.of_string "expected proof term, wrong tag"
(Ser_value.string tag))
let dec_step_id ?fix (self : t) =
let* id = Dec.int in
read_step ?fix self id |> Dec.return_result_err

View file

@ -0,0 +1,16 @@
open Sidekick_core
module Tr = Sidekick_trace
module Dec = Ser_decode
type step_id = Step.id
type t
val create : src:Tr.Source.t -> Term.Trace_reader.t -> t
val read_step : ?fix:bool -> t -> step_id -> (Pterm.t, Dec.Error.t) result
(** Read a step from the source at the given step id, using the trace reader.
@param fix if true, dereferences in a loop so the returned proof term is
not a Ref. *)
val dec_step_id : ?fix:bool -> t -> Pterm.t Dec.t
(** Reads an integer, decodes the corresponding entry *)

View file

@ -1,18 +1,32 @@
open Sidekick_core
module Proof = Sidekick_proof
module Tr = Sidekick_trace
type entry = Assert of Term.t | Assert_clause of { id: int; c: Lit.t list }
type entry =
| Assert of Term.t
| Assert_clause of { id: int; c: Lit.t list; p: Proof.Pterm.t option }
let pp_entry out = function
| Assert t -> Fmt.fprintf out "(@[assert@ %a@])" Term.pp t
| Assert_clause { id; c } ->
Fmt.fprintf out "(@[assert-c[%d]@ %a@])" id (Fmt.Dump.list Lit.pp) c
| Assert_clause { id; c; p } ->
Fmt.fprintf out "(@[assert-c[%d]@ %a@ :proof %a@])" id
(Fmt.Dump.list Lit.pp) c
(Fmt.Dump.option Proof.Pterm.pp)
p
type t = { tst: Term.store; src: Tr.Source.t; t_dec: Term.Trace_reader.t }
type t = {
tst: Term.store;
src: Tr.Source.t;
t_dec: Term.Trace_reader.t;
p_dec: Proof.Trace_reader.t;
}
let create ?const_decoders tst src : t =
let t_dec = Term.Trace_reader.create ?const_decoders tst ~source:src in
{ tst; src; t_dec }
let p_dec = Proof.Trace_reader.create ~src t_dec in
{ tst; src; t_dec; p_dec }
let term_trace_reader self = self.t_dec
let add_const_decoders self c =
Term.Trace_reader.add_const_decoders self.t_dec c
@ -28,8 +42,13 @@ let dec_c (self : t) =
let+ b, t = tup2 bool @@ dec_t self in
Lit.atom self.tst ~sign:b t
in
let+ id = dict_field "id" int and+ c = dict_field "c" (list dec_lit) in
id, c)
let+ id = dict_field "id" int
and+ c = dict_field "c" (list dec_lit)
and+ p =
dict_field_opt "p" (Proof.Trace_reader.dec_step_id ~fix:true self.p_dec)
in
id, c, p)
let decode (self : t) ~tag v =
Log.debugf 30 (fun k ->
@ -45,7 +64,7 @@ let decode (self : t) ~tag v =
| "AssC" ->
Ser_decode.(
(match run (dec_c self) v with
| Ok (id, c) -> Some (Assert_clause { id; c })
| Ok (id, c, p) -> Some (Assert_clause { id; c; p })
| Error err ->
Fmt.eprintf "cannot decode entry with tag %S:@ %a@." tag
Ser_decode.Error.pp err;

View file

@ -1,9 +1,12 @@
(** Read trace *)
open Sidekick_core
module Proof = Sidekick_proof
module Tr = Sidekick_trace
type entry = Assert of Term.t | Assert_clause of { id: int; c: Lit.t list }
type entry =
| Assert of Term.t
| Assert_clause of { id: int; c: Lit.t list; p: Proof.Pterm.t option }
val pp_entry : entry Fmt.printer
@ -13,5 +16,6 @@ val create :
?const_decoders:Const.decoders list -> Term.store -> Tr.Source.t -> t
val add_const_decoders : t -> Const.decoders -> unit
val term_trace_reader : t -> Term.Trace_reader.t
val decode : t -> tag:string -> Ser_value.t -> entry option
val decode_entry : t -> Tr.Entry_id.t -> entry option

View file

@ -44,7 +44,6 @@ class concrete (sink : Tr.Sink.t) : t =
class dummy : t =
object
inherit Term.Tracer.dummy
inherit Sidekick_sat.Tracer.dummy
method emit_assert_term _ = Tr.Entry_id.dummy
end

View file

@ -1,4 +1,3 @@
open Sidekick_core
module Proof = Sidekick_proof
let lemma_isa_cstor ~cstor_t t : Proof.Pterm.t =

View file

@ -1,4 +1,3 @@
open Sidekick_core
module Proof = Sidekick_proof
let lemma_lra lits : Proof.Pterm.t = Proof.Pterm.apply_rule "lra.lemma" ~lits

View file

@ -32,7 +32,8 @@ let[@inline] fail_e e = raise_notrace (Fail e)
let fail_err e = { deser = (fun _ -> fail_e e) }
let return x = { deser = (fun _ -> x) }
let fail s = { deser = (fun v -> fail_ s v) }
let failf fmt = Printf.ksprintf fail fmt
let failf fmt = Fmt.kasprintf fail fmt
let delay f = { deser = (fun v -> (f ()).deser v) }
let return_result = function
| Ok x -> return x
@ -170,6 +171,12 @@ end
include Infix
let dict_field_or default name d =
let+ r = dict_field_opt name d in
match r with
| Some r -> r
| None -> default
let tup2 d1 d2 =
let* l = list any in
match l with

View file

@ -9,7 +9,6 @@ module Error : sig
include Sidekick_sigs.PRINT with type t := t
val to_string : t -> string
val of_string : string -> Ser_value.t -> t
end
@ -24,11 +23,14 @@ val string : string t
val return : 'a -> 'a t
val return_result : ('a, string) result -> 'a t
val return_result_err : ('a, Error.t) result -> 'a t
val delay : (unit -> 'a t) -> 'a t
val fail : string -> 'a t
val failf : ('a, unit, string, 'b t) format4 -> 'a
val failf : ('a, Format.formatter, unit, 'b t) format4 -> 'a
val fail_err : Error.t -> 'a t
val unwrap_opt : string -> 'a option -> 'a t
(** Unwrap option, or fail *)
val any : Ser_value.t t
val list : 'a t -> 'a list t
val tup2 : 'a t -> 'b t -> ('a * 'b) t
@ -36,7 +38,9 @@ val tup3 : 'a t -> 'b t -> 'c t -> ('a*'b*'c) t
val tup4 : 'a t -> 'b t -> 'c t -> 'd t -> ('a * 'b * 'c * 'd) t
val dict_field : string -> 'a t -> 'a t
val dict_field_opt : string -> 'a t -> 'a option t
val dict_field_or : 'a -> string -> 'a t -> 'a t
val both : 'a t -> 'b t -> ('a * 'b) t
val reflect : 'a t -> Ser_value.t -> ('a, Error.t) result t
(** [reflect dec v] returns the result of decoding [v] with [dec] *)