From fb8614f304ad221df8fbd89368a0ebcd51ae22cb Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 13 Oct 2022 00:03:08 -0400 Subject: [PATCH] feat: decode proofs from traces; print them in show_trace --- src/core/t_trace_reader.ml | 12 ++++++++ src/core/t_trace_reader.mli | 9 ++++++ src/main/dune | 2 +- src/main/show_trace.ml | 15 ++++++++-- src/proof/pterm.ml | 57 ++++++++++++++++++++++++++++++++++++- src/proof/pterm.mli | 2 ++ src/proof/sidekick_proof.ml | 1 + src/proof/trace_reader.ml | 29 +++++++++++++++++++ src/proof/trace_reader.mli | 16 +++++++++++ src/smt/trace_reader.ml | 35 +++++++++++++++++------ src/smt/trace_reader.mli | 6 +++- src/smt/tracer.ml | 1 - src/th-data/proof_rules.ml | 1 - src/th-lra/proof_rules.ml | 1 - src/util/ser_decode.ml | 9 +++++- src/util/ser_decode.mli | 14 +++++---- 16 files changed, 187 insertions(+), 23 deletions(-) create mode 100644 src/proof/trace_reader.ml create mode 100644 src/proof/trace_reader.mli diff --git a/src/core/t_trace_reader.ml b/src/core/t_trace_reader.ml index a81e1948..e2f8ceac 100644 --- a/src/core/t_trace_reader.ml +++ b/src/core/t_trace_reader.ml @@ -24,6 +24,8 @@ let add_const_decoders (self : t) (decs : Const.decoders) : unit = self.const_decode <- Util.Str_map.add tag (ops, dec) self.const_decode) decs +let store self = self.tst + let create ?(const_decoders = []) ~source tst : t = let self = { @@ -104,3 +106,13 @@ let rec read_term_err (self : t) (id : term_ref) : _ result = let read_term self id = Result.map_error Dec.Error.to_string @@ read_term_err self id + +let read_term_exn self id = + match read_term_err self id with + | Ok t -> t + | Error err -> Error.errorf "term_reader: %a" Dec.Error.pp err + +let deref_term self t = + match T_ref.as_ref t with + | None -> t + | Some id -> read_term_exn self id diff --git a/src/core/t_trace_reader.mli b/src/core/t_trace_reader.mli index 22d919bc..bc1b7055 100644 --- a/src/core/t_trace_reader.mli +++ b/src/core/t_trace_reader.mli @@ -9,6 +9,15 @@ type t val create : ?const_decoders:const_decoders list -> source:Tr.Source.t -> Term.store -> t +val store : t -> Term.store val add_const_decoders : t -> const_decoders -> unit val read_term : t -> term_ref -> (Term.t, string) result val read_term_err : t -> term_ref -> (Term.t, Ser_decode.Error.t) result + +val read_term_exn : t -> term_ref -> Term.t +(** @raise Error.Error if it fails *) + +val deref_term : t -> Term.t -> Term.t +(** [deref_term reader t] dereferences the root node of [t]. + If [t] is [Ref id], this returns the term at [id] instead. + @raise Error.Error if it fails *) diff --git a/src/main/dune b/src/main/dune index cea88d77..8e519ff3 100644 --- a/src/main/dune +++ b/src/main/dune @@ -16,5 +16,5 @@ (modules show_trace) (modes native) (libraries containers sidekick.util sidekick.core sidekick.trace - sidekick.smt-solver sidekick-base) + sidekick.smt-solver sidekick.proof sidekick-base) (flags :standard -safe-string -color always -open Sidekick_util)) diff --git a/src/main/show_trace.ml b/src/main/show_trace.ml index fa58c8b6..38e56702 100644 --- a/src/main/show_trace.ml +++ b/src/main/show_trace.ml @@ -1,10 +1,16 @@ open Sidekick_core open Sidekick_trace +module Proof = Sidekick_proof module Smt = Sidekick_smt_solver -type state = { dump: bool; src: Source.t; t_reader: Smt.Trace_reader.t } +type state = { + dump: bool; + src: Source.t; + t_reader: Smt.Trace_reader.t; + p_reader: Proof.Trace_reader.t; +} -let show_sat (_self : state) ~tag v : unit = +let show_sat (self : state) ~tag v : unit = match tag with | "AssCSat" -> (match @@ -64,8 +70,11 @@ let show_file ~dump file : unit = ~const_decoders: [ Sidekick_core.const_decoders; Sidekick_base.const_decoders ] in + let p_reader = + Proof.Trace_reader.create ~src (Smt.Trace_reader.term_trace_reader t_reader) + in - let st = { t_reader; src; dump } in + let st = { t_reader; src; dump; p_reader } in Source.iter_all src (fun i ~tag v -> show_event st i ~tag v) let () = diff --git a/src/proof/pterm.ml b/src/proof/pterm.ml index 9409f0d4..399827c4 100644 --- a/src/proof/pterm.ml +++ b/src/proof/pterm.ml @@ -22,7 +22,13 @@ type delayed = unit -> t let rec pp out = function | P_ref r -> Fmt.fprintf out "!%d" r | P_local id -> Fmt.fprintf out "s%d" id - | P_apply r -> Fmt.fprintf out "%s" r.rule_name + | P_apply r -> + Fmt.fprintf out "%s{@[" r.rule_name; + if r.premises <> [] then + Fmt.fprintf out "@ :prem %a" Fmt.Dump.(list int) r.premises; + if r.term_args <> [] then + Fmt.fprintf out "@ :ts %a" Fmt.Dump.(list Term.pp) r.term_args; + Fmt.fprintf out "@]}" | P_let (bs, bod) -> let pp_b out (x, t) = Fmt.fprintf out "s%d := %a" x pp t in Fmt.fprintf out "(@[let %a@ in %a@])" @@ -52,6 +58,7 @@ let apply_rule ?(lits = []) ?(terms = []) ?(substs = []) ?(premises = []) } module V = Ser_value +module Dec = Ser_decode let ser_apply_rule (tracer : Term.Tracer.t) (r : rule_apply) : Ser_value.t = let { rule_name; lit_args; subst_args; term_args; premises; indices } = r in @@ -97,3 +104,51 @@ let rec to_ser (tracer : Term.Tracer.t) t : Ser_value.t = | P_let (bs, bod) -> let ser_b (x, t) = list [ int x; recurse t ] in list [ int 3; list (List.map ser_b bs); recurse bod ]) + +let deser_apply_rule (t_read : Term.Trace_reader.t) : rule_apply Ser_decode.t = + let open Dec.Infix in + let tst = Term.Trace_reader.store t_read in + + let dec_t = + let* i = Dec.int in + match Term.Trace_reader.read_term_err t_read i with + | Ok t -> Dec.return t + | Error e -> Dec.fail_err e + in + + let dec_lit : Lit.t Dec.t = + let+ sign, t = Dec.tup2 Dec.bool dec_t in + Lit.atom ~sign tst t + in + let dec_premise : step_id Dec.t = Dec.int in + let dec_indice : step_id Dec.t = Dec.int in + let dec_subst : _ Dec.t = Dec.delay (fun () -> assert false (* TODO *)) in + + let+ rule_name = Dec.dict_field "name" Dec.string + and+ lit_args = Dec.dict_field_or [] "lits" (Dec.list dec_lit) + and+ term_args = Dec.dict_field_or [] "t" (Dec.list dec_t) + and+ subst_args = Dec.dict_field_or [] "su" (Dec.list dec_subst) + and+ indices = Dec.dict_field_or [] "idx" (Dec.list dec_indice) + and+ premises = Dec.dict_field_or [] "ps" (Dec.list dec_premise) in + + { rule_name; lit_args; subst_args; term_args; premises; indices } + +let rec deser (t_read : Term.Trace_reader.t) : t Ser_decode.t = + let open Dec.Infix in + let* l = Dec.list Dec.any in + match l with + | [ V.Int 0; v ] -> + let+ i = Dec.reflect_or_fail Dec.int v in + P_ref i + | [ V.Int 1; v ] -> + let+ i = Dec.reflect_or_fail Dec.int v in + P_local i + | [ V.Int 2; v ] -> + let+ r = Dec.reflect_or_fail (deser_apply_rule t_read) v in + P_apply r + | [ V.Int 3; bs; body ] -> + let dec_b = Dec.tup2 Dec.int (deser t_read) in + let+ bs = Dec.reflect_or_fail (Dec.list dec_b) bs + and+ body = Dec.reflect_or_fail (deser t_read) body in + P_let (bs, body) + | _ -> Dec.failf "unknown proof-term %a" (Fmt.Dump.list V.pp) l diff --git a/src/proof/pterm.mli b/src/proof/pterm.mli index 11b71154..0dd9697f 100644 --- a/src/proof/pterm.mli +++ b/src/proof/pterm.mli @@ -48,3 +48,5 @@ val apply_rule : val to_ser : Term.Tracer.t -> t -> Ser_value.t (** Serialize *) + +val deser : Term.Trace_reader.t -> t Ser_decode.t diff --git a/src/proof/sidekick_proof.ml b/src/proof/sidekick_proof.ml index 0d0bd76c..2f0cf2ef 100644 --- a/src/proof/sidekick_proof.ml +++ b/src/proof/sidekick_proof.ml @@ -4,6 +4,7 @@ module Sat_rules = Sat_rules module Core_rules = Core_rules module Pterm = Pterm module Tracer = Tracer +module Trace_reader = Trace_reader module Arg = Arg type term = Pterm.t diff --git a/src/proof/trace_reader.ml b/src/proof/trace_reader.ml new file mode 100644 index 00000000..9d0727a8 --- /dev/null +++ b/src/proof/trace_reader.ml @@ -0,0 +1,29 @@ +open Sidekick_core +module Tr = Sidekick_trace +module Dec = Ser_decode +open Dec.Infix + +type step_id = Step.id +type t = { src: Tr.Source.t; t_reader: Term.Trace_reader.t } + +let create ~src t_reader : t = { src; t_reader } + +let rec read_step ?(fix = false) (self : t) (id : step_id) : _ result = + match Tr.Source.get_entry self.src id with + | Some ("Pt", v) -> + let res = Dec.run (Pterm.deser self.t_reader) v in + (match res with + | Ok (Pterm.P_ref id') when fix -> + (* read reference recursively *) + read_step ~fix self id' + | _ -> res) + | None -> + Error (Dec.Error.of_string "unknown source entry" (Ser_value.int id)) + | Some (tag, _) -> + Error + (Dec.Error.of_string "expected proof term, wrong tag" + (Ser_value.string tag)) + +let dec_step_id ?fix (self : t) = + let* id = Dec.int in + read_step ?fix self id |> Dec.return_result_err diff --git a/src/proof/trace_reader.mli b/src/proof/trace_reader.mli new file mode 100644 index 00000000..5b1bcb13 --- /dev/null +++ b/src/proof/trace_reader.mli @@ -0,0 +1,16 @@ +open Sidekick_core +module Tr = Sidekick_trace +module Dec = Ser_decode + +type step_id = Step.id +type t + +val create : src:Tr.Source.t -> Term.Trace_reader.t -> t + +val read_step : ?fix:bool -> t -> step_id -> (Pterm.t, Dec.Error.t) result +(** Read a step from the source at the given step id, using the trace reader. + @param fix if true, dereferences in a loop so the returned proof term is + not a Ref. *) + +val dec_step_id : ?fix:bool -> t -> Pterm.t Dec.t +(** Reads an integer, decodes the corresponding entry *) diff --git a/src/smt/trace_reader.ml b/src/smt/trace_reader.ml index 42e20638..20720ebb 100644 --- a/src/smt/trace_reader.ml +++ b/src/smt/trace_reader.ml @@ -1,18 +1,32 @@ open Sidekick_core +module Proof = Sidekick_proof module Tr = Sidekick_trace -type entry = Assert of Term.t | Assert_clause of { id: int; c: Lit.t list } +type entry = + | Assert of Term.t + | Assert_clause of { id: int; c: Lit.t list; p: Proof.Pterm.t option } let pp_entry out = function | Assert t -> Fmt.fprintf out "(@[assert@ %a@])" Term.pp t - | Assert_clause { id; c } -> - Fmt.fprintf out "(@[assert-c[%d]@ %a@])" id (Fmt.Dump.list Lit.pp) c + | Assert_clause { id; c; p } -> + Fmt.fprintf out "(@[assert-c[%d]@ %a@ :proof %a@])" id + (Fmt.Dump.list Lit.pp) c + (Fmt.Dump.option Proof.Pterm.pp) + p -type t = { tst: Term.store; src: Tr.Source.t; t_dec: Term.Trace_reader.t } +type t = { + tst: Term.store; + src: Tr.Source.t; + t_dec: Term.Trace_reader.t; + p_dec: Proof.Trace_reader.t; +} let create ?const_decoders tst src : t = let t_dec = Term.Trace_reader.create ?const_decoders tst ~source:src in - { tst; src; t_dec } + let p_dec = Proof.Trace_reader.create ~src t_dec in + { tst; src; t_dec; p_dec } + +let term_trace_reader self = self.t_dec let add_const_decoders self c = Term.Trace_reader.add_const_decoders self.t_dec c @@ -28,8 +42,13 @@ let dec_c (self : t) = let+ b, t = tup2 bool @@ dec_t self in Lit.atom self.tst ~sign:b t in - let+ id = dict_field "id" int and+ c = dict_field "c" (list dec_lit) in - id, c) + let+ id = dict_field "id" int + and+ c = dict_field "c" (list dec_lit) + and+ p = + dict_field_opt "p" (Proof.Trace_reader.dec_step_id ~fix:true self.p_dec) + in + + id, c, p) let decode (self : t) ~tag v = Log.debugf 30 (fun k -> @@ -45,7 +64,7 @@ let decode (self : t) ~tag v = | "AssC" -> Ser_decode.( (match run (dec_c self) v with - | Ok (id, c) -> Some (Assert_clause { id; c }) + | Ok (id, c, p) -> Some (Assert_clause { id; c; p }) | Error err -> Fmt.eprintf "cannot decode entry with tag %S:@ %a@." tag Ser_decode.Error.pp err; diff --git a/src/smt/trace_reader.mli b/src/smt/trace_reader.mli index 77580e75..02da88d8 100644 --- a/src/smt/trace_reader.mli +++ b/src/smt/trace_reader.mli @@ -1,9 +1,12 @@ (** Read trace *) open Sidekick_core +module Proof = Sidekick_proof module Tr = Sidekick_trace -type entry = Assert of Term.t | Assert_clause of { id: int; c: Lit.t list } +type entry = + | Assert of Term.t + | Assert_clause of { id: int; c: Lit.t list; p: Proof.Pterm.t option } val pp_entry : entry Fmt.printer @@ -13,5 +16,6 @@ val create : ?const_decoders:Const.decoders list -> Term.store -> Tr.Source.t -> t val add_const_decoders : t -> Const.decoders -> unit +val term_trace_reader : t -> Term.Trace_reader.t val decode : t -> tag:string -> Ser_value.t -> entry option val decode_entry : t -> Tr.Entry_id.t -> entry option diff --git a/src/smt/tracer.ml b/src/smt/tracer.ml index 00096473..c1d4ec21 100644 --- a/src/smt/tracer.ml +++ b/src/smt/tracer.ml @@ -44,7 +44,6 @@ class concrete (sink : Tr.Sink.t) : t = class dummy : t = object - inherit Term.Tracer.dummy inherit Sidekick_sat.Tracer.dummy method emit_assert_term _ = Tr.Entry_id.dummy end diff --git a/src/th-data/proof_rules.ml b/src/th-data/proof_rules.ml index 1496d902..76371957 100644 --- a/src/th-data/proof_rules.ml +++ b/src/th-data/proof_rules.ml @@ -1,4 +1,3 @@ -open Sidekick_core module Proof = Sidekick_proof let lemma_isa_cstor ~cstor_t t : Proof.Pterm.t = diff --git a/src/th-lra/proof_rules.ml b/src/th-lra/proof_rules.ml index 81dddc96..8326efab 100644 --- a/src/th-lra/proof_rules.ml +++ b/src/th-lra/proof_rules.ml @@ -1,4 +1,3 @@ -open Sidekick_core module Proof = Sidekick_proof let lemma_lra lits : Proof.Pterm.t = Proof.Pterm.apply_rule "lra.lemma" ~lits diff --git a/src/util/ser_decode.ml b/src/util/ser_decode.ml index 4f21e289..f128ab19 100644 --- a/src/util/ser_decode.ml +++ b/src/util/ser_decode.ml @@ -32,7 +32,8 @@ let[@inline] fail_e e = raise_notrace (Fail e) let fail_err e = { deser = (fun _ -> fail_e e) } let return x = { deser = (fun _ -> x) } let fail s = { deser = (fun v -> fail_ s v) } -let failf fmt = Printf.ksprintf fail fmt +let failf fmt = Fmt.kasprintf fail fmt +let delay f = { deser = (fun v -> (f ()).deser v) } let return_result = function | Ok x -> return x @@ -170,6 +171,12 @@ end include Infix +let dict_field_or default name d = + let+ r = dict_field_opt name d in + match r with + | Some r -> r + | None -> default + let tup2 d1 d2 = let* l = list any in match l with diff --git a/src/util/ser_decode.mli b/src/util/ser_decode.mli index fff60b4d..1c689050 100644 --- a/src/util/ser_decode.mli +++ b/src/util/ser_decode.mli @@ -9,7 +9,6 @@ module Error : sig include Sidekick_sigs.PRINT with type t := t val to_string : t -> string - val of_string : string -> Ser_value.t -> t end @@ -24,19 +23,24 @@ val string : string t val return : 'a -> 'a t val return_result : ('a, string) result -> 'a t val return_result_err : ('a, Error.t) result -> 'a t +val delay : (unit -> 'a t) -> 'a t val fail : string -> 'a t -val failf : ('a, unit, string, 'b t) format4 -> 'a +val failf : ('a, Format.formatter, unit, 'b t) format4 -> 'a val fail_err : Error.t -> 'a t + val unwrap_opt : string -> 'a option -> 'a t (** Unwrap option, or fail *) + val any : Ser_value.t t val list : 'a t -> 'a list t -val tup2 : 'a t -> 'b t -> ('a*'b) t -val tup3 : 'a t -> 'b t -> 'c t -> ('a*'b*'c) t -val tup4 : 'a t -> 'b t -> 'c t -> 'd t -> ('a*'b*'c*'d) t +val tup2 : 'a t -> 'b t -> ('a * 'b) t +val tup3 : 'a t -> 'b t -> 'c t -> ('a * 'b * 'c) t +val tup4 : 'a t -> 'b t -> 'c t -> 'd t -> ('a * 'b * 'c * 'd) t val dict_field : string -> 'a t -> 'a t val dict_field_opt : string -> 'a t -> 'a option t +val dict_field_or : 'a -> string -> 'a t -> 'a t val both : 'a t -> 'b t -> ('a * 'b) t + val reflect : 'a t -> Ser_value.t -> ('a, Error.t) result t (** [reflect dec v] returns the result of decoding [v] with [dec] *)