diff --git a/src/lwt/lwt_pipe.ml b/src/lwt/lwt_pipe.ml index 7998267e..f91b89fd 100644 --- a/src/lwt/lwt_pipe.ml +++ b/src/lwt/lwt_pipe.ml @@ -28,7 +28,6 @@ type 'a or_error = [`Ok of 'a | `Error of string] type 'a step = ['a or_error | `End] let (>>=) = Lwt.(>>=) -let (>|=) = Lwt.(>|=) module LwtErr = struct type 'a t = 'a or_error Lwt.t @@ -54,234 +53,343 @@ module LwtErr = struct ) x end -let step_map f = function - | `Ok x -> `Ok (f x) - | (`Error _ | `End) as e -> e - let (>>|=) = LwtErr.(>|=) let ret_end = Lwt.return `End +exception Closed + module Pipe = struct - type -'a writer = 'a step -> unit Lwt.t - - type +'a reader = unit -> 'a step Lwt.t - (* messages given to writers through the condition *) type 'a msg = | Send of 'a step Lwt.u (* send directly to reader *) | SendQueue (* push into queue *) + | Close (* close *) - type 'a t = { + type 'a inner_buf = + | Buf of 'a step Queue.t * int (* buf, max size *) + | NoBuf + + type ('a, +'perm) t = { + close : unit Lwt.u; + closed : unit Lwt.t; lock : Lwt_mutex.t; - queue : 'a step Queue.t; - max_size : int; + buf : 'a inner_buf; cond : 'a msg Lwt_condition.t; - mutable keep : unit Lwt.t list; (* do not GC *) - } + mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *) + } constraint 'perm = [< `r | `w] - let create ?(max_size=0) () = { - queue=Queue.create(); - max_size; - lock=Lwt_mutex.create(); - cond=Lwt_condition.create(); - keep=[]; - } + let create ?(max_size=0) () = + let buf = match max_size with + | 0 -> NoBuf + | n when n < 0 -> invalid_arg "max_size" + | n -> Buf (Queue.create (), n) + in + let closed, close = Lwt.wait () in + { + close; + closed; + buf; + lock=Lwt_mutex.create(); + cond=Lwt_condition.create(); + keep=[]; + } let keep p fut = p.keep <- fut :: p.keep + let is_closed p = not (Lwt.is_sleeping p.closed) + + let close p = + if is_closed p then Lwt.return_unit + else ( + Lwt.wakeup p.close (); (* evaluate *) + Lwt_condition.broadcast p.cond Close; + Lwt.join p.keep; + ) + + let close_async p = Lwt.async (fun () -> close p) + + let on_close p = p.closed + + (* try to take next element from buffer *) + let try_next_buf t = match t.buf with + | NoBuf -> None + | Buf (q, _) -> + if Queue.is_empty q then None + else Some (Queue.pop q) + + (* returns true if it could push successfully *) + let try_push_buf t x = match t.buf with + | NoBuf -> false + | Buf (q, max_size) when Queue.length q = max_size -> false + | Buf (q, _) -> Queue.push x q; true + (* read next one *) - let reader t () = + let read t = Lwt_mutex.with_lock t.lock (fun () -> - if Queue.is_empty t.queue - then ( + match try_next_buf t with + | None when is_closed t -> ret_end (* end of stream *) + | None -> let fut, send = Lwt.wait () in Lwt_condition.signal t.cond (Send send); fut - ) else ( - (* direct pop *) - assert (t.max_size > 0); - let x = Queue.pop t.queue in + | Some x -> Lwt_condition.signal t.cond SendQueue; (* queue isn't full anymore *) Lwt.return x - ) ) (* write a value *) - let writer t x = + let write t x = let rec try_write () = - if Queue.length t.queue < t.max_size then ( - Queue.push x t.queue; - Lwt.return_unit - ) else ( + if is_closed t then Lwt.fail Closed + else if try_push_buf t x + then Lwt.return_unit (* into buffer, do not wait *) + else ( (* wait for readers to consume the queue *) Lwt_condition.wait ~mutex:t.lock t.cond >>= fun msg -> match msg with - | Send s -> - Lwt.wakeup s x; - Lwt.return_unit - | SendQueue -> try_write () (* try again! *) + | Send s -> + Lwt.wakeup s x; (* sync with reader *) + Lwt.return_unit + | SendQueue -> try_write () (* try again! *) + | Close -> Lwt.fail Closed ) in Lwt_mutex.with_lock t.lock try_write - let create_pair ?max_size () = - let p = create ?max_size () in - reader p, writer p + let rec connect_rec r w = + read r >>= function + | `End -> Lwt.return_unit + | (`Error _ | `Ok _) as step -> + write w step >>= fun () -> + connect_rec r w - let rec connect_ (r:'a reader) (w:'a writer) = - r () >>= function - | `End -> w `End (* then stop *) - | (`Error _ | `Ok _) as step -> w step >>= fun () -> connect_ r w + let connect a b = + let fut = connect_rec a b in + keep b fut - let pipe_into p1 p2 = - connect_ (reader p1) (writer p2) + (* close a when b closes *) + let close_when_closed a b = + Lwt.on_success b.closed + (fun () -> close_async a) + + (* close a when every member of l closes *) + let close_when_all_closed a l = + let n = ref (List.length l) in + List.iter + (fun p -> Lwt.on_success p.closed + (fun () -> + decr n; + if !n = 0 then close_async a + ) + ) l end -let connect r w = Pipe.connect_ r w - module Writer = struct - type -'a t = 'a Pipe.writer + type 'a t = ('a, [`w]) Pipe.t - let write t x = t (`Ok x) + let write t x = Pipe.write t (`Ok x) - let write_error t msg = t (`Error msg) - - let write_end t = t `End + let write_error t msg = Pipe.write t (`Error msg) let rec write_list t l = match l with | [] -> Lwt.return_unit | x :: tail -> write t x >>= fun () -> write_list t tail - let map ~f t x = t (step_map f x) + let map ~f a = + let b = Pipe.create() in + let rec fwd () = + Pipe.read b >>= function + | `Ok x -> write a (f x) >>= fwd + | `Error msg -> write_error a msg >>= fwd + | `End -> Lwt.return_unit + in + Pipe.keep b (fwd()); + (* when a gets closed, close b too *) + Lwt.on_success (Pipe.on_close a) (fun () -> Pipe.close_async b); + b + + let send_all l = + if l = [] then invalid_arg "send_all"; + let res = Pipe.create () in + let rec fwd () = + Pipe.read res >>= function + | `End -> Lwt.return_unit + | `Ok x -> Lwt_list.iter_p (fun p -> write p x) l >>= fwd + | `Error msg -> Lwt_list.iter_p (fun p -> write_error p msg) l >>= fwd + in + (* do not GC before res dies; close res when any outputx is closed *) + Pipe.keep res (fwd ()); + List.iter (Pipe.close_when_closed res) l; + res + + let send_both a b = send_all [a; b] end module Reader = struct - type +'a t = 'a Pipe.reader + type 'a t = ('a, [`r]) Pipe.t - let read t = t () + let read = Pipe.read - let map ~f t () = - t () >|= (step_map f) + let map ~f a = + let b = Pipe.create () in + let rec fwd () = + Pipe.read a >>= function + | `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd + | (`Error _) as e -> Pipe.write b e >>= fwd + | `End -> Pipe.close b + in + Pipe.keep b (fwd()); + b - let rec filter_map ~f t () = - t () >>= function - | `Error msg -> LwtErr.fail msg - | `Ok x -> - begin match f x with - | Some y -> LwtErr.return y - | None -> filter_map ~f t () - end - | `End -> ret_end + let filter_map ~f a = + let b = Pipe.create () in + let rec fwd () = + Pipe.read a >>= function + | `Ok x -> + begin match f x with + | None -> fwd() + | Some y -> Pipe.write b (`Ok y) >>= fwd + end + | (`Error _) as e -> Pipe.write b e >>= fwd + | `End -> Pipe.close b + in + Pipe.keep b (fwd()); + b let rec fold ~f ~x t = - t () >>= function + read t >>= function | `End -> LwtErr.return x | `Error msg -> LwtErr.fail msg | `Ok y -> fold ~f ~x:(f x y) t let rec fold_s ~f ~x t = - t () >>= function + read t >>= function | `End -> LwtErr.return x | `Error msg -> LwtErr.fail msg | `Ok y -> f x y >>= fun x -> fold_s ~f ~x t let rec iter ~f t = - t () >>= function + read t >>= function | `End -> LwtErr.return_unit | `Error msg -> LwtErr.fail msg | `Ok x -> f x; iter ~f t let rec iter_s ~f t = - t () >>= function + read t >>= function | `End -> LwtErr.return_unit | `Error msg -> LwtErr.fail msg | `Ok x -> f x >>= fun () -> iter_s ~f t - let merge a b : _ t = - let r, w = Pipe.create_pair () in - Lwt.async (fun () -> Lwt.join [connect a w; connect b w]); - r + let merge_all l = + if l = [] then invalid_arg "merge_all"; + let res = Pipe.create () in + List.iter (fun p -> Pipe.connect p res) l; + (* connect res' input to all members of l; close res when they all close *) + Pipe.close_when_all_closed res l; + res + + let merge_both a b = merge_all [a; b] end (** {2 Conversions} *) let of_list l : _ Reader.t = - let l = ref l in - fun () -> match !l with - | [] -> ret_end - | x :: tail -> - l := tail; - Lwt.return (`Ok x) + let p = Pipe.create ~max_size:0 () in + Pipe.keep p (Lwt_list.iter_s (Writer.write p) l >>= fun () -> Pipe.close p); + p let of_array a = - let i = ref 0 in - fun () -> - if !i = Array.length a - then ret_end + let p = Pipe.create ~max_size:0 () in + let rec send i = + if i = Array.length a then Pipe.close p else ( - let x = a.(!i) in - incr i; - Lwt.return (`Ok x) + Writer.write p a.(i) >>= fun () -> + send (i+1) ) + in + Pipe.keep p (send 0); + p -let of_string s = - let i = ref 0 in - fun () -> - if !i = String.length s - then ret_end +let of_string a = + let p = Pipe.create ~max_size:0 () in + let rec send i = + if i = String.length a then Pipe.close p else ( - let x = String.get s !i in - incr i; - Lwt.return (`Ok x) + Writer.write p (String.get a i) >>= fun () -> + send (i+1) ) + in + Pipe.keep p (send 0); + p -let to_rev_list w = - Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] w +let to_rev_list r = + Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r -let to_list w = to_rev_list w >>|= List.rev +let to_list r = to_rev_list r >>|= List.rev -let to_list_exn w = - to_list w >>= function +let to_list_exn r = + to_list r >>= function | `Error msg -> Lwt.fail (Failure msg) | `Ok x -> Lwt.return x -let to_buffer buf : _ Writer.t = function - | `Ok c -> - Buffer.add_char buf c; +let to_buffer buf = + let p = Pipe.create () in + Pipe.keep p ( + Reader.iter ~f:(fun c -> Buffer.add_char buf c) p >>= fun _ -> Lwt.return_unit - | `Error _ | `End -> Lwt.return_unit + ); + p -let to_buffer_str buf = function - | `Ok s -> - Buffer.add_string buf s; +let to_buffer_str buf = + let p = Pipe.create () in + Pipe.keep p ( + Reader.iter ~f:(fun s -> Buffer.add_string buf s) p >>= fun _ -> Lwt.return_unit - | `Error _ | `End -> Lwt.return_unit + ); + p (** {2 Basic IO wrappers} *) module IO = struct let read ?(bufsize=4096) ic : _ Reader.t = let buf = Bytes.make bufsize ' ' in - fun () -> + let p = Pipe.create ~max_size:0 () in + let rec send() = Lwt_io.read_into ic buf 0 bufsize >>= fun n -> - if n = 0 then ret_end + if n = 0 then Pipe.close p else - Lwt.return (`Ok (Bytes.sub_string buf 0 n)) + Writer.write p (Bytes.sub_string buf 0 n) >>= fun () -> + send () + in Lwt.async send; + p - let read_lines ic () = - Lwt_io.read_line_opt ic >>= function - | None -> ret_end - | Some line -> Lwt.return (`Ok line) + let read_lines ic = + let p = Pipe.create () in + let rec send () = + Lwt_io.read_line_opt ic >>= function + | None -> Pipe.close p + | Some line -> Writer.write p line >>= fun () -> send () + in + Lwt.async send; + p - let write oc = function - | `Ok s -> Lwt_io.write oc s - | `End | `Error _ -> Lwt.return_unit + let write oc = + let p = Pipe.create () in + Pipe.keep p ( + Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ -> + Pipe.close p + ); + p - let write_lines oc = function - | `Ok l -> Lwt_io.write_line oc l - | `End | `Error _ -> Lwt.return_unit + let write_lines oc = + let p = Pipe.create () in + Pipe.keep p ( + Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ -> + Pipe.close p + ); + p end diff --git a/src/lwt/lwt_pipe.mli b/src/lwt/lwt_pipe.mli index 71bb73d1..836977db 100644 --- a/src/lwt/lwt_pipe.mli +++ b/src/lwt/lwt_pipe.mli @@ -44,22 +44,65 @@ module LwtErr : sig val fail : string -> 'a t end +exception Closed + +module Pipe : sig + type ('a, +'perm) t constraint 'perm = [< `r | `w] + (** A pipe between producers of values of type 'a, and consumers of values + of type 'a. *) + + val keep : _ t -> unit Lwt.t -> unit + (** [keep p fut] adds a pointer from [p] to [fut] so that [fut] is not + garbage-collected before [p] *) + + val is_closed : _ t -> bool + + val close : _ t -> unit Lwt.t + (** [close p] closes [p], which will not accept input anymore. + This sends [`End] to all readers connected to [p] *) + + val close_async : _ t -> unit + (** Same as {!close} but closes in the background *) + + val on_close : _ t -> unit Lwt.t + (** Evaluates once the pipe closes *) + + val create : ?max_size:int -> unit -> ('a, 'perm) t + (** Create a new pipe. + @param max_size size of internal buffer. Default 0. *) + + val connect : ('a, [>`r]) t -> ('a, [>`w]) t -> unit + (** [connect p1 p2] forwards every item output by [p1] into [p2]'s input + until [p1] is closed. *) +end + module Writer : sig - type -'a t + type 'a t = ('a, [`w]) Pipe.t val write : 'a t -> 'a -> unit Lwt.t + (** @raise Pipe.Closed if the writer is closed *) val write_list : 'a t -> 'a list -> unit Lwt.t + (** @raise Pipe.Closed if the writer is closed *) val write_error : _ t -> string -> unit Lwt.t - - val write_end : _ t -> unit Lwt.t + (** @raise Pipe.Closed if the writer is closed *) val map : f:('a -> 'b) -> 'b t -> 'a t + (** Map values before writing them *) + + val send_both : 'a t -> 'a t -> 'a t + (** [send_both a b] returns a writer [c] such that writing to [c] + writes to [a] and [b], and waits for those writes to succeed + before returning *) + + val send_all : 'a t list -> 'a t + (** Generalized version of {!send_both} + @raise Invalid_argument if the list is empty *) end module Reader : sig - type +'a t + type 'a t = ('a, [`r]) Pipe.t val read : 'a t -> 'a step Lwt.t @@ -75,38 +118,14 @@ module Reader : sig val iter_s : f:('a -> unit Lwt.t) -> 'a t -> unit LwtErr.t - val merge : 'a t -> 'a t -> 'a t - (** Merge the two input streams *) + val merge_both : 'a t -> 'a t -> 'a t + (** Merge the two input streams in a non-specified order *) + + val merge_all : 'a t list -> 'a t + (** Merge all the input streams + @raise Invalid_argument if the list is empty *) end -module Pipe : sig - type 'a t - (** A pipe between producers of values of type 'a, and consumers of values - of type 'a. *) - - val reader : 'a t -> 'a Reader.t - - val writer : 'a t -> 'a Writer.t - - val keep : _ t -> unit Lwt.t -> unit - (** [keep p fut] adds a pointer from [p] to [fut] so that [fut] is not - garbage-collected before [p] *) - - val create : ?max_size:int -> unit -> 'a t - (** Create a new pipe. - @param max_size size of internal buffer. Default 0. *) - - val create_pair : ?max_size:int -> unit -> 'a Reader.t * 'a Writer.t - (** Create a pair [r, w] connect by a pipe *) - - val pipe_into : 'a t -> 'a t -> unit Lwt.t - (** [connect p1 p2] forwards every item output by [p1] into [p2]'s input - until [`End] is reached. After [`End] is sent, the process stops. *) -end - -val connect : 'a Reader.t -> 'a Writer.t -> unit Lwt.t -(** [connect r w] sends every item read from [r] into [w] *) - (** {2 Conversions} *) val of_list : 'a list -> 'a Reader.t