breaking: pp: modify Ext.t so it takes surrounding value

The type is now opaque, using a smart constructor, and is passed
the value used in the closest surrounding call to this extension,
if any. It is used by `Term_color` to properly restore ANSI
style in nested situations.
This commit is contained in:
Simon Cruanes 2023-11-13 22:22:34 -05:00
parent 1508b6c940
commit 0d273d886f
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 110 additions and 52 deletions

View file

@ -1,4 +1,5 @@
module B = Buffer
module Int_map = Map.Make (CCInt)
module Out = struct
type t = {
@ -24,10 +25,43 @@ module Out = struct
end
module Ext = struct
type view = ..
type 'a key = { id: int; inject: 'a -> view; extract: view -> 'a option }
type map = view Int_map.t
let empty : map = Int_map.empty
let get k (self : map) : _ option =
try k.extract @@ Int_map.find k.id self with Not_found -> None
let add k v self : map = Int_map.add k.id (k.inject v) self
type 'a t = {
pre: Out.t -> 'a -> unit; (** Printed before the wrapped value. *)
post: Out.t -> 'a -> unit; (** Printed after the wrapped value. *)
name: string;
k: 'a key;
width: 'a -> int;
pre: Out.t -> inside:'a option -> 'a -> unit;
post: Out.t -> inside:'a option -> 'a -> unit;
}
let key_counter_ = ref 0
let make (type a) ?(width = fun _ -> 0) ~name ~pre ~post () : a t =
let module M = struct
type view += V of a
end in
let k =
{
id = !key_counter_;
inject = (fun x -> M.V x);
extract =
(function
| M.V x -> Some x
| _ -> None);
}
in
incr key_counter_;
{ name; k; width; pre; post }
end
type t = {
@ -63,7 +97,7 @@ let rec debug out (self : t) : unit =
| Group d -> Format.fprintf out "(@[group@ %a@])" debug d
| Fill { sep = _; l } ->
Format.fprintf out "(@[fill@ %a@])" (Format.pp_print_list debug) l
| Wrap (_, _, d) -> Format.fprintf out "(@[ext@ %a@])" debug d
| Wrap (e, _, d) -> Format.fprintf out "(@[ext.%s@ %a@])" e.name debug d
let nil : t = { view = Nil; wfl = 0 }
let newline : t = { view = Newline 1; wfl = 1 }
@ -98,7 +132,10 @@ let group d : t =
| Group _ -> d
| _ -> { view = Group d; wfl = d.wfl }
let ext ext v d : t = { view = Wrap (ext, v, d); wfl = d.wfl }
let ext (ext : _ Ext.t) v d : t =
let wfl = d.wfl + ext.width v in
{ view = Wrap (ext, v, d); wfl }
let ( ^ ) = append
let text_sub_ s i len : t = { view = Text_sub (s, i, len); wfl = len }
@ -135,7 +172,7 @@ let textf fmt = Format.kasprintf text fmt
module Flatten = struct
let to_out (out : Out.t) (self : t) : unit =
let rec loop (d : t) =
let rec loop (ext_map : Ext.map) (d : t) =
match d.view with
| Nil | Newline 0 -> ()
| Char c -> out.char c
@ -144,25 +181,27 @@ module Flatten = struct
for _i = 1 to n do
out.char ' '
done
| Nest (_, x) -> loop x
| Nest (_, x) -> loop ext_map x
| Append (x, y) ->
loop x;
loop y
loop ext_map x;
loop ext_map y
| Text s | Text_zero_width s -> out.string s
| Text_sub (s, i, len) -> out.sub_string s i len
| Group x -> loop x
| Group x -> loop ext_map x
| Fill { sep; l } ->
List.iteri
(fun i x ->
if i > 0 then loop sep;
loop x)
if i > 0 then loop ext_map sep;
loop ext_map x)
l
| Wrap (ext, v, d) ->
ext.pre out v;
loop d;
ext.post out v
let inside = Ext.get ext.k ext_map in
ext.pre out ~inside v;
let ext_map' = Ext.add ext.k v ext_map in
loop ext_map' d;
ext.post out ~inside v
in
loop self
loop Ext.empty self
let to_buffer buf (self : t) : unit =
let out = Out.of_buffer buf in
@ -175,7 +214,7 @@ module Flatten = struct
end
module Pretty = struct
type st = { out: Out.t; width: int }
type st = { out: Out.t; width: int; ext_map: Ext.map }
(** Add [i] spaces of indentation. *)
let add_indent st (i : int) =
@ -218,9 +257,11 @@ module Pretty = struct
l;
!n
| Wrap (ext, v, d) ->
ext.pre st.out v;
let n = pp_flatten st d in
ext.post st.out v;
let inside = Ext.get ext.k st.ext_map in
ext.pre st.out ~inside v;
let st' = { st with ext_map = Ext.add ext.k v st.ext_map } in
let n = pp_flatten st' d in
ext.post st.out ~inside v;
n
(** Does [x] fit in the current line when flattened, given that [k] chars
@ -238,51 +279,54 @@ module Pretty = struct
match stack with
| [] -> ()
| (i, d) :: stack_tl ->
pp_rec_top st ~k ~i d (fun k -> pp_rec st k stack_tl)
pp_rec_top st ~k ~i d (fun st k -> pp_rec st k stack_tl)
(** Print [d] at indentation [i], with [k] chars already printed
on the current line, then calls [kont] with the
new [k]. *)
and pp_rec_top st ~k ~i d (kont : int -> unit) : unit =
and pp_rec_top st ~k ~i d (kont : st -> int -> unit) : unit =
match d.view with
| Nil -> kont k
| Nil -> kont st k
| Char c ->
st.out.char c;
kont (k + 1)
kont st (k + 1)
| Newline _ ->
pp_newline st i;
kont i
kont st i
| Nest (j, x) -> pp_rec_top st ~k ~i:(i + j) x kont
| Append (x, y) ->
(* print [x], then print [y] *)
pp_rec_top st ~k ~i x (fun k -> pp_rec_top st ~k ~i y kont)
pp_rec_top st ~k ~i x (fun st k -> pp_rec_top st ~k ~i y kont)
| Text s ->
st.out.string s;
kont (k + String.length s)
kont st (k + String.length s)
| Text_zero_width s ->
st.out.string s;
kont k
kont st k
| Text_sub (s, i, len) ->
st.out.sub_string s i len;
kont (k + len)
kont st (k + len)
| Group x ->
if fits_flattened st k x then (
(* print flattened *)
let w_x = pp_flatten st x in
assert (w_x = x.wfl);
kont (k + w_x)
kont st (k + w_x)
) else
pp_rec_top st ~k ~i x kont
| Fill { sep; l } -> pp_fill st ~k ~i sep l kont
| Wrap (ext, v, d) ->
ext.pre st.out v;
pp_rec_top st ~k ~i d (fun k ->
ext.post st.out v;
kont k)
let old_ext_map = st.ext_map in
let inside = Ext.get ext.k st.ext_map in
ext.pre st.out ~inside v;
let st' = { st with ext_map = Ext.add ext.k v st.ext_map } in
pp_rec_top st' ~k ~i d (fun st k ->
ext.post st.out ~inside v;
kont { st with ext_map = old_ext_map } k)
and pp_fill st ~k ~i sep l (kont : int -> unit) : unit =
and pp_fill st ~k ~i sep l (kont : st -> int -> unit) : unit =
(* [k] is the current offset in the line *)
let rec loop idx k l =
let rec loop st idx k l =
match l with
| x :: tl ->
if fits_flattened st k x then (
@ -295,24 +339,24 @@ module Pretty = struct
in
let w_x = pp_flatten st x in
assert (w_x = x.wfl);
loop (idx + 1) (k + w_x + w_sep) tl
loop st (idx + 1) (k + w_x + w_sep) tl
) else (
(* print, followed by a newline and resume filling with [k=i] *)
let pp_and_continue k =
pp_rec_top st ~k ~i x (fun k -> loop (idx + 1) k tl)
let pp_and_continue st k =
pp_rec_top st ~k ~i x (fun st k -> loop st (idx + 1) k tl)
in
if idx > 0 then
(* separator, then item *)
pp_rec_top st ~k ~i sep pp_and_continue
else
pp_and_continue k
pp_and_continue st k
)
| [] -> kont k
| [] -> kont st k
in
loop 0 k l
loop st 0 k l
let to_out ~width out (self : t) : unit =
let st = { out; width } in
let st = { out; width; ext_map = Ext.empty } in
pp_rec st 0 [ 0, self ]
let to_buffer ~width (buf : Buffer.t) (self : t) : unit =
@ -464,12 +508,15 @@ module Term_color = struct
Buffer.add_string buf "m";
Buffer.contents buf
(* TODO: handle nested styles *)
let ext_style_ : style list Ext.t =
{
pre = (fun out l -> Out.string out (string_of_style_list l));
post = (fun out _l -> Out.string out reset);
}
Ext.make ~name:"termcolor"
~pre:(fun out ~inside:_ l -> Out.string out (string_of_style_list l))
~post:(fun out ~inside _l ->
let style =
CCOption.map_or ~default:reset string_of_style_list inside
in
Out.string out style)
()
(** Set the foreground color. *)
let color (c : color) (d : t) : t = ext ext_style_ [ `FG c ] d

View file

@ -120,15 +120,26 @@ end
might be annotated with ANSI-terminal colors, or
with HTML tags. *)
module Ext : sig
type 'a t = {
pre: Out.t -> 'a -> unit; (** Printed before the wrapped value. *)
post: Out.t -> 'a -> unit; (** Printed after the wrapped value. *)
}
type 'a t
val make :
?width:('a -> int) ->
name:string ->
pre:(Out.t -> inside:'a option -> 'a -> unit) ->
post:(Out.t -> inside:'a option -> 'a -> unit) ->
unit ->
'a t
(** An extension is a custom document node. It takes a value of type ['a],
and a document [d], and can output what it wants based
on the custom value before and after [d] is printed.
The extension is considered to have width [0]. *)
The extension is considered to have width [0], unless [width]
is specified.
@param pre called before the wrapped value is printed
@param post called after the wrapped value is printed
*)
end
val ext : 'a Ext.t -> 'a -> t -> t