wip: fix bugs in Lwt_pipe

This commit is contained in:
Simon Cruanes 2015-02-19 18:31:49 +01:00
parent c6b23890ec
commit e41faaf91e
2 changed files with 81 additions and 85 deletions

View file

@ -60,38 +60,26 @@ let ret_end = Lwt.return `End
exception Closed exception Closed
module Pipe = struct module Pipe = struct
(* messages given to writers through the condition *)
type 'a msg =
| Send of 'a step Lwt.u (* send directly to reader *)
| SendQueue (* push into queue *)
| Close (* close *)
type 'a inner_buf =
| Buf of 'a step Queue.t * int (* buf, max size *)
| NoBuf
type ('a, +'perm) t = { type ('a, +'perm) t = {
close : unit Lwt.u; close : unit Lwt.u;
closed : unit Lwt.t; closed : unit Lwt.t;
lock : Lwt_mutex.t; buf :
buf : 'a inner_buf; [`Item of 'a step
cond : 'a msg Lwt_condition.t; | `Block of 'a step * unit Lwt.u
] Queue.t; (* actions queued *)
max_size : int;
box : 'a step Lwt.u Lwt_mvar.t;
mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *) mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *)
} constraint 'perm = [< `r | `w] } constraint 'perm = [< `r | `w]
let create ?(max_size=0) () = let create ?(max_size=0) () =
let buf = match max_size with
| 0 -> NoBuf
| n when n < 0 -> invalid_arg "max_size"
| n -> Buf (Queue.create (), n)
in
let closed, close = Lwt.wait () in let closed, close = Lwt.wait () in
{ {
close; close;
closed; closed;
buf; buf = Queue.create ();
lock=Lwt_mutex.create(); max_size;
cond=Lwt_condition.create(); box=Lwt_mvar.create_empty ();
keep=[]; keep=[];
} }
@ -103,65 +91,52 @@ module Pipe = struct
if is_closed p then Lwt.return_unit if is_closed p then Lwt.return_unit
else ( else (
Lwt.wakeup p.close (); (* evaluate *) Lwt.wakeup p.close (); (* evaluate *)
Lwt_condition.broadcast p.cond Close;
Lwt.join p.keep; Lwt.join p.keep;
) )
let close_async p = Lwt.async (fun () -> close p) let close_async p = Lwt.async (fun () -> close p)
let on_close p = p.closed let wait p = Lwt.map (fun _ -> ()) p.closed
(* try to take next element from buffer *) (* try to take next element from buffer *)
let try_next_buf t = match t.buf with let try_next_buf t =
| NoBuf -> None if Queue.is_empty t.buf then None
| Buf (q, _) -> else Some (Queue.pop t.buf)
if Queue.is_empty q then None
else Some (Queue.pop q)
(* returns true if it could push successfully *)
let try_push_buf t x = match t.buf with
| NoBuf -> false
| Buf (q, max_size) when Queue.length q = max_size -> false
| Buf (q, _) -> Queue.push x q; true
(* read next one *) (* read next one *)
let read t = let read t =
Lwt_mutex.with_lock t.lock match try_next_buf t with
(fun () -> | None when is_closed t -> ret_end (* end of stream *)
match try_next_buf t with | None ->
| None when is_closed t -> ret_end (* end of stream *) let fut, send = Lwt.wait () in
| None -> Lwt_mvar.put t.box send >>= fun () ->
let fut, send = Lwt.wait () in fut
Lwt_condition.signal t.cond (Send send); | Some (`Item x) -> Lwt.return x
fut | Some (`Block (x, signal_done)) ->
| Some x -> Lwt.wakeup signal_done (); (* signal the writer it's done *)
Lwt_condition.signal t.cond SendQueue; (* queue isn't full anymore *) Lwt.return x
Lwt.return x
) (* TODO: signal writers when their value has less than max_size
steps before being read *)
(* write a value *) (* write a value *)
let write t x = let write t x =
let rec try_write () = if is_closed t then Lwt.fail Closed
if is_closed t then Lwt.fail Closed else if Queue.length t.buf < t.max_size
else if try_push_buf t x then (
then Lwt.return_unit (* into buffer, do not wait *) Queue.push (`Item x) t.buf;
else ( Lwt.return_unit (* into buffer, do not wait *)
(* wait for readers to consume the queue *) ) else (
Lwt_condition.wait ~mutex:t.lock t.cond >>= fun msg -> let is_done, signal_done = Lwt.wait () in
match msg with Queue.push (`Block (x, signal_done)) t.buf;
| Send s -> is_done
Lwt.wakeup s x; (* sync with reader *) )
Lwt.return_unit
| SendQueue -> try_write () (* try again! *)
| Close -> Lwt.fail Closed
)
in
Lwt_mutex.with_lock t.lock try_write
let rec connect_rec r w = let rec connect_rec r w =
read r >>= function read r >>= function
| `End -> Lwt.return_unit | `End -> Lwt.return_unit
| (`Error _ | `Ok _) as step -> | `Error _ as step -> write w step
| `Ok _ as step ->
write w step >>= fun () -> write w step >>= fun () ->
connect_rec r w connect_rec r w
@ -170,20 +145,20 @@ module Pipe = struct
keep b fut keep b fut
(* close a when b closes *) (* close a when b closes *)
let close_when_closed a b = let link_close p ~after =
Lwt.on_success b.closed Lwt.on_termination after.closed
(fun () -> close_async a) (fun _ -> close_async p)
(* close a when every member of l closes *) (* close a when every member of after closes *)
let close_when_all_closed a l = let link_close_l p ~after =
let n = ref (List.length l) in let n = ref (List.length after) in
List.iter List.iter
(fun p -> Lwt.on_success p.closed (fun p' -> Lwt.on_termination p'.closed
(fun () -> (fun _ ->
decr n; decr n;
if !n = 0 then close_async a if !n = 0 then close_async p
) )
) l ) after
end end
module Writer = struct module Writer = struct
@ -203,12 +178,12 @@ module Writer = struct
let rec fwd () = let rec fwd () =
Pipe.read b >>= function Pipe.read b >>= function
| `Ok x -> write a (f x) >>= fwd | `Ok x -> write a (f x) >>= fwd
| `Error msg -> write_error a msg >>= fwd | `Error msg -> write_error a msg >>= fun _ -> Pipe.close a
| `End -> Lwt.return_unit | `End -> Lwt.return_unit
in in
Pipe.keep b (fwd()); Pipe.keep b (fwd());
(* when a gets closed, close b too *) (* when a gets closed, close b too *)
Lwt.on_success (Pipe.on_close a) (fun () -> Pipe.close_async b); Pipe.link_close b ~after:a;
b b
let send_all l = let send_all l =
@ -222,7 +197,7 @@ module Writer = struct
in in
(* do not GC before res dies; close res when any outputx is closed *) (* do not GC before res dies; close res when any outputx is closed *)
Pipe.keep res (fwd ()); Pipe.keep res (fwd ());
List.iter (Pipe.close_when_closed res) l; List.iter (fun out -> Pipe.link_close res ~after:out) l;
res res
let send_both a b = send_all [a; b] let send_both a b = send_all [a; b]
@ -238,7 +213,7 @@ module Reader = struct
let rec fwd () = let rec fwd () =
Pipe.read a >>= function Pipe.read a >>= function
| `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd | `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd
| (`Error _) as e -> Pipe.write b e >>= fwd | (`Error _) as e -> Pipe.write b e >>= fun _ -> Pipe.close b
| `End -> Pipe.close b | `End -> Pipe.close b
in in
Pipe.keep b (fwd()); Pipe.keep b (fwd());
@ -253,7 +228,7 @@ module Reader = struct
| None -> fwd() | None -> fwd()
| Some y -> Pipe.write b (`Ok y) >>= fwd | Some y -> Pipe.write b (`Ok y) >>= fwd
end end
| (`Error _) as e -> Pipe.write b e >>= fwd | (`Error _) as e -> Pipe.write b e >>= fun _ -> Pipe.close b
| `End -> Pipe.close b | `End -> Pipe.close b
in in
Pipe.keep b (fwd()); Pipe.keep b (fwd());
@ -289,12 +264,24 @@ module Reader = struct
let res = Pipe.create () in let res = Pipe.create () in
List.iter (fun p -> Pipe.connect p res) l; List.iter (fun p -> Pipe.connect p res) l;
(* connect res' input to all members of l; close res when they all close *) (* connect res' input to all members of l; close res when they all close *)
Pipe.close_when_all_closed res l; Pipe.link_close_l res ~after:l;
res res
let merge_both a b = merge_all [a; b] let merge_both a b = merge_all [a; b]
let append a b =
let c = Pipe.create () in
Pipe.connect a c;
Lwt.on_success (Pipe.wait a)
(fun () ->
Pipe.connect b c;
Pipe.link_close c ~after:b (* once a and b finished, c is too *)
);
c
end end
let connect = Pipe.connect
(** {2 Conversions} *) (** {2 Conversions} *)
let of_list l : _ Reader.t = let of_list l : _ Reader.t =
@ -326,10 +313,10 @@ let of_string a =
Pipe.keep p (send 0); Pipe.keep p (send 0);
p p
let to_rev_list r = let to_list_rev r =
Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r
let to_list r = to_rev_list r >>|= List.rev let to_list r = to_list_rev r >>|= List.rev
let to_list_exn r = let to_list_exn r =
to_list r >>= function to_list r >>= function
@ -381,6 +368,7 @@ module IO = struct
let p = Pipe.create () in let p = Pipe.create () in
Pipe.keep p ( Pipe.keep p (
Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ -> Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ ->
Lwt_io.flush oc >>= fun () ->
Pipe.close p Pipe.close p
); );
p p
@ -389,6 +377,7 @@ module IO = struct
let p = Pipe.create () in let p = Pipe.create () in
Pipe.keep p ( Pipe.keep p (
Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ -> Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ ->
Lwt_io.flush oc >>= fun () ->
Pipe.close p Pipe.close p
); );
p p

View file

@ -31,6 +31,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{- Pipe: a possibly buffered channel through which readers and writer communicate} {- Pipe: a possibly buffered channel through which readers and writer communicate}
{- Reader: accepts values, produces effects} {- Reader: accepts values, produces effects}
{- Writer: yield values} {- Writer: yield values}
@since NEXT_RELEASE
*) *)
type 'a or_error = [`Ok of 'a | `Error of string] type 'a or_error = [`Ok of 'a | `Error of string]
@ -64,7 +66,7 @@ module Pipe : sig
val close_async : _ t -> unit val close_async : _ t -> unit
(** Same as {!close} but closes in the background *) (** Same as {!close} but closes in the background *)
val on_close : _ t -> unit Lwt.t val wait : _ t -> unit Lwt.t
(** Evaluates once the pipe closes *) (** Evaluates once the pipe closes *)
val create : ?max_size:int -> unit -> ('a, 'perm) t val create : ?max_size:int -> unit -> ('a, 'perm) t
@ -124,8 +126,13 @@ module Reader : sig
val merge_all : 'a t list -> 'a t val merge_all : 'a t list -> 'a t
(** Merge all the input streams (** Merge all the input streams
@raise Invalid_argument if the list is empty *) @raise Invalid_argument if the list is empty *)
val append : 'a t -> 'a t -> 'a t
end end
val connect : 'a Reader.t -> 'a Writer.t -> unit
(** Handy synonym to {!Pipe.connect} *)
(** {2 Conversions} *) (** {2 Conversions} *)
val of_list : 'a list -> 'a Reader.t val of_list : 'a list -> 'a Reader.t
@ -134,7 +141,7 @@ val of_array : 'a array -> 'a Reader.t
val of_string : string -> char Reader.t val of_string : string -> char Reader.t
val to_rev_list : 'a Reader.t -> 'a list LwtErr.t val to_list_rev : 'a Reader.t -> 'a list LwtErr.t
val to_list : 'a Reader.t -> 'a list LwtErr.t val to_list : 'a Reader.t -> 'a list LwtErr.t