diff --git a/src/lwt/lwt_pipe.ml b/src/lwt/lwt_pipe.ml index f91b89fd..c48209c1 100644 --- a/src/lwt/lwt_pipe.ml +++ b/src/lwt/lwt_pipe.ml @@ -60,38 +60,26 @@ let ret_end = Lwt.return `End exception Closed module Pipe = struct - (* messages given to writers through the condition *) - type 'a msg = - | Send of 'a step Lwt.u (* send directly to reader *) - | SendQueue (* push into queue *) - | Close (* close *) - - type 'a inner_buf = - | Buf of 'a step Queue.t * int (* buf, max size *) - | NoBuf - type ('a, +'perm) t = { close : unit Lwt.u; closed : unit Lwt.t; - lock : Lwt_mutex.t; - buf : 'a inner_buf; - cond : 'a msg Lwt_condition.t; + buf : + [`Item of 'a step + | `Block of 'a step * unit Lwt.u + ] Queue.t; (* actions queued *) + max_size : int; + box : 'a step Lwt.u Lwt_mvar.t; mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *) } constraint 'perm = [< `r | `w] let create ?(max_size=0) () = - let buf = match max_size with - | 0 -> NoBuf - | n when n < 0 -> invalid_arg "max_size" - | n -> Buf (Queue.create (), n) - in let closed, close = Lwt.wait () in { close; closed; - buf; - lock=Lwt_mutex.create(); - cond=Lwt_condition.create(); + buf = Queue.create (); + max_size; + box=Lwt_mvar.create_empty (); keep=[]; } @@ -103,65 +91,52 @@ module Pipe = struct if is_closed p then Lwt.return_unit else ( Lwt.wakeup p.close (); (* evaluate *) - Lwt_condition.broadcast p.cond Close; Lwt.join p.keep; ) let close_async p = Lwt.async (fun () -> close p) - let on_close p = p.closed + let wait p = Lwt.map (fun _ -> ()) p.closed (* try to take next element from buffer *) - let try_next_buf t = match t.buf with - | NoBuf -> None - | Buf (q, _) -> - if Queue.is_empty q then None - else Some (Queue.pop q) - - (* returns true if it could push successfully *) - let try_push_buf t x = match t.buf with - | NoBuf -> false - | Buf (q, max_size) when Queue.length q = max_size -> false - | Buf (q, _) -> Queue.push x q; true + let try_next_buf t = + if Queue.is_empty t.buf then None + else Some (Queue.pop t.buf) (* read next one *) let read t = - Lwt_mutex.with_lock t.lock - (fun () -> - match try_next_buf t with - | None when is_closed t -> ret_end (* end of stream *) - | None -> - let fut, send = Lwt.wait () in - Lwt_condition.signal t.cond (Send send); - fut - | Some x -> - Lwt_condition.signal t.cond SendQueue; (* queue isn't full anymore *) - Lwt.return x - ) + match try_next_buf t with + | None when is_closed t -> ret_end (* end of stream *) + | None -> + let fut, send = Lwt.wait () in + Lwt_mvar.put t.box send >>= fun () -> + fut + | Some (`Item x) -> Lwt.return x + | Some (`Block (x, signal_done)) -> + Lwt.wakeup signal_done (); (* signal the writer it's done *) + Lwt.return x + + (* TODO: signal writers when their value has less than max_size + steps before being read *) (* write a value *) let write t x = - let rec try_write () = - if is_closed t then Lwt.fail Closed - else if try_push_buf t x - then Lwt.return_unit (* into buffer, do not wait *) - else ( - (* wait for readers to consume the queue *) - Lwt_condition.wait ~mutex:t.lock t.cond >>= fun msg -> - match msg with - | Send s -> - Lwt.wakeup s x; (* sync with reader *) - Lwt.return_unit - | SendQueue -> try_write () (* try again! *) - | Close -> Lwt.fail Closed - ) - in - Lwt_mutex.with_lock t.lock try_write + if is_closed t then Lwt.fail Closed + else if Queue.length t.buf < t.max_size + then ( + Queue.push (`Item x) t.buf; + Lwt.return_unit (* into buffer, do not wait *) + ) else ( + let is_done, signal_done = Lwt.wait () in + Queue.push (`Block (x, signal_done)) t.buf; + is_done + ) let rec connect_rec r w = read r >>= function | `End -> Lwt.return_unit - | (`Error _ | `Ok _) as step -> + | `Error _ as step -> write w step + | `Ok _ as step -> write w step >>= fun () -> connect_rec r w @@ -170,20 +145,20 @@ module Pipe = struct keep b fut (* close a when b closes *) - let close_when_closed a b = - Lwt.on_success b.closed - (fun () -> close_async a) + let link_close p ~after = + Lwt.on_termination after.closed + (fun _ -> close_async p) - (* close a when every member of l closes *) - let close_when_all_closed a l = - let n = ref (List.length l) in + (* close a when every member of after closes *) + let link_close_l p ~after = + let n = ref (List.length after) in List.iter - (fun p -> Lwt.on_success p.closed - (fun () -> - decr n; - if !n = 0 then close_async a - ) - ) l + (fun p' -> Lwt.on_termination p'.closed + (fun _ -> + decr n; + if !n = 0 then close_async p + ) + ) after end module Writer = struct @@ -203,12 +178,12 @@ module Writer = struct let rec fwd () = Pipe.read b >>= function | `Ok x -> write a (f x) >>= fwd - | `Error msg -> write_error a msg >>= fwd + | `Error msg -> write_error a msg >>= fun _ -> Pipe.close a | `End -> Lwt.return_unit in Pipe.keep b (fwd()); (* when a gets closed, close b too *) - Lwt.on_success (Pipe.on_close a) (fun () -> Pipe.close_async b); + Pipe.link_close b ~after:a; b let send_all l = @@ -222,7 +197,7 @@ module Writer = struct in (* do not GC before res dies; close res when any outputx is closed *) Pipe.keep res (fwd ()); - List.iter (Pipe.close_when_closed res) l; + List.iter (fun out -> Pipe.link_close res ~after:out) l; res let send_both a b = send_all [a; b] @@ -238,7 +213,7 @@ module Reader = struct let rec fwd () = Pipe.read a >>= function | `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd - | (`Error _) as e -> Pipe.write b e >>= fwd + | (`Error _) as e -> Pipe.write b e >>= fun _ -> Pipe.close b | `End -> Pipe.close b in Pipe.keep b (fwd()); @@ -253,7 +228,7 @@ module Reader = struct | None -> fwd() | Some y -> Pipe.write b (`Ok y) >>= fwd end - | (`Error _) as e -> Pipe.write b e >>= fwd + | (`Error _) as e -> Pipe.write b e >>= fun _ -> Pipe.close b | `End -> Pipe.close b in Pipe.keep b (fwd()); @@ -289,12 +264,24 @@ module Reader = struct let res = Pipe.create () in List.iter (fun p -> Pipe.connect p res) l; (* connect res' input to all members of l; close res when they all close *) - Pipe.close_when_all_closed res l; + Pipe.link_close_l res ~after:l; res let merge_both a b = merge_all [a; b] + + let append a b = + let c = Pipe.create () in + Pipe.connect a c; + Lwt.on_success (Pipe.wait a) + (fun () -> + Pipe.connect b c; + Pipe.link_close c ~after:b (* once a and b finished, c is too *) + ); + c end +let connect = Pipe.connect + (** {2 Conversions} *) let of_list l : _ Reader.t = @@ -326,10 +313,10 @@ let of_string a = Pipe.keep p (send 0); p -let to_rev_list r = +let to_list_rev r = Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r -let to_list r = to_rev_list r >>|= List.rev +let to_list r = to_list_rev r >>|= List.rev let to_list_exn r = to_list r >>= function @@ -381,6 +368,7 @@ module IO = struct let p = Pipe.create () in Pipe.keep p ( Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ -> + Lwt_io.flush oc >>= fun () -> Pipe.close p ); p @@ -389,6 +377,7 @@ module IO = struct let p = Pipe.create () in Pipe.keep p ( Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ -> + Lwt_io.flush oc >>= fun () -> Pipe.close p ); p diff --git a/src/lwt/lwt_pipe.mli b/src/lwt/lwt_pipe.mli index 836977db..ae10cbd7 100644 --- a/src/lwt/lwt_pipe.mli +++ b/src/lwt/lwt_pipe.mli @@ -31,6 +31,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. {- Pipe: a possibly buffered channel through which readers and writer communicate} {- Reader: accepts values, produces effects} {- Writer: yield values} + + @since NEXT_RELEASE *) type 'a or_error = [`Ok of 'a | `Error of string] @@ -64,7 +66,7 @@ module Pipe : sig val close_async : _ t -> unit (** Same as {!close} but closes in the background *) - val on_close : _ t -> unit Lwt.t + val wait : _ t -> unit Lwt.t (** Evaluates once the pipe closes *) val create : ?max_size:int -> unit -> ('a, 'perm) t @@ -124,8 +126,13 @@ module Reader : sig val merge_all : 'a t list -> 'a t (** Merge all the input streams @raise Invalid_argument if the list is empty *) + + val append : 'a t -> 'a t -> 'a t end +val connect : 'a Reader.t -> 'a Writer.t -> unit +(** Handy synonym to {!Pipe.connect} *) + (** {2 Conversions} *) val of_list : 'a list -> 'a Reader.t @@ -134,7 +141,7 @@ val of_array : 'a array -> 'a Reader.t val of_string : string -> char Reader.t -val to_rev_list : 'a Reader.t -> 'a list LwtErr.t +val to_list_rev : 'a Reader.t -> 'a list LwtErr.t val to_list : 'a Reader.t -> 'a list LwtErr.t