breaking: pp: modify Ext.t so it takes surrounding value

The type is now opaque, using a smart constructor, and is passed
the value used in the closest surrounding call to this extension,
if any. It is used by `Term_color` to properly restore ANSI
style in nested situations.
This commit is contained in:
Simon Cruanes 2023-11-13 22:22:34 -05:00
parent 1508b6c940
commit 0d273d886f
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 110 additions and 52 deletions

View file

@ -1,4 +1,5 @@
module B = Buffer module B = Buffer
module Int_map = Map.Make (CCInt)
module Out = struct module Out = struct
type t = { type t = {
@ -24,10 +25,43 @@ module Out = struct
end end
module Ext = struct module Ext = struct
type view = ..
type 'a key = { id: int; inject: 'a -> view; extract: view -> 'a option }
type map = view Int_map.t
let empty : map = Int_map.empty
let get k (self : map) : _ option =
try k.extract @@ Int_map.find k.id self with Not_found -> None
let add k v self : map = Int_map.add k.id (k.inject v) self
type 'a t = { type 'a t = {
pre: Out.t -> 'a -> unit; (** Printed before the wrapped value. *) name: string;
post: Out.t -> 'a -> unit; (** Printed after the wrapped value. *) k: 'a key;
width: 'a -> int;
pre: Out.t -> inside:'a option -> 'a -> unit;
post: Out.t -> inside:'a option -> 'a -> unit;
} }
let key_counter_ = ref 0
let make (type a) ?(width = fun _ -> 0) ~name ~pre ~post () : a t =
let module M = struct
type view += V of a
end in
let k =
{
id = !key_counter_;
inject = (fun x -> M.V x);
extract =
(function
| M.V x -> Some x
| _ -> None);
}
in
incr key_counter_;
{ name; k; width; pre; post }
end end
type t = { type t = {
@ -63,7 +97,7 @@ let rec debug out (self : t) : unit =
| Group d -> Format.fprintf out "(@[group@ %a@])" debug d | Group d -> Format.fprintf out "(@[group@ %a@])" debug d
| Fill { sep = _; l } -> | Fill { sep = _; l } ->
Format.fprintf out "(@[fill@ %a@])" (Format.pp_print_list debug) l Format.fprintf out "(@[fill@ %a@])" (Format.pp_print_list debug) l
| Wrap (_, _, d) -> Format.fprintf out "(@[ext@ %a@])" debug d | Wrap (e, _, d) -> Format.fprintf out "(@[ext.%s@ %a@])" e.name debug d
let nil : t = { view = Nil; wfl = 0 } let nil : t = { view = Nil; wfl = 0 }
let newline : t = { view = Newline 1; wfl = 1 } let newline : t = { view = Newline 1; wfl = 1 }
@ -98,7 +132,10 @@ let group d : t =
| Group _ -> d | Group _ -> d
| _ -> { view = Group d; wfl = d.wfl } | _ -> { view = Group d; wfl = d.wfl }
let ext ext v d : t = { view = Wrap (ext, v, d); wfl = d.wfl } let ext (ext : _ Ext.t) v d : t =
let wfl = d.wfl + ext.width v in
{ view = Wrap (ext, v, d); wfl }
let ( ^ ) = append let ( ^ ) = append
let text_sub_ s i len : t = { view = Text_sub (s, i, len); wfl = len } let text_sub_ s i len : t = { view = Text_sub (s, i, len); wfl = len }
@ -135,7 +172,7 @@ let textf fmt = Format.kasprintf text fmt
module Flatten = struct module Flatten = struct
let to_out (out : Out.t) (self : t) : unit = let to_out (out : Out.t) (self : t) : unit =
let rec loop (d : t) = let rec loop (ext_map : Ext.map) (d : t) =
match d.view with match d.view with
| Nil | Newline 0 -> () | Nil | Newline 0 -> ()
| Char c -> out.char c | Char c -> out.char c
@ -144,25 +181,27 @@ module Flatten = struct
for _i = 1 to n do for _i = 1 to n do
out.char ' ' out.char ' '
done done
| Nest (_, x) -> loop x | Nest (_, x) -> loop ext_map x
| Append (x, y) -> | Append (x, y) ->
loop x; loop ext_map x;
loop y loop ext_map y
| Text s | Text_zero_width s -> out.string s | Text s | Text_zero_width s -> out.string s
| Text_sub (s, i, len) -> out.sub_string s i len | Text_sub (s, i, len) -> out.sub_string s i len
| Group x -> loop x | Group x -> loop ext_map x
| Fill { sep; l } -> | Fill { sep; l } ->
List.iteri List.iteri
(fun i x -> (fun i x ->
if i > 0 then loop sep; if i > 0 then loop ext_map sep;
loop x) loop ext_map x)
l l
| Wrap (ext, v, d) -> | Wrap (ext, v, d) ->
ext.pre out v; let inside = Ext.get ext.k ext_map in
loop d; ext.pre out ~inside v;
ext.post out v let ext_map' = Ext.add ext.k v ext_map in
loop ext_map' d;
ext.post out ~inside v
in in
loop self loop Ext.empty self
let to_buffer buf (self : t) : unit = let to_buffer buf (self : t) : unit =
let out = Out.of_buffer buf in let out = Out.of_buffer buf in
@ -175,7 +214,7 @@ module Flatten = struct
end end
module Pretty = struct module Pretty = struct
type st = { out: Out.t; width: int } type st = { out: Out.t; width: int; ext_map: Ext.map }
(** Add [i] spaces of indentation. *) (** Add [i] spaces of indentation. *)
let add_indent st (i : int) = let add_indent st (i : int) =
@ -218,9 +257,11 @@ module Pretty = struct
l; l;
!n !n
| Wrap (ext, v, d) -> | Wrap (ext, v, d) ->
ext.pre st.out v; let inside = Ext.get ext.k st.ext_map in
let n = pp_flatten st d in ext.pre st.out ~inside v;
ext.post st.out v; let st' = { st with ext_map = Ext.add ext.k v st.ext_map } in
let n = pp_flatten st' d in
ext.post st.out ~inside v;
n n
(** Does [x] fit in the current line when flattened, given that [k] chars (** Does [x] fit in the current line when flattened, given that [k] chars
@ -238,51 +279,54 @@ module Pretty = struct
match stack with match stack with
| [] -> () | [] -> ()
| (i, d) :: stack_tl -> | (i, d) :: stack_tl ->
pp_rec_top st ~k ~i d (fun k -> pp_rec st k stack_tl) pp_rec_top st ~k ~i d (fun st k -> pp_rec st k stack_tl)
(** Print [d] at indentation [i], with [k] chars already printed (** Print [d] at indentation [i], with [k] chars already printed
on the current line, then calls [kont] with the on the current line, then calls [kont] with the
new [k]. *) new [k]. *)
and pp_rec_top st ~k ~i d (kont : int -> unit) : unit = and pp_rec_top st ~k ~i d (kont : st -> int -> unit) : unit =
match d.view with match d.view with
| Nil -> kont k | Nil -> kont st k
| Char c -> | Char c ->
st.out.char c; st.out.char c;
kont (k + 1) kont st (k + 1)
| Newline _ -> | Newline _ ->
pp_newline st i; pp_newline st i;
kont i kont st i
| Nest (j, x) -> pp_rec_top st ~k ~i:(i + j) x kont | Nest (j, x) -> pp_rec_top st ~k ~i:(i + j) x kont
| Append (x, y) -> | Append (x, y) ->
(* print [x], then print [y] *) (* print [x], then print [y] *)
pp_rec_top st ~k ~i x (fun k -> pp_rec_top st ~k ~i y kont) pp_rec_top st ~k ~i x (fun st k -> pp_rec_top st ~k ~i y kont)
| Text s -> | Text s ->
st.out.string s; st.out.string s;
kont (k + String.length s) kont st (k + String.length s)
| Text_zero_width s -> | Text_zero_width s ->
st.out.string s; st.out.string s;
kont k kont st k
| Text_sub (s, i, len) -> | Text_sub (s, i, len) ->
st.out.sub_string s i len; st.out.sub_string s i len;
kont (k + len) kont st (k + len)
| Group x -> | Group x ->
if fits_flattened st k x then ( if fits_flattened st k x then (
(* print flattened *) (* print flattened *)
let w_x = pp_flatten st x in let w_x = pp_flatten st x in
assert (w_x = x.wfl); assert (w_x = x.wfl);
kont (k + w_x) kont st (k + w_x)
) else ) else
pp_rec_top st ~k ~i x kont pp_rec_top st ~k ~i x kont
| Fill { sep; l } -> pp_fill st ~k ~i sep l kont | Fill { sep; l } -> pp_fill st ~k ~i sep l kont
| Wrap (ext, v, d) -> | Wrap (ext, v, d) ->
ext.pre st.out v; let old_ext_map = st.ext_map in
pp_rec_top st ~k ~i d (fun k -> let inside = Ext.get ext.k st.ext_map in
ext.post st.out v; ext.pre st.out ~inside v;
kont k) let st' = { st with ext_map = Ext.add ext.k v st.ext_map } in
pp_rec_top st' ~k ~i d (fun st k ->
ext.post st.out ~inside v;
kont { st with ext_map = old_ext_map } k)
and pp_fill st ~k ~i sep l (kont : int -> unit) : unit = and pp_fill st ~k ~i sep l (kont : st -> int -> unit) : unit =
(* [k] is the current offset in the line *) (* [k] is the current offset in the line *)
let rec loop idx k l = let rec loop st idx k l =
match l with match l with
| x :: tl -> | x :: tl ->
if fits_flattened st k x then ( if fits_flattened st k x then (
@ -295,24 +339,24 @@ module Pretty = struct
in in
let w_x = pp_flatten st x in let w_x = pp_flatten st x in
assert (w_x = x.wfl); assert (w_x = x.wfl);
loop (idx + 1) (k + w_x + w_sep) tl loop st (idx + 1) (k + w_x + w_sep) tl
) else ( ) else (
(* print, followed by a newline and resume filling with [k=i] *) (* print, followed by a newline and resume filling with [k=i] *)
let pp_and_continue k = let pp_and_continue st k =
pp_rec_top st ~k ~i x (fun k -> loop (idx + 1) k tl) pp_rec_top st ~k ~i x (fun st k -> loop st (idx + 1) k tl)
in in
if idx > 0 then if idx > 0 then
(* separator, then item *) (* separator, then item *)
pp_rec_top st ~k ~i sep pp_and_continue pp_rec_top st ~k ~i sep pp_and_continue
else else
pp_and_continue k pp_and_continue st k
) )
| [] -> kont k | [] -> kont st k
in in
loop 0 k l loop st 0 k l
let to_out ~width out (self : t) : unit = let to_out ~width out (self : t) : unit =
let st = { out; width } in let st = { out; width; ext_map = Ext.empty } in
pp_rec st 0 [ 0, self ] pp_rec st 0 [ 0, self ]
let to_buffer ~width (buf : Buffer.t) (self : t) : unit = let to_buffer ~width (buf : Buffer.t) (self : t) : unit =
@ -464,12 +508,15 @@ module Term_color = struct
Buffer.add_string buf "m"; Buffer.add_string buf "m";
Buffer.contents buf Buffer.contents buf
(* TODO: handle nested styles *)
let ext_style_ : style list Ext.t = let ext_style_ : style list Ext.t =
{ Ext.make ~name:"termcolor"
pre = (fun out l -> Out.string out (string_of_style_list l)); ~pre:(fun out ~inside:_ l -> Out.string out (string_of_style_list l))
post = (fun out _l -> Out.string out reset); ~post:(fun out ~inside _l ->
} let style =
CCOption.map_or ~default:reset string_of_style_list inside
in
Out.string out style)
()
(** Set the foreground color. *) (** Set the foreground color. *)
let color (c : color) (d : t) : t = ext ext_style_ [ `FG c ] d let color (c : color) (d : t) : t = ext ext_style_ [ `FG c ] d

View file

@ -120,15 +120,26 @@ end
might be annotated with ANSI-terminal colors, or might be annotated with ANSI-terminal colors, or
with HTML tags. *) with HTML tags. *)
module Ext : sig module Ext : sig
type 'a t = { type 'a t
pre: Out.t -> 'a -> unit; (** Printed before the wrapped value. *)
post: Out.t -> 'a -> unit; (** Printed after the wrapped value. *) val make :
} ?width:('a -> int) ->
name:string ->
pre:(Out.t -> inside:'a option -> 'a -> unit) ->
post:(Out.t -> inside:'a option -> 'a -> unit) ->
unit ->
'a t
(** An extension is a custom document node. It takes a value of type ['a], (** An extension is a custom document node. It takes a value of type ['a],
and a document [d], and can output what it wants based and a document [d], and can output what it wants based
on the custom value before and after [d] is printed. on the custom value before and after [d] is printed.
The extension is considered to have width [0]. *) The extension is considered to have width [0], unless [width]
is specified.
@param pre called before the wrapped value is printed
@param post called after the wrapped value is printed
*)
end end
val ext : 'a Ext.t -> 'a -> t -> t val ext : 'a Ext.t -> 'a -> t -> t