diff --git a/src/pp/containers_pp.ml b/src/pp/containers_pp.ml index 0b9ea16e..8ca900c0 100644 --- a/src/pp/containers_pp.ml +++ b/src/pp/containers_pp.ml @@ -1,4 +1,5 @@ module B = Buffer +module Int_map = Map.Make (CCInt) module Out = struct type t = { @@ -24,10 +25,43 @@ module Out = struct end module Ext = struct + type view = .. + type 'a key = { id: int; inject: 'a -> view; extract: view -> 'a option } + type map = view Int_map.t + + let empty : map = Int_map.empty + + let get k (self : map) : _ option = + try k.extract @@ Int_map.find k.id self with Not_found -> None + + let add k v self : map = Int_map.add k.id (k.inject v) self + type 'a t = { - pre: Out.t -> 'a -> unit; (** Printed before the wrapped value. *) - post: Out.t -> 'a -> unit; (** Printed after the wrapped value. *) + name: string; + k: 'a key; + width: 'a -> int; + pre: Out.t -> inside:'a option -> 'a -> unit; + post: Out.t -> inside:'a option -> 'a -> unit; } + + let key_counter_ = ref 0 + + let make (type a) ?(width = fun _ -> 0) ~name ~pre ~post () : a t = + let module M = struct + type view += V of a + end in + let k = + { + id = !key_counter_; + inject = (fun x -> M.V x); + extract = + (function + | M.V x -> Some x + | _ -> None); + } + in + incr key_counter_; + { name; k; width; pre; post } end type t = { @@ -63,7 +97,7 @@ let rec debug out (self : t) : unit = | Group d -> Format.fprintf out "(@[group@ %a@])" debug d | Fill { sep = _; l } -> Format.fprintf out "(@[fill@ %a@])" (Format.pp_print_list debug) l - | Wrap (_, _, d) -> Format.fprintf out "(@[ext@ %a@])" debug d + | Wrap (e, _, d) -> Format.fprintf out "(@[ext.%s@ %a@])" e.name debug d let nil : t = { view = Nil; wfl = 0 } let newline : t = { view = Newline 1; wfl = 1 } @@ -98,7 +132,10 @@ let group d : t = | Group _ -> d | _ -> { view = Group d; wfl = d.wfl } -let ext ext v d : t = { view = Wrap (ext, v, d); wfl = d.wfl } +let ext (ext : _ Ext.t) v d : t = + let wfl = d.wfl + ext.width v in + { view = Wrap (ext, v, d); wfl } + let ( ^ ) = append let text_sub_ s i len : t = { view = Text_sub (s, i, len); wfl = len } @@ -135,7 +172,7 @@ let textf fmt = Format.kasprintf text fmt module Flatten = struct let to_out (out : Out.t) (self : t) : unit = - let rec loop (d : t) = + let rec loop (ext_map : Ext.map) (d : t) = match d.view with | Nil | Newline 0 -> () | Char c -> out.char c @@ -144,25 +181,27 @@ module Flatten = struct for _i = 1 to n do out.char ' ' done - | Nest (_, x) -> loop x + | Nest (_, x) -> loop ext_map x | Append (x, y) -> - loop x; - loop y + loop ext_map x; + loop ext_map y | Text s | Text_zero_width s -> out.string s | Text_sub (s, i, len) -> out.sub_string s i len - | Group x -> loop x + | Group x -> loop ext_map x | Fill { sep; l } -> List.iteri (fun i x -> - if i > 0 then loop sep; - loop x) + if i > 0 then loop ext_map sep; + loop ext_map x) l | Wrap (ext, v, d) -> - ext.pre out v; - loop d; - ext.post out v + let inside = Ext.get ext.k ext_map in + ext.pre out ~inside v; + let ext_map' = Ext.add ext.k v ext_map in + loop ext_map' d; + ext.post out ~inside v in - loop self + loop Ext.empty self let to_buffer buf (self : t) : unit = let out = Out.of_buffer buf in @@ -175,7 +214,7 @@ module Flatten = struct end module Pretty = struct - type st = { out: Out.t; width: int } + type st = { out: Out.t; width: int; ext_map: Ext.map } (** Add [i] spaces of indentation. *) let add_indent st (i : int) = @@ -218,9 +257,11 @@ module Pretty = struct l; !n | Wrap (ext, v, d) -> - ext.pre st.out v; - let n = pp_flatten st d in - ext.post st.out v; + let inside = Ext.get ext.k st.ext_map in + ext.pre st.out ~inside v; + let st' = { st with ext_map = Ext.add ext.k v st.ext_map } in + let n = pp_flatten st' d in + ext.post st.out ~inside v; n (** Does [x] fit in the current line when flattened, given that [k] chars @@ -238,51 +279,54 @@ module Pretty = struct match stack with | [] -> () | (i, d) :: stack_tl -> - pp_rec_top st ~k ~i d (fun k -> pp_rec st k stack_tl) + pp_rec_top st ~k ~i d (fun st k -> pp_rec st k stack_tl) (** Print [d] at indentation [i], with [k] chars already printed on the current line, then calls [kont] with the new [k]. *) - and pp_rec_top st ~k ~i d (kont : int -> unit) : unit = + and pp_rec_top st ~k ~i d (kont : st -> int -> unit) : unit = match d.view with - | Nil -> kont k + | Nil -> kont st k | Char c -> st.out.char c; - kont (k + 1) + kont st (k + 1) | Newline _ -> pp_newline st i; - kont i + kont st i | Nest (j, x) -> pp_rec_top st ~k ~i:(i + j) x kont | Append (x, y) -> (* print [x], then print [y] *) - pp_rec_top st ~k ~i x (fun k -> pp_rec_top st ~k ~i y kont) + pp_rec_top st ~k ~i x (fun st k -> pp_rec_top st ~k ~i y kont) | Text s -> st.out.string s; - kont (k + String.length s) + kont st (k + String.length s) | Text_zero_width s -> st.out.string s; - kont k + kont st k | Text_sub (s, i, len) -> st.out.sub_string s i len; - kont (k + len) + kont st (k + len) | Group x -> if fits_flattened st k x then ( (* print flattened *) let w_x = pp_flatten st x in assert (w_x = x.wfl); - kont (k + w_x) + kont st (k + w_x) ) else pp_rec_top st ~k ~i x kont | Fill { sep; l } -> pp_fill st ~k ~i sep l kont | Wrap (ext, v, d) -> - ext.pre st.out v; - pp_rec_top st ~k ~i d (fun k -> - ext.post st.out v; - kont k) + let old_ext_map = st.ext_map in + let inside = Ext.get ext.k st.ext_map in + ext.pre st.out ~inside v; + let st' = { st with ext_map = Ext.add ext.k v st.ext_map } in + pp_rec_top st' ~k ~i d (fun st k -> + ext.post st.out ~inside v; + kont { st with ext_map = old_ext_map } k) - and pp_fill st ~k ~i sep l (kont : int -> unit) : unit = + and pp_fill st ~k ~i sep l (kont : st -> int -> unit) : unit = (* [k] is the current offset in the line *) - let rec loop idx k l = + let rec loop st idx k l = match l with | x :: tl -> if fits_flattened st k x then ( @@ -295,24 +339,24 @@ module Pretty = struct in let w_x = pp_flatten st x in assert (w_x = x.wfl); - loop (idx + 1) (k + w_x + w_sep) tl + loop st (idx + 1) (k + w_x + w_sep) tl ) else ( (* print, followed by a newline and resume filling with [k=i] *) - let pp_and_continue k = - pp_rec_top st ~k ~i x (fun k -> loop (idx + 1) k tl) + let pp_and_continue st k = + pp_rec_top st ~k ~i x (fun st k -> loop st (idx + 1) k tl) in if idx > 0 then (* separator, then item *) pp_rec_top st ~k ~i sep pp_and_continue else - pp_and_continue k + pp_and_continue st k ) - | [] -> kont k + | [] -> kont st k in - loop 0 k l + loop st 0 k l let to_out ~width out (self : t) : unit = - let st = { out; width } in + let st = { out; width; ext_map = Ext.empty } in pp_rec st 0 [ 0, self ] let to_buffer ~width (buf : Buffer.t) (self : t) : unit = @@ -464,12 +508,15 @@ module Term_color = struct Buffer.add_string buf "m"; Buffer.contents buf - (* TODO: handle nested styles *) let ext_style_ : style list Ext.t = - { - pre = (fun out l -> Out.string out (string_of_style_list l)); - post = (fun out _l -> Out.string out reset); - } + Ext.make ~name:"termcolor" + ~pre:(fun out ~inside:_ l -> Out.string out (string_of_style_list l)) + ~post:(fun out ~inside _l -> + let style = + CCOption.map_or ~default:reset string_of_style_list inside + in + Out.string out style) + () (** Set the foreground color. *) let color (c : color) (d : t) : t = ext ext_style_ [ `FG c ] d diff --git a/src/pp/containers_pp.mli b/src/pp/containers_pp.mli index fe502348..80cce5ae 100644 --- a/src/pp/containers_pp.mli +++ b/src/pp/containers_pp.mli @@ -120,15 +120,26 @@ end might be annotated with ANSI-terminal colors, or with HTML tags. *) module Ext : sig - type 'a t = { - pre: Out.t -> 'a -> unit; (** Printed before the wrapped value. *) - post: Out.t -> 'a -> unit; (** Printed after the wrapped value. *) - } + type 'a t + + val make : + ?width:('a -> int) -> + name:string -> + pre:(Out.t -> inside:'a option -> 'a -> unit) -> + post:(Out.t -> inside:'a option -> 'a -> unit) -> + unit -> + 'a t (** An extension is a custom document node. It takes a value of type ['a], and a document [d], and can output what it wants based on the custom value before and after [d] is printed. - The extension is considered to have width [0]. *) + The extension is considered to have width [0], unless [width] + is specified. + + @param pre called before the wrapped value is printed + @param post called after the wrapped value is printed + + *) end val ext : 'a Ext.t -> 'a -> t -> t