diff --git a/src/base/Sidekick_base.ml b/src/base/Sidekick_base.ml index 5d7df622..d4cf4ed7 100644 --- a/src/base/Sidekick_base.ml +++ b/src/base/Sidekick_base.ml @@ -42,3 +42,14 @@ let th_bool_static : Solver.theory = Th_bool.theory_static let th_data : Solver.theory = Th_data.theory let th_lra : Solver.theory = Th_lra.theory let th_ty_unin : Solver.theory = Th_ty_unin.theory + +(** All constant decoders *) +let const_decoders = + List.flatten + [ + Uconst.const_decoders; + LRA_term.const_decoders; + Ty.const_decoders; + (* TODO Th_data *) + Form.const_decoders; + ] diff --git a/src/base/Uconst.ml b/src/base/Uconst.ml index b295f67e..37478704 100644 --- a/src/base/Uconst.ml +++ b/src/base/Uconst.ml @@ -42,7 +42,7 @@ let const_decoders : Const.decoders = ops, Ser_decode.( fun dec_t -> - let+ uc_id = ID.deser and+ uc_ty = dec_t in + let+ uc_id, uc_ty = tup2 ID.deser dec_t in Uconst { uc_id; uc_ty }) ); ] diff --git a/src/core/t_trace_reader.ml b/src/core/t_trace_reader.ml index 4e1df71f..a81e1948 100644 --- a/src/core/t_trace_reader.ml +++ b/src/core/t_trace_reader.ml @@ -9,7 +9,7 @@ type const_decoders = Const.decoders type t = { tst: Term.store; src: Tr.Source.t; - cache: (Term.t, string) result ID_cache.t; + cache: (Term.t, Dec.Error.t) result ID_cache.t; mutable const_decode: (Const.Ops.t * (Term.t Dec.t -> Const.view Dec.t)) Util.Str_map.t; (** tag -> const decoder *) @@ -75,22 +75,32 @@ let decode_term (self : t) ~read_subterm ~tag : Term.t Dec.t = let+ c_view = reflect_or_fail (c_dec read_subterm) view in let const = Const.make c_view ops ~ty in Term.const self.tst const)) - | "Tf@" -> assert false (* TODO *) + | "Tf@" -> + Dec.( + let+ f = dict_field "f" read_subterm + and+ l = dict_field "l" (list read_subterm) + and+ acc0 = dict_field "a0" read_subterm in + Term.app_fold self.tst ~f l ~acc0) | _ -> Dec.failf "unknown tag %S for a term" tag -let rec read_term (self : t) (id : term_ref) : _ result = +let rec read_term_err (self : t) (id : term_ref) : _ result = (* decoder for subterms *) let read_subterm : Term.t Dec.t = Dec.( let* id = int in - match read_term self id with - | Ok x -> return x - | Error e -> fail e) + return_result_err @@ read_term_err self id) in ID_cache.get self.cache id ~compute:(fun id -> match Tr.Source.get_entry self.src id with - | None -> Error (Printf.sprintf "invalid entry: %d" id) + | None -> + Error + (Dec.Error.of_string + (Printf.sprintf "invalid entry: %d" id) + (Ser_value.int id)) | Some (tag, v) -> let dec = decode_term self ~tag ~read_subterm in - Dec.run dec v |> Result.map_error Dec.Error.to_string) + Dec.run dec v) + +let read_term self id = + Result.map_error Dec.Error.to_string @@ read_term_err self id diff --git a/src/core/t_trace_reader.mli b/src/core/t_trace_reader.mli index 3d31ca65..22d919bc 100644 --- a/src/core/t_trace_reader.mli +++ b/src/core/t_trace_reader.mli @@ -11,3 +11,4 @@ val create : val add_const_decoders : t -> const_decoders -> unit val read_term : t -> term_ref -> (Term.t, string) result +val read_term_err : t -> term_ref -> (Term.t, Ser_decode.Error.t) result diff --git a/src/core/t_tracer.ml b/src/core/t_tracer.ml index cb9738e8..f712043e 100644 --- a/src/core/t_tracer.ml +++ b/src/core/t_tracer.ml @@ -42,7 +42,14 @@ let emit_term_ (self : state) (t : Term.t) = in V.dict_of_list fields ) | T.E_app_fold { f; args; acc0 } -> - "Tf@", V.(list [ loop' f; list (List.map loop' args); loop' acc0 ]) + ( "Tf@", + V.( + dict_of_list + [ + "f", loop' f; + "l", list (List.map loop' args); + "a0", loop' acc0; + ]) ) | T.E_lam (name, ty, bod) -> "Tl", V.(list [ string name; loop' ty; loop' bod ]) | T.E_pi (name, ty, bod) -> diff --git a/src/main/dune b/src/main/dune index 7b2db52b..7727142d 100644 --- a/src/main/dune +++ b/src/main/dune @@ -15,5 +15,6 @@ (name show_trace) (modules show_trace) (modes native) - (libraries containers sidekick.util sidekick.core sidekick.trace) + (libraries containers sidekick.util sidekick.core sidekick.trace + sidekick.smt-solver sidekick-base) (flags :standard -safe-string -color always -open Sidekick_util)) diff --git a/src/main/show_trace.ml b/src/main/show_trace.ml index 0d941371..46d59efd 100644 --- a/src/main/show_trace.ml +++ b/src/main/show_trace.ml @@ -1,21 +1,43 @@ +open Sidekick_core open Sidekick_trace +module Smt = Sidekick_smt_solver -let show_file file : unit = +let show_file ~dump file : unit = Log.debugf 1 (fun k -> k "(@[show-file %S@])" file); let src = Source.of_string_using_bencode @@ CCIO.File.read_exn file in + let tst = Term.Store.create () in + + (* trace reader *) + let t_reader = + Smt.Trace_reader.create tst src + ~const_decoders: + [ + Term.const_decoders; Box.const_decoders; Sidekick_base.const_decoders; + ] + in + Source.iter_all src (fun i ~tag v -> - Format.printf "[%d]: %S %a@." i tag Sidekick_util.Ser_value.pp v) + Log.debugf 10 (fun k -> + k "(@[show-trace[%d]@ :tag %S@ :val %a@])" i tag Ser_value.pp v); + if dump then Format.printf "[%d]: %S %a@." i tag Ser_value.pp v; + + match Smt.Trace_reader.decode t_reader ~tag v with + | Some e -> Fmt.printf "[%d]: %a@." i Smt.Trace_reader.pp_entry e + | None -> ()) let () = let files = ref [] in + let dump = ref false in let opts = [ + "--dump", Arg.Set dump, " dump each raw entry"; ( "--bt", Arg.Unit (fun () -> Printexc.record_backtrace true), " enable backtraces" ); + "-d", Arg.Int Log.set_debug, " debug level"; ] |> Arg.align in Arg.parse opts (fun f -> files := f :: !files) "show_trace [file]+"; let files = List.rev !files in - List.iter show_file files + List.iter (show_file ~dump:!dump) files diff --git a/src/smt/Sidekick_smt_solver.ml b/src/smt/Sidekick_smt_solver.ml index ea9cc2a6..b1452b9b 100644 --- a/src/smt/Sidekick_smt_solver.ml +++ b/src/smt/Sidekick_smt_solver.ml @@ -17,6 +17,7 @@ module Theory_id = Theory_id module Preprocess = Preprocess module Find_foreign = Find_foreign module Tracer = Tracer +module Trace_reader = Trace_reader type theory = Theory.t type solver = Solver.t diff --git a/src/smt/trace_reader.ml b/src/smt/trace_reader.ml new file mode 100644 index 00000000..85ccf19c --- /dev/null +++ b/src/smt/trace_reader.ml @@ -0,0 +1,55 @@ +open Sidekick_core +module Tr = Sidekick_trace + +type entry = Assert of Term.t | Assert_clause of Lit.t list + +let pp_entry out = function + | Assert t -> Fmt.fprintf out "(@[assert@ %a@])" Term.pp t + | Assert_clause c -> + Fmt.fprintf out "(@[assert-c@ %a@])" (Fmt.Dump.list Lit.pp) c + +type t = { tst: Term.store; src: Tr.Source.t; t_dec: Term.Trace_reader.t } + +let create ?const_decoders tst src : t = + let t_dec = Term.Trace_reader.create ?const_decoders tst ~source:src in + { tst; src; t_dec } + +let add_const_decoders self c = Term.Trace_reader.add_const_decoders self.t_dec c + +let dec_t (self : t) = + Ser_decode.( + let* i = int in + return_result @@ Term.Trace_reader.read_term self.t_dec i) + +let dec_c (self : t) = + Ser_decode.( + let dec_lit = + let+ b, t = tup2 bool @@ dec_t self in + Lit.atom self.tst ~sign:b t + in + list dec_lit) + +let decode (self : t) ~tag v = + Log.debugf 30 (fun k -> + k "(@[trace-reader.decode@ :tag %S@ :val %a@])" tag Ser_value.pp v); + match tag with + | "Asst" -> + (match Ser_decode.(run (dec_t self) v) with + | Ok t -> Some (Assert t) + | Error err -> + Fmt.eprintf "cannot decode entry with tag %S:@ %a@." tag + Ser_decode.Error.pp err; + None) + | "AssC" -> + Ser_decode.( + (match run (dec_c self) v with + | Ok c -> Some (Assert_clause c) + | Error err -> + Fmt.eprintf "cannot decode entry with tag %S:@ %a@." tag + Ser_decode.Error.pp err; + None)) + | _ -> None + +let decode_entry self id : _ option = + let tag, v = Tr.Source.get_entry_exn self.src id in + decode self ~tag v diff --git a/src/smt/trace_reader.mli b/src/smt/trace_reader.mli new file mode 100644 index 00000000..e26c9a2f --- /dev/null +++ b/src/smt/trace_reader.mli @@ -0,0 +1,17 @@ +(** Read trace *) + +open Sidekick_core +module Tr = Sidekick_trace + +type entry = Assert of Term.t | Assert_clause of Lit.t list + +val pp_entry : entry Fmt.printer + +type t + +val create : + ?const_decoders:Const.decoders list -> Term.store -> Tr.Source.t -> t + +val add_const_decoders : t -> Const.decoders -> unit +val decode : t -> tag:string -> Ser_value.t -> entry option +val decode_entry : t -> Tr.Entry_id.t -> entry option diff --git a/src/smt/tracer.ml b/src/smt/tracer.ml index 540b03a4..678e36f2 100644 --- a/src/smt/tracer.ml +++ b/src/smt/tracer.ml @@ -2,8 +2,6 @@ open Sidekick_core module Tr = Sidekick_trace module V = Ser_value -type Tr.entry_view += Assert of Term.t | Assert_clause of Lit.t list - class type t = object inherit Term.Tracer.t diff --git a/src/smt/tracer.mli b/src/smt/tracer.mli index bad089f2..ef8588a0 100644 --- a/src/smt/tracer.mli +++ b/src/smt/tracer.mli @@ -1,8 +1,6 @@ open Sidekick_core module Tr = Sidekick_trace -type Tr.entry_view += Assert of Term.t | Assert_clause of Lit.t list - class type t = object inherit Term.Tracer.t diff --git a/src/trace/sink.ml b/src/trace/sink.ml index f35544c8..9365a7b8 100644 --- a/src/trace/sink.ml +++ b/src/trace/sink.ml @@ -24,18 +24,19 @@ let null : t = end) let of_out_channel_using_bencode (oc : out_channel) : t = - let id_ = ref 0 in + (* id: offset in the channel *) + let off = ref 0 in let buf = Buffer.create 128 in (module struct let emit ~tag (v : Ser_value.t) = assert (Buffer.length buf = 0); - let id = Entry_id.of_int_unsafe !id_ in + let id = Entry_id.of_int_unsafe !off in (* add tag+id around *) let v' = Ser_value.(list [ int id; string tag; v ]) in - incr id_; Sidekick_bencode.Encode.to_buffer buf v'; Buffer.add_char buf '\n'; Buffer.output_buffer oc buf; + off := !off + Buffer.length buf; Buffer.clear buf; id end) diff --git a/src/util/ser_decode.ml b/src/util/ser_decode.ml index 530537c3..4f21e289 100644 --- a/src/util/ser_decode.ml +++ b/src/util/ser_decode.ml @@ -8,10 +8,11 @@ module Error = struct type t = { msg: string; v: V.t; subs: t list } let mk ?(subs = []) msg v : t = { msg; v; subs } + let of_string s v : t = mk s v let pp out (self : t) = let rec pp out self = - Fmt.fprintf out "@[@[<2>%s@ in %a@]" self.msg V.pp self.v; + Fmt.fprintf out "@[@[<2>%s@ in value %a@]" self.msg V.pp self.v; List.iter (fun s -> Fmt.fprintf out "@ @[<2>sub-error:@ %a@]" pp s) self.subs; @@ -28,10 +29,19 @@ type 'a t = { deser: V.t -> 'a } [@@unboxed] let[@inline] fail_ msg v = raise_notrace (Fail (Error.mk msg v)) let[@inline] fail_e e = raise_notrace (Fail e) +let fail_err e = { deser = (fun _ -> fail_e e) } let return x = { deser = (fun _ -> x) } let fail s = { deser = (fun v -> fail_ s v) } let failf fmt = Printf.ksprintf fail fmt +let return_result = function + | Ok x -> return x + | Error s -> fail s + +let return_result_err = function + | Ok x -> return x + | Error e -> fail_err e + let unwrap_opt msg = function | Some x -> return x | None -> fail msg @@ -43,6 +53,8 @@ let bool = deser = (function | V.Bool b -> b + | V.Int 1 -> true + | V.Int 0 -> false | v -> fail_ "expected bool" v); } diff --git a/src/util/ser_decode.mli b/src/util/ser_decode.mli index 7f14f789..fff60b4d 100644 --- a/src/util/ser_decode.mli +++ b/src/util/ser_decode.mli @@ -9,6 +9,8 @@ module Error : sig include Sidekick_sigs.PRINT with type t := t val to_string : t -> string + + val of_string : string -> Ser_value.t -> t end (** {2 Main combinators *) @@ -20,8 +22,11 @@ val int : int t val bool : bool t val string : string t val return : 'a -> 'a t +val return_result : ('a, string) result -> 'a t +val return_result_err : ('a, Error.t) result -> 'a t val fail : string -> 'a t val failf : ('a, unit, string, 'b t) format4 -> 'a +val fail_err : Error.t -> 'a t val unwrap_opt : string -> 'a option -> 'a t (** Unwrap option, or fail *) val any : Ser_value.t t