From e177534a4616dafab2aa51b827d77af5f1966b94 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 16 Feb 2022 14:20:27 -0500 Subject: [PATCH] chore: ugprade bare encoding --- src/base/Proof.ml | 7 +- src/proof-trace-dump/proof_trace_dump.ml | 2 +- src/proof-trace/proof_ser.ml | 276 ++++++++++++++--------- 3 files changed, 174 insertions(+), 111 deletions(-) diff --git a/src/base/Proof.ml b/src/base/Proof.ml index 755429d4..6f243a84 100644 --- a/src/base/Proof.ml +++ b/src/base/Proof.ml @@ -48,6 +48,7 @@ type t = { mutable enabled : bool; config: Config.t; buf: Buffer.t; + out: Proof_ser.Bare.Encode.t; mutable storage: Storage.t; mutable dispose: unit -> unit; mutable steps_writer: CS.Writer.t; @@ -90,10 +91,12 @@ let create ?(config=Config.default) () : t = let dispose () = close_out oc in Storage.On_disk (file, oc), w, dispose in + let buf = Buffer.create 1024 in { enabled=config.Config.enabled; config; next_id=1; - buf=Buffer.create 1_024; + buf; + out=Proof_ser.Bare.Encode.of_buffer buf; map_term=Term.Tbl.create 32; map_fun=Fun.Tbl.create 32; steps_writer; storage; dispose; @@ -118,7 +121,7 @@ let[@inline] alloc_id (self:t) : Proof_ser.ID.t = let emit_step_ (self:t) (step:Proof_ser.Step.t) : unit = if enabled self then ( Buffer.clear self.buf; - Proof_ser.Step.encode self.buf step; + Proof_ser.Step.encode self.out step; Chunk_stack.Writer.add_buffer self.steps_writer self.buf; ) diff --git a/src/proof-trace-dump/proof_trace_dump.ml b/src/proof-trace-dump/proof_trace_dump.ml index 352f02a6..7883e355 100644 --- a/src/proof-trace-dump/proof_trace_dump.ml +++ b/src/proof-trace-dump/proof_trace_dump.ml @@ -20,7 +20,7 @@ let parse_file () : unit = CS.Reader.next reader ~finish:(fun () -> ()) ~yield:(fun b i _len -> - let decode = {Proof_ser.Bare.Decode.bs=b; off=i} in + let decode = Proof_ser.Bare.Decode.of_bytes b ~off:i in let step = Proof_ser.Step.decode decode in incr n; if not !quiet then ( diff --git a/src/proof-trace/proof_ser.ml b/src/proof-trace/proof_ser.ml index 8aa7b08f..471c3b82 100644 --- a/src/proof-trace/proof_ser.ml +++ b/src/proof-trace/proof_ser.ml @@ -1,3 +1,4 @@ +(* generated from "proof_ser.bare" using bare-codegen *) [@@@ocaml.warning "-26-27"] (* embedded runtime library *) @@ -5,27 +6,45 @@ module Bare = struct module String_map = Map.Make(String) -let spf = Printf.sprintf +module type INPUT = sig + val read_byte : unit -> char + val read_i16 : unit -> int + val read_i32 : unit -> int32 + val read_i64 : unit -> int64 + val read_exact : bytes -> int -> int -> unit +end +type input = (module INPUT) + +let input_of_bytes ?(off=0) ?len (b:bytes) : input = + let off = ref off in + let len = match len with + | None -> Bytes.length b - !off + | Some l -> l + in + if !off + len > Bytes.length b then invalid_arg "input_of_bytes"; + let[@inline] check_ n = if !off + n > len then invalid_arg "input exhausted" in + let module M = struct + let read_byte () = check_ 1; let c = Bytes.get b !off in incr off; c + let read_i16 () = check_ 2; let r = Bytes.get_int16_le b !off in off := !off + 2; r + let read_i32 () = check_ 4; let r = Bytes.get_int32_le b !off in off := !off + 4; r + let read_i64 () = check_ 8; let r = Bytes.get_int64_le b !off in off := !off + 8; r + let read_exact into i len = check_ len; Bytes.blit b !off into i len; off := !off + len + end in + (module M) module Decode = struct - exception Error of string + type t = input - type t = { - bs: bytes; - mutable off: int; - } + let[@inline] of_input (i:input) : t = i + let of_bytes ?off ?len b = of_input (input_of_bytes ?off ?len b) + let of_string ?off ?len s = of_bytes ?off ?len (Bytes.unsafe_of_string s) type 'a dec = t -> 'a - let fail_ e = raise (Error e) - let fail_eof_ what = - fail_ (spf "unexpected end of input, expected %s" what) - let uint (self:t) : int64 = let rec loop () = - if self.off >= Bytes.length self.bs then fail_eof_ "uint"; - let c = Char.code (Bytes.get self.bs self.off) in - self.off <- 1 + self.off; (* consume *) + let c = let (module M) = self in M.read_byte() in + let c = Char.code c in if c land 0b1000_0000 <> 0 then ( let rest = loop() in let c = Int64.of_int (c land 0b0111_1111) in @@ -51,33 +70,20 @@ module Decode = struct in res - let u8 self : char = - let x = Bytes.get self.bs self.off in - self.off <- self.off + 1; - x - let i8 = u8 + let i8 (self:t) : char = let (module M) = self in M.read_byte() + let u8 = i8 - let u16 self = - let x = Bytes.get_int16_le self.bs self.off in - self.off <- self.off + 2; - x - let i16 = u16 + let i16 (self:t) = let (module M) = self in M.read_i16() + let u16 = i16 - let u32 self = - let x = Bytes.get_int32_le self.bs self.off in - self.off <- self.off + 4; - x - let i32 = u32 + let i32 (self:t) = let (module M) = self in M.read_i32() + let u32 = i32 - let u64 self = - let i = Bytes.get_int64_le self.bs self.off in - self.off <- 8 + self.off; - i - let i64 = u64 + let i64 (self:t) = let (module M) = self in M.read_i64() + let u64 = i64 - let bool self : bool = - let c = Bytes.get self.bs self.off in - self.off <- 1 + self.off; + let[@inline] bool self : bool = + let c = i8 self in Char.code c <> 0 let f32 (self:t) : float = @@ -88,15 +94,16 @@ module Decode = struct let i = i64 self in Int64.float_of_bits i - let data_of ~size self : bytes = - let s = Bytes.sub self.bs self.off size in - self.off <- self.off + size; - s + let data_of ~size (self:t) : bytes = + let b = Bytes.create size in + let (module M) = self in + M.read_exact b 0 size; + b let data self : bytes = let size = uint self in if Int64.compare size (Int64.of_int Sys.max_string_length) > 0 then - fail_ "string too large"; + invalid_arg "Decode.data: string too large"; let size = Int64.to_int size in (* fits, because of previous test *) data_of ~size self @@ -108,10 +115,33 @@ module Decode = struct if Char.code c = 0 then None else Some (dec self) end -module Encode = struct - type t = Buffer.t +module type OUTPUT = sig + val write_byte : char -> unit + val write_i16 : int -> unit + val write_i32 : int32 -> unit + val write_i64 : int64 -> unit + val write_exact : bytes -> int -> int -> unit + val flush : unit -> unit +end - let of_buffer buf : t = buf +type output = (module OUTPUT) + +let output_of_buffer (buf:Buffer.t) : output = + let module M = struct + let[@inline] write_byte c = Buffer.add_char buf c + let[@inline] write_i16 c = Buffer.add_int16_le buf c + let[@inline] write_i32 c = Buffer.add_int32_le buf c + let[@inline] write_i64 c = Buffer.add_int64_le buf c + let write_exact b i len = Buffer.add_subbytes buf b i len + let flush _ = () + end in + (module M) + +module Encode = struct + type t = output + + let[@inline] of_output (o:output) : t = o + let[@inline] of_buffer buf : t = of_output @@ output_of_buffer buf type 'a enc = t -> 'a -> unit @@ -127,12 +157,14 @@ module Encode = struct if !i = j then ( continue := false; let j = I.to_int j in - Buffer.add_char self (unsafe_chr j) + let (module M) = self in + M.write_byte (unsafe_chr j) ) else ( (* set bit 8 to [1] *) let lsb = I.to_int (I.logor 0b1000_0000L j) in let lsb = (unsafe_chr lsb) in - Buffer.add_char self lsb; + let (module M) = self in + M.write_byte lsb; i := I.shift_right_logical !i 7; ) done @@ -142,28 +174,30 @@ module Encode = struct let ui = logxor (shift_left i 1) (shift_right i 63) in uint self ui - let u8 self x = Buffer.add_char self x - let i8 = u8 - let u16 self x = Buffer.add_int16_le self x - let i16 = u16 - let u32 self x = Buffer.add_int32_le self x - let i32 = u32 - let u64 self x = Buffer.add_int64_le self x - let i64 = u64 + let[@inline] i8 (self:t) x = let (module M) = self in M.write_byte x + let u8 = i8 + let[@inline] i16 (self:t) x = let (module M) = self in M.write_i16 x + let u16 = i16 + let[@inline] i32 (self:t) x = let (module M) = self in M.write_i32 x + let u32 = i32 + let[@inline] i64 (self:t) x = let (module M) = self in M.write_i64 x + let u64 = i64 - let bool self x = Buffer.add_char self (if x then Char.chr 1 else Char.chr 0) + let bool self x = i8 self (if x then Char.chr 1 else Char.chr 0) - let f64 (self:t) x = Buffer.add_int64_le self (Int64.bits_of_float x) + let f64 (self:t) x = i64 self (Int64.bits_of_float x) - let data_of ~size self x = + let data_of ~size (self:t) x = if size <> Bytes.length x then failwith "invalid length for Encode.data_of"; - Buffer.add_bytes self x + let (module M) = self in + M.write_exact x 0 size - let data self x = + let data (self:t) x = uint self (Int64.of_int (Bytes.length x)); - Buffer.add_bytes self x + let (module M) = self in + M.write_exact x 0 (Bytes.length x) - let string self x = data self (Bytes.unsafe_of_string x) + let[@inline] string self x = data self (Bytes.unsafe_of_string x) let[@inline] optional enc self x : unit = match x with @@ -208,19 +242,21 @@ end let to_string (e:'a Encode.enc) (x:'a) = let buf = Buffer.create 32 in - e buf x; + e (Encode.of_buffer buf) x; Buffer.contents buf -let of_bytes_exn ?(off=0) dec bs = - let i = {Decode.bs; off} in +let of_bytes_exn ?off ?len dec b = + let i = Decode.of_bytes ?off ?len b in dec i -let of_bytes ?off dec bs = - try Ok (of_bytes_exn ?off dec bs) - with Decode.Error e -> Error e +let of_bytes ?off ?len dec bs = + try Ok (of_bytes_exn ?off ?len dec bs) + with + | Invalid_argument e | Failure e -> Error e + | End_of_file -> Error "end of file" -let of_string_exn dec s = of_bytes_exn dec (Bytes.unsafe_of_string s) -let of_string dec s = of_bytes dec (Bytes.unsafe_of_string s) +let of_string_exn ?off ?len dec s = of_bytes_exn ?off ?len dec (Bytes.unsafe_of_string s) +let of_string ?off ?len dec s = of_bytes ?off ?len dec (Bytes.unsafe_of_string s) (*$inject @@ -231,7 +267,7 @@ let of_string dec s = of_bytes dec (Bytes.unsafe_of_string s) Buffer.contents buf let of_s f x = - let i = {Decode.off=0; bs=Bytes.unsafe_of_string x} in + let i = Decode.of_string x in f i *) @@ -258,24 +294,47 @@ let of_string dec s = of_bytes dec (Bytes.unsafe_of_string s) 1 (let s = to_s Encode.int (-1209433446454112432L) in 0x1 land (Char.code s.[0])) *) -(*$Q +(*$Q & ~count:1000 Q.(int64) (fun s -> \ s = (of_s Decode.uint (to_s Encode.uint s))) + Q.(small_nat) (fun n -> \ + let n = Int64.of_int n in \ + n = (of_s Decode.uint (to_s Encode.uint n))) *) -(*$Q +(*$Q & ~count:1000 Q.(int64) (fun s -> \ s = (of_s Decode.int (to_s Encode.int s))) + Q.(small_signed_int) (fun n -> \ + let n = Int64.of_int n in \ + n = (of_s Decode.int (to_s Encode.int n))) *) -(* TODO: some tests with qtest *) +(*$R + for i=0 to 1_000 do + let i = Int64.of_int i in + assert_equal ~printer:Int64.to_string i (of_s Decode.int (to_s Encode.int i)) + done +*) + +(*$R + for i=0 to 1_000 do + let i = Int64.of_int i in + assert_equal ~printer:Int64.to_string i (of_s Decode.uint (to_s Encode.uint i)) + done +*) + +(*$Q & ~count:1000 + Q.(string) (fun s -> \ + s = (of_s Decode.string (to_s Encode.string s))) +*) end module ID = struct type t = int32 - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = Bare.Decode.i32 dec @@ -290,7 +349,7 @@ end module Lit = struct type t = ID.t - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = ID.decode dec @@ -307,11 +366,11 @@ module Clause = struct lits: Lit.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let lits = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> Lit.decode dec)) in {lits; } @@ -337,7 +396,7 @@ module Step_input = struct c: Clause.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let c = Clause.decode dec in {c; } @@ -360,12 +419,12 @@ module Step_rup = struct hyps: ID.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let res = Clause.decode dec in let hyps = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> ID.decode dec)) in {res; hyps; } @@ -394,7 +453,7 @@ module Step_bridge_lit_expr = struct expr: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let lit = Lit.decode dec in let expr = ID.decode dec in {lit; expr; } @@ -417,11 +476,11 @@ module Step_cc = struct eqns: ID.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let eqns = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> ID.decode dec)) in {eqns; } @@ -449,13 +508,13 @@ module Step_preprocess = struct using: ID.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let t = ID.decode dec in let u = ID.decode dec in let using = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> ID.decode dec)) in {t; u; using; } @@ -487,13 +546,13 @@ module Step_clause_rw = struct using: ID.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let c = ID.decode dec in let res = Clause.decode dec in let using = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> ID.decode dec)) in {c; res; using; } @@ -523,7 +582,7 @@ module Step_unsat = struct c: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let c = ID.decode dec in {c; } @@ -546,7 +605,7 @@ module Step_proof_p1 = struct c: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let rw_with = ID.decode dec in let c = ID.decode dec in {rw_with; c; } @@ -570,7 +629,7 @@ module Step_proof_r1 = struct c: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let unit = ID.decode dec in let c = ID.decode dec in {unit; c; } @@ -595,7 +654,7 @@ module Step_proof_res = struct c2: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let pivot = ID.decode dec in let c1 = ID.decode dec in @@ -626,11 +685,11 @@ module Step_bool_tauto = struct lits: Lit.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let lits = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> Lit.decode dec)) in {lits; } @@ -657,12 +716,12 @@ module Step_bool_c = struct exprs: ID.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let rule = Bare.Decode.string dec in let exprs = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> ID.decode dec)) in {rule; exprs; } @@ -690,7 +749,7 @@ module Step_true = struct true_: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let true_ = ID.decode dec in {true_; } @@ -712,7 +771,7 @@ module Fun_decl = struct f: string; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let f = Bare.Decode.string dec in {f; } @@ -735,7 +794,7 @@ module Expr_def = struct rhs: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let c = ID.decode dec in let rhs = ID.decode dec in {c; rhs; } @@ -758,7 +817,7 @@ module Expr_bool = struct b: bool; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let b = Bare.Decode.bool dec in {b; } @@ -782,7 +841,7 @@ module Expr_if = struct else_: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let cond = ID.decode dec in let then_ = ID.decode dec in @@ -813,7 +872,7 @@ module Expr_not = struct f: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let f = ID.decode dec in {f; } @@ -836,7 +895,7 @@ module Expr_eq = struct rhs: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let lhs = ID.decode dec in let rhs = ID.decode dec in {lhs; rhs; } @@ -860,12 +919,12 @@ module Expr_app = struct args: ID.t array; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let f = ID.decode dec in let args = (let len = Bare.Decode.uint dec in - if len>Int64.of_int Sys.max_array_length then raise (Bare.Decode.Error"array too big"); + if len>Int64.of_int Sys.max_array_length then invalid_arg "array too big"; Array.init (Int64.to_int len) (fun _ -> ID.decode dec)) in {f; args; } @@ -894,7 +953,7 @@ module Expr_isa = struct arg: ID.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let c = ID.decode dec in let arg = ID.decode dec in {c; arg; } @@ -937,7 +996,7 @@ module Step_view = struct | Expr_app of Expr_app.t - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let tag = Bare.Decode.uint dec in match tag with @@ -962,7 +1021,8 @@ module Step_view = struct | 18L -> Expr_isa (Expr_isa.decode dec) | 19L -> Expr_eq (Expr_eq.decode dec) | 20L -> Expr_app (Expr_app.decode dec) - | _ -> raise (Bare.Decode.Error(Printf.sprintf "unknown union tag Step_view.t: %Ld" tag)) + | _ -> invalid_arg + (Printf.sprintf "unknown union tag Step_view.t: %Ld" tag) let encode (enc: Bare.Encode.t) (self: t) : unit = @@ -1086,7 +1146,7 @@ module Step = struct view: Step_view.t; } - (** @raise Bare.Decode.Error in case of error. *) + (** @raise Invalid_argument in case of error. *) let decode (dec: Bare.Decode.t) : t = let id = ID.decode dec in let view = Step_view.decode dec in {id; view; }