mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-06 11:15:31 -05:00
wip: fix bugs in Lwt_pipe
This commit is contained in:
parent
c6b23890ec
commit
e41faaf91e
2 changed files with 81 additions and 85 deletions
|
|
@ -60,38 +60,26 @@ let ret_end = Lwt.return `End
|
||||||
exception Closed
|
exception Closed
|
||||||
|
|
||||||
module Pipe = struct
|
module Pipe = struct
|
||||||
(* messages given to writers through the condition *)
|
|
||||||
type 'a msg =
|
|
||||||
| Send of 'a step Lwt.u (* send directly to reader *)
|
|
||||||
| SendQueue (* push into queue *)
|
|
||||||
| Close (* close *)
|
|
||||||
|
|
||||||
type 'a inner_buf =
|
|
||||||
| Buf of 'a step Queue.t * int (* buf, max size *)
|
|
||||||
| NoBuf
|
|
||||||
|
|
||||||
type ('a, +'perm) t = {
|
type ('a, +'perm) t = {
|
||||||
close : unit Lwt.u;
|
close : unit Lwt.u;
|
||||||
closed : unit Lwt.t;
|
closed : unit Lwt.t;
|
||||||
lock : Lwt_mutex.t;
|
buf :
|
||||||
buf : 'a inner_buf;
|
[`Item of 'a step
|
||||||
cond : 'a msg Lwt_condition.t;
|
| `Block of 'a step * unit Lwt.u
|
||||||
|
] Queue.t; (* actions queued *)
|
||||||
|
max_size : int;
|
||||||
|
box : 'a step Lwt.u Lwt_mvar.t;
|
||||||
mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *)
|
mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *)
|
||||||
} constraint 'perm = [< `r | `w]
|
} constraint 'perm = [< `r | `w]
|
||||||
|
|
||||||
let create ?(max_size=0) () =
|
let create ?(max_size=0) () =
|
||||||
let buf = match max_size with
|
|
||||||
| 0 -> NoBuf
|
|
||||||
| n when n < 0 -> invalid_arg "max_size"
|
|
||||||
| n -> Buf (Queue.create (), n)
|
|
||||||
in
|
|
||||||
let closed, close = Lwt.wait () in
|
let closed, close = Lwt.wait () in
|
||||||
{
|
{
|
||||||
close;
|
close;
|
||||||
closed;
|
closed;
|
||||||
buf;
|
buf = Queue.create ();
|
||||||
lock=Lwt_mutex.create();
|
max_size;
|
||||||
cond=Lwt_condition.create();
|
box=Lwt_mvar.create_empty ();
|
||||||
keep=[];
|
keep=[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -103,65 +91,52 @@ module Pipe = struct
|
||||||
if is_closed p then Lwt.return_unit
|
if is_closed p then Lwt.return_unit
|
||||||
else (
|
else (
|
||||||
Lwt.wakeup p.close (); (* evaluate *)
|
Lwt.wakeup p.close (); (* evaluate *)
|
||||||
Lwt_condition.broadcast p.cond Close;
|
|
||||||
Lwt.join p.keep;
|
Lwt.join p.keep;
|
||||||
)
|
)
|
||||||
|
|
||||||
let close_async p = Lwt.async (fun () -> close p)
|
let close_async p = Lwt.async (fun () -> close p)
|
||||||
|
|
||||||
let on_close p = p.closed
|
let wait p = Lwt.map (fun _ -> ()) p.closed
|
||||||
|
|
||||||
(* try to take next element from buffer *)
|
(* try to take next element from buffer *)
|
||||||
let try_next_buf t = match t.buf with
|
let try_next_buf t =
|
||||||
| NoBuf -> None
|
if Queue.is_empty t.buf then None
|
||||||
| Buf (q, _) ->
|
else Some (Queue.pop t.buf)
|
||||||
if Queue.is_empty q then None
|
|
||||||
else Some (Queue.pop q)
|
|
||||||
|
|
||||||
(* returns true if it could push successfully *)
|
|
||||||
let try_push_buf t x = match t.buf with
|
|
||||||
| NoBuf -> false
|
|
||||||
| Buf (q, max_size) when Queue.length q = max_size -> false
|
|
||||||
| Buf (q, _) -> Queue.push x q; true
|
|
||||||
|
|
||||||
(* read next one *)
|
(* read next one *)
|
||||||
let read t =
|
let read t =
|
||||||
Lwt_mutex.with_lock t.lock
|
match try_next_buf t with
|
||||||
(fun () ->
|
| None when is_closed t -> ret_end (* end of stream *)
|
||||||
match try_next_buf t with
|
| None ->
|
||||||
| None when is_closed t -> ret_end (* end of stream *)
|
let fut, send = Lwt.wait () in
|
||||||
| None ->
|
Lwt_mvar.put t.box send >>= fun () ->
|
||||||
let fut, send = Lwt.wait () in
|
fut
|
||||||
Lwt_condition.signal t.cond (Send send);
|
| Some (`Item x) -> Lwt.return x
|
||||||
fut
|
| Some (`Block (x, signal_done)) ->
|
||||||
| Some x ->
|
Lwt.wakeup signal_done (); (* signal the writer it's done *)
|
||||||
Lwt_condition.signal t.cond SendQueue; (* queue isn't full anymore *)
|
Lwt.return x
|
||||||
Lwt.return x
|
|
||||||
)
|
(* TODO: signal writers when their value has less than max_size
|
||||||
|
steps before being read *)
|
||||||
|
|
||||||
(* write a value *)
|
(* write a value *)
|
||||||
let write t x =
|
let write t x =
|
||||||
let rec try_write () =
|
if is_closed t then Lwt.fail Closed
|
||||||
if is_closed t then Lwt.fail Closed
|
else if Queue.length t.buf < t.max_size
|
||||||
else if try_push_buf t x
|
then (
|
||||||
then Lwt.return_unit (* into buffer, do not wait *)
|
Queue.push (`Item x) t.buf;
|
||||||
else (
|
Lwt.return_unit (* into buffer, do not wait *)
|
||||||
(* wait for readers to consume the queue *)
|
) else (
|
||||||
Lwt_condition.wait ~mutex:t.lock t.cond >>= fun msg ->
|
let is_done, signal_done = Lwt.wait () in
|
||||||
match msg with
|
Queue.push (`Block (x, signal_done)) t.buf;
|
||||||
| Send s ->
|
is_done
|
||||||
Lwt.wakeup s x; (* sync with reader *)
|
)
|
||||||
Lwt.return_unit
|
|
||||||
| SendQueue -> try_write () (* try again! *)
|
|
||||||
| Close -> Lwt.fail Closed
|
|
||||||
)
|
|
||||||
in
|
|
||||||
Lwt_mutex.with_lock t.lock try_write
|
|
||||||
|
|
||||||
let rec connect_rec r w =
|
let rec connect_rec r w =
|
||||||
read r >>= function
|
read r >>= function
|
||||||
| `End -> Lwt.return_unit
|
| `End -> Lwt.return_unit
|
||||||
| (`Error _ | `Ok _) as step ->
|
| `Error _ as step -> write w step
|
||||||
|
| `Ok _ as step ->
|
||||||
write w step >>= fun () ->
|
write w step >>= fun () ->
|
||||||
connect_rec r w
|
connect_rec r w
|
||||||
|
|
||||||
|
|
@ -170,20 +145,20 @@ module Pipe = struct
|
||||||
keep b fut
|
keep b fut
|
||||||
|
|
||||||
(* close a when b closes *)
|
(* close a when b closes *)
|
||||||
let close_when_closed a b =
|
let link_close p ~after =
|
||||||
Lwt.on_success b.closed
|
Lwt.on_termination after.closed
|
||||||
(fun () -> close_async a)
|
(fun _ -> close_async p)
|
||||||
|
|
||||||
(* close a when every member of l closes *)
|
(* close a when every member of after closes *)
|
||||||
let close_when_all_closed a l =
|
let link_close_l p ~after =
|
||||||
let n = ref (List.length l) in
|
let n = ref (List.length after) in
|
||||||
List.iter
|
List.iter
|
||||||
(fun p -> Lwt.on_success p.closed
|
(fun p' -> Lwt.on_termination p'.closed
|
||||||
(fun () ->
|
(fun _ ->
|
||||||
decr n;
|
decr n;
|
||||||
if !n = 0 then close_async a
|
if !n = 0 then close_async p
|
||||||
)
|
)
|
||||||
) l
|
) after
|
||||||
end
|
end
|
||||||
|
|
||||||
module Writer = struct
|
module Writer = struct
|
||||||
|
|
@ -203,12 +178,12 @@ module Writer = struct
|
||||||
let rec fwd () =
|
let rec fwd () =
|
||||||
Pipe.read b >>= function
|
Pipe.read b >>= function
|
||||||
| `Ok x -> write a (f x) >>= fwd
|
| `Ok x -> write a (f x) >>= fwd
|
||||||
| `Error msg -> write_error a msg >>= fwd
|
| `Error msg -> write_error a msg >>= fun _ -> Pipe.close a
|
||||||
| `End -> Lwt.return_unit
|
| `End -> Lwt.return_unit
|
||||||
in
|
in
|
||||||
Pipe.keep b (fwd());
|
Pipe.keep b (fwd());
|
||||||
(* when a gets closed, close b too *)
|
(* when a gets closed, close b too *)
|
||||||
Lwt.on_success (Pipe.on_close a) (fun () -> Pipe.close_async b);
|
Pipe.link_close b ~after:a;
|
||||||
b
|
b
|
||||||
|
|
||||||
let send_all l =
|
let send_all l =
|
||||||
|
|
@ -222,7 +197,7 @@ module Writer = struct
|
||||||
in
|
in
|
||||||
(* do not GC before res dies; close res when any outputx is closed *)
|
(* do not GC before res dies; close res when any outputx is closed *)
|
||||||
Pipe.keep res (fwd ());
|
Pipe.keep res (fwd ());
|
||||||
List.iter (Pipe.close_when_closed res) l;
|
List.iter (fun out -> Pipe.link_close res ~after:out) l;
|
||||||
res
|
res
|
||||||
|
|
||||||
let send_both a b = send_all [a; b]
|
let send_both a b = send_all [a; b]
|
||||||
|
|
@ -238,7 +213,7 @@ module Reader = struct
|
||||||
let rec fwd () =
|
let rec fwd () =
|
||||||
Pipe.read a >>= function
|
Pipe.read a >>= function
|
||||||
| `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd
|
| `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd
|
||||||
| (`Error _) as e -> Pipe.write b e >>= fwd
|
| (`Error _) as e -> Pipe.write b e >>= fun _ -> Pipe.close b
|
||||||
| `End -> Pipe.close b
|
| `End -> Pipe.close b
|
||||||
in
|
in
|
||||||
Pipe.keep b (fwd());
|
Pipe.keep b (fwd());
|
||||||
|
|
@ -253,7 +228,7 @@ module Reader = struct
|
||||||
| None -> fwd()
|
| None -> fwd()
|
||||||
| Some y -> Pipe.write b (`Ok y) >>= fwd
|
| Some y -> Pipe.write b (`Ok y) >>= fwd
|
||||||
end
|
end
|
||||||
| (`Error _) as e -> Pipe.write b e >>= fwd
|
| (`Error _) as e -> Pipe.write b e >>= fun _ -> Pipe.close b
|
||||||
| `End -> Pipe.close b
|
| `End -> Pipe.close b
|
||||||
in
|
in
|
||||||
Pipe.keep b (fwd());
|
Pipe.keep b (fwd());
|
||||||
|
|
@ -289,12 +264,24 @@ module Reader = struct
|
||||||
let res = Pipe.create () in
|
let res = Pipe.create () in
|
||||||
List.iter (fun p -> Pipe.connect p res) l;
|
List.iter (fun p -> Pipe.connect p res) l;
|
||||||
(* connect res' input to all members of l; close res when they all close *)
|
(* connect res' input to all members of l; close res when they all close *)
|
||||||
Pipe.close_when_all_closed res l;
|
Pipe.link_close_l res ~after:l;
|
||||||
res
|
res
|
||||||
|
|
||||||
let merge_both a b = merge_all [a; b]
|
let merge_both a b = merge_all [a; b]
|
||||||
|
|
||||||
|
let append a b =
|
||||||
|
let c = Pipe.create () in
|
||||||
|
Pipe.connect a c;
|
||||||
|
Lwt.on_success (Pipe.wait a)
|
||||||
|
(fun () ->
|
||||||
|
Pipe.connect b c;
|
||||||
|
Pipe.link_close c ~after:b (* once a and b finished, c is too *)
|
||||||
|
);
|
||||||
|
c
|
||||||
end
|
end
|
||||||
|
|
||||||
|
let connect = Pipe.connect
|
||||||
|
|
||||||
(** {2 Conversions} *)
|
(** {2 Conversions} *)
|
||||||
|
|
||||||
let of_list l : _ Reader.t =
|
let of_list l : _ Reader.t =
|
||||||
|
|
@ -326,10 +313,10 @@ let of_string a =
|
||||||
Pipe.keep p (send 0);
|
Pipe.keep p (send 0);
|
||||||
p
|
p
|
||||||
|
|
||||||
let to_rev_list r =
|
let to_list_rev r =
|
||||||
Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r
|
Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r
|
||||||
|
|
||||||
let to_list r = to_rev_list r >>|= List.rev
|
let to_list r = to_list_rev r >>|= List.rev
|
||||||
|
|
||||||
let to_list_exn r =
|
let to_list_exn r =
|
||||||
to_list r >>= function
|
to_list r >>= function
|
||||||
|
|
@ -381,6 +368,7 @@ module IO = struct
|
||||||
let p = Pipe.create () in
|
let p = Pipe.create () in
|
||||||
Pipe.keep p (
|
Pipe.keep p (
|
||||||
Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ ->
|
Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ ->
|
||||||
|
Lwt_io.flush oc >>= fun () ->
|
||||||
Pipe.close p
|
Pipe.close p
|
||||||
);
|
);
|
||||||
p
|
p
|
||||||
|
|
@ -389,6 +377,7 @@ module IO = struct
|
||||||
let p = Pipe.create () in
|
let p = Pipe.create () in
|
||||||
Pipe.keep p (
|
Pipe.keep p (
|
||||||
Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ ->
|
Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ ->
|
||||||
|
Lwt_io.flush oc >>= fun () ->
|
||||||
Pipe.close p
|
Pipe.close p
|
||||||
);
|
);
|
||||||
p
|
p
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
{- Pipe: a possibly buffered channel through which readers and writer communicate}
|
{- Pipe: a possibly buffered channel through which readers and writer communicate}
|
||||||
{- Reader: accepts values, produces effects}
|
{- Reader: accepts values, produces effects}
|
||||||
{- Writer: yield values}
|
{- Writer: yield values}
|
||||||
|
|
||||||
|
@since NEXT_RELEASE
|
||||||
*)
|
*)
|
||||||
|
|
||||||
type 'a or_error = [`Ok of 'a | `Error of string]
|
type 'a or_error = [`Ok of 'a | `Error of string]
|
||||||
|
|
@ -64,7 +66,7 @@ module Pipe : sig
|
||||||
val close_async : _ t -> unit
|
val close_async : _ t -> unit
|
||||||
(** Same as {!close} but closes in the background *)
|
(** Same as {!close} but closes in the background *)
|
||||||
|
|
||||||
val on_close : _ t -> unit Lwt.t
|
val wait : _ t -> unit Lwt.t
|
||||||
(** Evaluates once the pipe closes *)
|
(** Evaluates once the pipe closes *)
|
||||||
|
|
||||||
val create : ?max_size:int -> unit -> ('a, 'perm) t
|
val create : ?max_size:int -> unit -> ('a, 'perm) t
|
||||||
|
|
@ -124,8 +126,13 @@ module Reader : sig
|
||||||
val merge_all : 'a t list -> 'a t
|
val merge_all : 'a t list -> 'a t
|
||||||
(** Merge all the input streams
|
(** Merge all the input streams
|
||||||
@raise Invalid_argument if the list is empty *)
|
@raise Invalid_argument if the list is empty *)
|
||||||
|
|
||||||
|
val append : 'a t -> 'a t -> 'a t
|
||||||
end
|
end
|
||||||
|
|
||||||
|
val connect : 'a Reader.t -> 'a Writer.t -> unit
|
||||||
|
(** Handy synonym to {!Pipe.connect} *)
|
||||||
|
|
||||||
(** {2 Conversions} *)
|
(** {2 Conversions} *)
|
||||||
|
|
||||||
val of_list : 'a list -> 'a Reader.t
|
val of_list : 'a list -> 'a Reader.t
|
||||||
|
|
@ -134,7 +141,7 @@ val of_array : 'a array -> 'a Reader.t
|
||||||
|
|
||||||
val of_string : string -> char Reader.t
|
val of_string : string -> char Reader.t
|
||||||
|
|
||||||
val to_rev_list : 'a Reader.t -> 'a list LwtErr.t
|
val to_list_rev : 'a Reader.t -> 'a list LwtErr.t
|
||||||
|
|
||||||
val to_list : 'a Reader.t -> 'a list LwtErr.t
|
val to_list : 'a Reader.t -> 'a list LwtErr.t
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue