Lwt_pipe now with reader/writer subtypes of pipe, better API, safer closing

This commit is contained in:
Simon Cruanes 2015-02-19 18:15:49 +01:00
parent 77b6197c49
commit c6b23890ec
2 changed files with 284 additions and 157 deletions

View file

@ -28,7 +28,6 @@ type 'a or_error = [`Ok of 'a | `Error of string]
type 'a step = ['a or_error | `End] type 'a step = ['a or_error | `End]
let (>>=) = Lwt.(>>=) let (>>=) = Lwt.(>>=)
let (>|=) = Lwt.(>|=)
module LwtErr = struct module LwtErr = struct
type 'a t = 'a or_error Lwt.t type 'a t = 'a or_error Lwt.t
@ -54,35 +53,43 @@ module LwtErr = struct
) x ) x
end end
let step_map f = function
| `Ok x -> `Ok (f x)
| (`Error _ | `End) as e -> e
let (>>|=) = LwtErr.(>|=) let (>>|=) = LwtErr.(>|=)
let ret_end = Lwt.return `End let ret_end = Lwt.return `End
exception Closed
module Pipe = struct module Pipe = struct
type -'a writer = 'a step -> unit Lwt.t
type +'a reader = unit -> 'a step Lwt.t
(* messages given to writers through the condition *) (* messages given to writers through the condition *)
type 'a msg = type 'a msg =
| Send of 'a step Lwt.u (* send directly to reader *) | Send of 'a step Lwt.u (* send directly to reader *)
| SendQueue (* push into queue *) | SendQueue (* push into queue *)
| Close (* close *)
type 'a t = { type 'a inner_buf =
| Buf of 'a step Queue.t * int (* buf, max size *)
| NoBuf
type ('a, +'perm) t = {
close : unit Lwt.u;
closed : unit Lwt.t;
lock : Lwt_mutex.t; lock : Lwt_mutex.t;
queue : 'a step Queue.t; buf : 'a inner_buf;
max_size : int;
cond : 'a msg Lwt_condition.t; cond : 'a msg Lwt_condition.t;
mutable keep : unit Lwt.t list; (* do not GC *) mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *)
} } constraint 'perm = [< `r | `w]
let create ?(max_size=0) () = { let create ?(max_size=0) () =
queue=Queue.create(); let buf = match max_size with
max_size; | 0 -> NoBuf
| n when n < 0 -> invalid_arg "max_size"
| n -> Buf (Queue.create (), n)
in
let closed, close = Lwt.wait () in
{
close;
closed;
buf;
lock=Lwt_mutex.create(); lock=Lwt_mutex.create();
cond=Lwt_condition.create(); cond=Lwt_condition.create();
keep=[]; keep=[];
@ -90,198 +97,299 @@ module Pipe = struct
let keep p fut = p.keep <- fut :: p.keep let keep p fut = p.keep <- fut :: p.keep
let is_closed p = not (Lwt.is_sleeping p.closed)
let close p =
if is_closed p then Lwt.return_unit
else (
Lwt.wakeup p.close (); (* evaluate *)
Lwt_condition.broadcast p.cond Close;
Lwt.join p.keep;
)
let close_async p = Lwt.async (fun () -> close p)
let on_close p = p.closed
(* try to take next element from buffer *)
let try_next_buf t = match t.buf with
| NoBuf -> None
| Buf (q, _) ->
if Queue.is_empty q then None
else Some (Queue.pop q)
(* returns true if it could push successfully *)
let try_push_buf t x = match t.buf with
| NoBuf -> false
| Buf (q, max_size) when Queue.length q = max_size -> false
| Buf (q, _) -> Queue.push x q; true
(* read next one *) (* read next one *)
let reader t () = let read t =
Lwt_mutex.with_lock t.lock Lwt_mutex.with_lock t.lock
(fun () -> (fun () ->
if Queue.is_empty t.queue match try_next_buf t with
then ( | None when is_closed t -> ret_end (* end of stream *)
| None ->
let fut, send = Lwt.wait () in let fut, send = Lwt.wait () in
Lwt_condition.signal t.cond (Send send); Lwt_condition.signal t.cond (Send send);
fut fut
) else ( | Some x ->
(* direct pop *)
assert (t.max_size > 0);
let x = Queue.pop t.queue in
Lwt_condition.signal t.cond SendQueue; (* queue isn't full anymore *) Lwt_condition.signal t.cond SendQueue; (* queue isn't full anymore *)
Lwt.return x Lwt.return x
) )
)
(* write a value *) (* write a value *)
let writer t x = let write t x =
let rec try_write () = let rec try_write () =
if Queue.length t.queue < t.max_size then ( if is_closed t then Lwt.fail Closed
Queue.push x t.queue; else if try_push_buf t x
Lwt.return_unit then Lwt.return_unit (* into buffer, do not wait *)
) else ( else (
(* wait for readers to consume the queue *) (* wait for readers to consume the queue *)
Lwt_condition.wait ~mutex:t.lock t.cond >>= fun msg -> Lwt_condition.wait ~mutex:t.lock t.cond >>= fun msg ->
match msg with match msg with
| Send s -> | Send s ->
Lwt.wakeup s x; Lwt.wakeup s x; (* sync with reader *)
Lwt.return_unit Lwt.return_unit
| SendQueue -> try_write () (* try again! *) | SendQueue -> try_write () (* try again! *)
| Close -> Lwt.fail Closed
) )
in in
Lwt_mutex.with_lock t.lock try_write Lwt_mutex.with_lock t.lock try_write
let create_pair ?max_size () = let rec connect_rec r w =
let p = create ?max_size () in read r >>= function
reader p, writer p | `End -> Lwt.return_unit
| (`Error _ | `Ok _) as step ->
write w step >>= fun () ->
connect_rec r w
let rec connect_ (r:'a reader) (w:'a writer) = let connect a b =
r () >>= function let fut = connect_rec a b in
| `End -> w `End (* then stop *) keep b fut
| (`Error _ | `Ok _) as step -> w step >>= fun () -> connect_ r w
let pipe_into p1 p2 = (* close a when b closes *)
connect_ (reader p1) (writer p2) let close_when_closed a b =
Lwt.on_success b.closed
(fun () -> close_async a)
(* close a when every member of l closes *)
let close_when_all_closed a l =
let n = ref (List.length l) in
List.iter
(fun p -> Lwt.on_success p.closed
(fun () ->
decr n;
if !n = 0 then close_async a
)
) l
end end
let connect r w = Pipe.connect_ r w
module Writer = struct module Writer = struct
type -'a t = 'a Pipe.writer type 'a t = ('a, [`w]) Pipe.t
let write t x = t (`Ok x) let write t x = Pipe.write t (`Ok x)
let write_error t msg = t (`Error msg) let write_error t msg = Pipe.write t (`Error msg)
let write_end t = t `End
let rec write_list t l = match l with let rec write_list t l = match l with
| [] -> Lwt.return_unit | [] -> Lwt.return_unit
| x :: tail -> | x :: tail ->
write t x >>= fun () -> write_list t tail write t x >>= fun () -> write_list t tail
let map ~f t x = t (step_map f x) let map ~f a =
let b = Pipe.create() in
let rec fwd () =
Pipe.read b >>= function
| `Ok x -> write a (f x) >>= fwd
| `Error msg -> write_error a msg >>= fwd
| `End -> Lwt.return_unit
in
Pipe.keep b (fwd());
(* when a gets closed, close b too *)
Lwt.on_success (Pipe.on_close a) (fun () -> Pipe.close_async b);
b
let send_all l =
if l = [] then invalid_arg "send_all";
let res = Pipe.create () in
let rec fwd () =
Pipe.read res >>= function
| `End -> Lwt.return_unit
| `Ok x -> Lwt_list.iter_p (fun p -> write p x) l >>= fwd
| `Error msg -> Lwt_list.iter_p (fun p -> write_error p msg) l >>= fwd
in
(* do not GC before res dies; close res when any outputx is closed *)
Pipe.keep res (fwd ());
List.iter (Pipe.close_when_closed res) l;
res
let send_both a b = send_all [a; b]
end end
module Reader = struct module Reader = struct
type +'a t = 'a Pipe.reader type 'a t = ('a, [`r]) Pipe.t
let read t = t () let read = Pipe.read
let map ~f t () = let map ~f a =
t () >|= (step_map f) let b = Pipe.create () in
let rec fwd () =
Pipe.read a >>= function
| `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd
| (`Error _) as e -> Pipe.write b e >>= fwd
| `End -> Pipe.close b
in
Pipe.keep b (fwd());
b
let rec filter_map ~f t () = let filter_map ~f a =
t () >>= function let b = Pipe.create () in
| `Error msg -> LwtErr.fail msg let rec fwd () =
Pipe.read a >>= function
| `Ok x -> | `Ok x ->
begin match f x with begin match f x with
| Some y -> LwtErr.return y | None -> fwd()
| None -> filter_map ~f t () | Some y -> Pipe.write b (`Ok y) >>= fwd
end end
| `End -> ret_end | (`Error _) as e -> Pipe.write b e >>= fwd
| `End -> Pipe.close b
in
Pipe.keep b (fwd());
b
let rec fold ~f ~x t = let rec fold ~f ~x t =
t () >>= function read t >>= function
| `End -> LwtErr.return x | `End -> LwtErr.return x
| `Error msg -> LwtErr.fail msg | `Error msg -> LwtErr.fail msg
| `Ok y -> fold ~f ~x:(f x y) t | `Ok y -> fold ~f ~x:(f x y) t
let rec fold_s ~f ~x t = let rec fold_s ~f ~x t =
t () >>= function read t >>= function
| `End -> LwtErr.return x | `End -> LwtErr.return x
| `Error msg -> LwtErr.fail msg | `Error msg -> LwtErr.fail msg
| `Ok y -> | `Ok y ->
f x y >>= fun x -> fold_s ~f ~x t f x y >>= fun x -> fold_s ~f ~x t
let rec iter ~f t = let rec iter ~f t =
t () >>= function read t >>= function
| `End -> LwtErr.return_unit | `End -> LwtErr.return_unit
| `Error msg -> LwtErr.fail msg | `Error msg -> LwtErr.fail msg
| `Ok x -> f x; iter ~f t | `Ok x -> f x; iter ~f t
let rec iter_s ~f t = let rec iter_s ~f t =
t () >>= function read t >>= function
| `End -> LwtErr.return_unit | `End -> LwtErr.return_unit
| `Error msg -> LwtErr.fail msg | `Error msg -> LwtErr.fail msg
| `Ok x -> f x >>= fun () -> iter_s ~f t | `Ok x -> f x >>= fun () -> iter_s ~f t
let merge a b : _ t = let merge_all l =
let r, w = Pipe.create_pair () in if l = [] then invalid_arg "merge_all";
Lwt.async (fun () -> Lwt.join [connect a w; connect b w]); let res = Pipe.create () in
r List.iter (fun p -> Pipe.connect p res) l;
(* connect res' input to all members of l; close res when they all close *)
Pipe.close_when_all_closed res l;
res
let merge_both a b = merge_all [a; b]
end end
(** {2 Conversions} *) (** {2 Conversions} *)
let of_list l : _ Reader.t = let of_list l : _ Reader.t =
let l = ref l in let p = Pipe.create ~max_size:0 () in
fun () -> match !l with Pipe.keep p (Lwt_list.iter_s (Writer.write p) l >>= fun () -> Pipe.close p);
| [] -> ret_end p
| x :: tail ->
l := tail;
Lwt.return (`Ok x)
let of_array a = let of_array a =
let i = ref 0 in let p = Pipe.create ~max_size:0 () in
fun () -> let rec send i =
if !i = Array.length a if i = Array.length a then Pipe.close p
then ret_end
else ( else (
let x = a.(!i) in Writer.write p a.(i) >>= fun () ->
incr i; send (i+1)
Lwt.return (`Ok x)
) )
in
Pipe.keep p (send 0);
p
let of_string s = let of_string a =
let i = ref 0 in let p = Pipe.create ~max_size:0 () in
fun () -> let rec send i =
if !i = String.length s if i = String.length a then Pipe.close p
then ret_end
else ( else (
let x = String.get s !i in Writer.write p (String.get a i) >>= fun () ->
incr i; send (i+1)
Lwt.return (`Ok x)
) )
in
Pipe.keep p (send 0);
p
let to_rev_list w = let to_rev_list r =
Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] w Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r
let to_list w = to_rev_list w >>|= List.rev let to_list r = to_rev_list r >>|= List.rev
let to_list_exn w = let to_list_exn r =
to_list w >>= function to_list r >>= function
| `Error msg -> Lwt.fail (Failure msg) | `Error msg -> Lwt.fail (Failure msg)
| `Ok x -> Lwt.return x | `Ok x -> Lwt.return x
let to_buffer buf : _ Writer.t = function let to_buffer buf =
| `Ok c -> let p = Pipe.create () in
Buffer.add_char buf c; Pipe.keep p (
Reader.iter ~f:(fun c -> Buffer.add_char buf c) p >>= fun _ ->
Lwt.return_unit Lwt.return_unit
| `Error _ | `End -> Lwt.return_unit );
p
let to_buffer_str buf = function let to_buffer_str buf =
| `Ok s -> let p = Pipe.create () in
Buffer.add_string buf s; Pipe.keep p (
Reader.iter ~f:(fun s -> Buffer.add_string buf s) p >>= fun _ ->
Lwt.return_unit Lwt.return_unit
| `Error _ | `End -> Lwt.return_unit );
p
(** {2 Basic IO wrappers} *) (** {2 Basic IO wrappers} *)
module IO = struct module IO = struct
let read ?(bufsize=4096) ic : _ Reader.t = let read ?(bufsize=4096) ic : _ Reader.t =
let buf = Bytes.make bufsize ' ' in let buf = Bytes.make bufsize ' ' in
fun () -> let p = Pipe.create ~max_size:0 () in
let rec send() =
Lwt_io.read_into ic buf 0 bufsize >>= fun n -> Lwt_io.read_into ic buf 0 bufsize >>= fun n ->
if n = 0 then ret_end if n = 0 then Pipe.close p
else else
Lwt.return (`Ok (Bytes.sub_string buf 0 n)) Writer.write p (Bytes.sub_string buf 0 n) >>= fun () ->
send ()
in Lwt.async send;
p
let read_lines ic () = let read_lines ic =
let p = Pipe.create () in
let rec send () =
Lwt_io.read_line_opt ic >>= function Lwt_io.read_line_opt ic >>= function
| None -> ret_end | None -> Pipe.close p
| Some line -> Lwt.return (`Ok line) | Some line -> Writer.write p line >>= fun () -> send ()
in
Lwt.async send;
p
let write oc = function let write oc =
| `Ok s -> Lwt_io.write oc s let p = Pipe.create () in
| `End | `Error _ -> Lwt.return_unit Pipe.keep p (
Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ ->
Pipe.close p
);
p
let write_lines oc = function let write_lines oc =
| `Ok l -> Lwt_io.write_line oc l let p = Pipe.create () in
| `End | `Error _ -> Lwt.return_unit Pipe.keep p (
Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ ->
Pipe.close p
);
p
end end

View file

@ -44,22 +44,65 @@ module LwtErr : sig
val fail : string -> 'a t val fail : string -> 'a t
end end
exception Closed
module Pipe : sig
type ('a, +'perm) t constraint 'perm = [< `r | `w]
(** A pipe between producers of values of type 'a, and consumers of values
of type 'a. *)
val keep : _ t -> unit Lwt.t -> unit
(** [keep p fut] adds a pointer from [p] to [fut] so that [fut] is not
garbage-collected before [p] *)
val is_closed : _ t -> bool
val close : _ t -> unit Lwt.t
(** [close p] closes [p], which will not accept input anymore.
This sends [`End] to all readers connected to [p] *)
val close_async : _ t -> unit
(** Same as {!close} but closes in the background *)
val on_close : _ t -> unit Lwt.t
(** Evaluates once the pipe closes *)
val create : ?max_size:int -> unit -> ('a, 'perm) t
(** Create a new pipe.
@param max_size size of internal buffer. Default 0. *)
val connect : ('a, [>`r]) t -> ('a, [>`w]) t -> unit
(** [connect p1 p2] forwards every item output by [p1] into [p2]'s input
until [p1] is closed. *)
end
module Writer : sig module Writer : sig
type -'a t type 'a t = ('a, [`w]) Pipe.t
val write : 'a t -> 'a -> unit Lwt.t val write : 'a t -> 'a -> unit Lwt.t
(** @raise Pipe.Closed if the writer is closed *)
val write_list : 'a t -> 'a list -> unit Lwt.t val write_list : 'a t -> 'a list -> unit Lwt.t
(** @raise Pipe.Closed if the writer is closed *)
val write_error : _ t -> string -> unit Lwt.t val write_error : _ t -> string -> unit Lwt.t
(** @raise Pipe.Closed if the writer is closed *)
val write_end : _ t -> unit Lwt.t
val map : f:('a -> 'b) -> 'b t -> 'a t val map : f:('a -> 'b) -> 'b t -> 'a t
(** Map values before writing them *)
val send_both : 'a t -> 'a t -> 'a t
(** [send_both a b] returns a writer [c] such that writing to [c]
writes to [a] and [b], and waits for those writes to succeed
before returning *)
val send_all : 'a t list -> 'a t
(** Generalized version of {!send_both}
@raise Invalid_argument if the list is empty *)
end end
module Reader : sig module Reader : sig
type +'a t type 'a t = ('a, [`r]) Pipe.t
val read : 'a t -> 'a step Lwt.t val read : 'a t -> 'a step Lwt.t
@ -75,38 +118,14 @@ module Reader : sig
val iter_s : f:('a -> unit Lwt.t) -> 'a t -> unit LwtErr.t val iter_s : f:('a -> unit Lwt.t) -> 'a t -> unit LwtErr.t
val merge : 'a t -> 'a t -> 'a t val merge_both : 'a t -> 'a t -> 'a t
(** Merge the two input streams *) (** Merge the two input streams in a non-specified order *)
val merge_all : 'a t list -> 'a t
(** Merge all the input streams
@raise Invalid_argument if the list is empty *)
end end
module Pipe : sig
type 'a t
(** A pipe between producers of values of type 'a, and consumers of values
of type 'a. *)
val reader : 'a t -> 'a Reader.t
val writer : 'a t -> 'a Writer.t
val keep : _ t -> unit Lwt.t -> unit
(** [keep p fut] adds a pointer from [p] to [fut] so that [fut] is not
garbage-collected before [p] *)
val create : ?max_size:int -> unit -> 'a t
(** Create a new pipe.
@param max_size size of internal buffer. Default 0. *)
val create_pair : ?max_size:int -> unit -> 'a Reader.t * 'a Writer.t
(** Create a pair [r, w] connect by a pipe *)
val pipe_into : 'a t -> 'a t -> unit Lwt.t
(** [connect p1 p2] forwards every item output by [p1] into [p2]'s input
until [`End] is reached. After [`End] is sent, the process stops. *)
end
val connect : 'a Reader.t -> 'a Writer.t -> unit Lwt.t
(** [connect r w] sends every item read from [r] into [w] *)
(** {2 Conversions} *) (** {2 Conversions} *)
val of_list : 'a list -> 'a Reader.t val of_list : 'a list -> 'a Reader.t