diff --git a/src/util/Sidekick_util.ml b/src/util/Sidekick_util.ml index fda8f3a4..d1509ab4 100644 --- a/src/util/Sidekick_util.ml +++ b/src/util/Sidekick_util.ml @@ -24,5 +24,6 @@ module Hash = Hash module Profile = Profile module Chunk_stack = Chunk_stack module Ser_value = Ser_value +module Ser_decode = Ser_decode let[@inline] ( let@ ) f x = f x diff --git a/src/util/ser_decode.ml b/src/util/ser_decode.ml new file mode 100644 index 00000000..17b00881 --- /dev/null +++ b/src/util/ser_decode.ml @@ -0,0 +1,149 @@ +open struct + module Error_ = Error + module Fmt = CCFormat + module V = Ser_value +end + +module Error = struct + type t = { msg: string; v: V.t; subs: t list } + + let mk ?(subs = []) msg v : t = { msg; v; subs } + + let pp out (self : t) = + let rec pp out self = + Fmt.fprintf out "@[@[<2>%s@ in %a@]" self.msg V.pp self.v; + List.iter + (fun s -> Fmt.fprintf out "@ @[<2>sub-error:@ %a@]" pp s) + self.subs; + Fmt.fprintf out "@]" + in + Fmt.fprintf out "@[<2>Ser_decode.error:@ %a@]" pp self + + let to_string = Fmt.to_string pp +end + +exception Fail of Error.t + +type 'a t = { deser: V.t -> 'a } [@@unboxed] + +let[@inline] fail_ msg v = raise_notrace (Fail (Error.mk msg v)) +let[@inline] fail_e e = raise_notrace (Fail e) +let return x = { deser = (fun _ -> x) } +let fail s = { deser = (fun v -> fail_ s v) } +let any = { deser = (fun v -> v) } + +let bool = + { + deser = + (function + | V.Bool b -> b + | v -> fail_ "expected bool" v); + } + +let int = + { + deser = + (function + | V.Int i -> i + | v -> fail_ "expected int" v); + } + +let string = + { + deser = + (function + | V.Str s | V.Bytes s -> s + | v -> fail_ "expected string" v); + } + +let list d = + { + deser = + (function + | V.List l -> List.map (fun x -> d.deser x) l + | v -> fail_ "expected list" v); + } + +let dict_field name d = + { + deser = + (function + | V.Dict m as v -> + (match Util.Str_map.find_opt name m with + | None -> fail_ (Printf.sprintf "did not find key %S" name) v + | Some x -> d.deser x) + | v -> fail_ "expected dict" v); + } + +let dict_field_opt name d = + { + deser = + (function + | V.Dict m -> + (match Util.Str_map.find_opt name m with + | None -> None + | Some x -> Some (d.deser x)) + | v -> fail_ "expected dict" v); + } + +let both a b = + { + deser = + (fun v -> + let xa = a.deser v in + let xb = b.deser v in + xa, xb); + } + +let ( >>= ) d f = + { + deser = + (fun v -> + let x = d.deser v in + (f x).deser v); + } + +let ( >|= ) d f = + { + deser = + (fun v -> + let x = d.deser v in + f x); + } + +let try_l l = + { + deser = + (fun v -> + let subs = ref [] in + match + CCList.find_map + (fun d -> + match d.deser v with + | x -> Some x + | exception Fail err -> + subs := err :: !subs; + None) + l + with + | Some x -> x + | None -> fail_e (Error.mk "all decoders failed" v ~subs:!subs)); + } + +module Infix = struct + let ( >>= ) = ( >>= ) + let ( >|= ) = ( >|= ) + let ( and+ ) = both + let ( and* ) = both + let ( let+ ) = ( >|= ) + let ( let* ) = ( >>= ) +end + +include Infix + +let run d v = try Ok (d.deser v) with Fail err -> Error err + +let run_exn d v = + try d.deser v + with Fail err -> + Error_.errorf "ser_decode: failed to decode:@ %a" Error.pp err diff --git a/src/util/ser_decode.mli b/src/util/ser_decode.mli new file mode 100644 index 00000000..b5086920 --- /dev/null +++ b/src/util/ser_decode.mli @@ -0,0 +1,44 @@ +(** Decoders for {!Ser_value}. + + Combinators to decode values. *) + +type +'a t +(** Decode a value of type ['a] *) + +val int : int t +val bool : bool t +val string : string t +val return : 'a -> 'a t +val fail : string -> 'a t +val any : Ser_value.t t +val list : 'a t -> 'a list t +val dict_field : string -> 'a t -> 'a t +val dict_field_opt : string -> 'a t -> 'a option t +val both : 'a t -> 'b t -> ('a * 'b) t +val try_l : 'a t list -> 'a t + +module Infix : sig + val ( >|= ) : 'a t -> ('a -> 'b) -> 'b t + val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t + val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t + val ( and+ ) : 'a t -> 'b t -> ('a * 'b) t + val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t + val ( and* ) : 'a t -> 'b t -> ('a * 'b) t +end + +include module type of Infix + +(** {2 Deserializing} *) + +module Error : sig + type t + + include Sidekick_sigs.PRINT with type t := t + + val to_string : t -> string +end + +val run : 'a t -> Ser_value.t -> ('a, Error.t) result + +val run_exn : 'a t -> Ser_value.t -> 'a +(** @raise Error.Error in case of failure *)