ocaml-containers/src/lwt/lwt_pipe.ml

395 lines
10 KiB
OCaml

(*
copyright (c) 2013-2014, simon cruanes
all rights reserved.
redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer. redistributions in binary
form must reproduce the above copyright notice, this list of conditions and the
following disclaimer in the documentation and/or other materials provided with
the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*)
type 'a or_error = [`Ok of 'a | `Error of string]
type 'a step = ['a or_error | `End]
let (>>=) = Lwt.(>>=)
module LwtErr = struct
type 'a t = 'a or_error Lwt.t
let return x = Lwt.return (`Ok x)
let return_unit = Lwt.return (`Ok ())
let fail msg = Lwt.return (`Error msg)
let (>>=) x f =
Lwt.bind x
(function
| `Error msg -> fail msg
| `Ok y -> f y
)
let (>|=) x f =
Lwt.map
(function
| `Error _ as e -> e
| `Ok x -> `Ok (f x)
) x
end
let (>>|=) = LwtErr.(>|=)
let ret_end = Lwt.return `End
exception Closed
module Pipe = struct
(* messages given to writers through the condition *)
type 'a msg =
| Send of 'a step Lwt.u (* send directly to reader *)
| SendQueue (* push into queue *)
| Close (* close *)
type 'a inner_buf =
| Buf of 'a step Queue.t * int (* buf, max size *)
| NoBuf
type ('a, +'perm) t = {
close : unit Lwt.u;
closed : unit Lwt.t;
lock : Lwt_mutex.t;
buf : 'a inner_buf;
cond : 'a msg Lwt_condition.t;
mutable keep : unit Lwt.t list; (* do not GC, and wait for completion *)
} constraint 'perm = [< `r | `w]
let create ?(max_size=0) () =
let buf = match max_size with
| 0 -> NoBuf
| n when n < 0 -> invalid_arg "max_size"
| n -> Buf (Queue.create (), n)
in
let closed, close = Lwt.wait () in
{
close;
closed;
buf;
lock=Lwt_mutex.create();
cond=Lwt_condition.create();
keep=[];
}
let keep p fut = p.keep <- fut :: p.keep
let is_closed p = not (Lwt.is_sleeping p.closed)
let close p =
if is_closed p then Lwt.return_unit
else (
Lwt.wakeup p.close (); (* evaluate *)
Lwt_condition.broadcast p.cond Close;
Lwt.join p.keep;
)
let close_async p = Lwt.async (fun () -> close p)
let on_close p = p.closed
(* try to take next element from buffer *)
let try_next_buf t = match t.buf with
| NoBuf -> None
| Buf (q, _) ->
if Queue.is_empty q then None
else Some (Queue.pop q)
(* returns true if it could push successfully *)
let try_push_buf t x = match t.buf with
| NoBuf -> false
| Buf (q, max_size) when Queue.length q = max_size -> false
| Buf (q, _) -> Queue.push x q; true
(* read next one *)
let read t =
Lwt_mutex.with_lock t.lock
(fun () ->
match try_next_buf t with
| None when is_closed t -> ret_end (* end of stream *)
| None ->
let fut, send = Lwt.wait () in
Lwt_condition.signal t.cond (Send send);
fut
| Some x ->
Lwt_condition.signal t.cond SendQueue; (* queue isn't full anymore *)
Lwt.return x
)
(* write a value *)
let write t x =
let rec try_write () =
if is_closed t then Lwt.fail Closed
else if try_push_buf t x
then Lwt.return_unit (* into buffer, do not wait *)
else (
(* wait for readers to consume the queue *)
Lwt_condition.wait ~mutex:t.lock t.cond >>= fun msg ->
match msg with
| Send s ->
Lwt.wakeup s x; (* sync with reader *)
Lwt.return_unit
| SendQueue -> try_write () (* try again! *)
| Close -> Lwt.fail Closed
)
in
Lwt_mutex.with_lock t.lock try_write
let rec connect_rec r w =
read r >>= function
| `End -> Lwt.return_unit
| (`Error _ | `Ok _) as step ->
write w step >>= fun () ->
connect_rec r w
let connect a b =
let fut = connect_rec a b in
keep b fut
(* close a when b closes *)
let close_when_closed a b =
Lwt.on_success b.closed
(fun () -> close_async a)
(* close a when every member of l closes *)
let close_when_all_closed a l =
let n = ref (List.length l) in
List.iter
(fun p -> Lwt.on_success p.closed
(fun () ->
decr n;
if !n = 0 then close_async a
)
) l
end
module Writer = struct
type 'a t = ('a, [`w]) Pipe.t
let write t x = Pipe.write t (`Ok x)
let write_error t msg = Pipe.write t (`Error msg)
let rec write_list t l = match l with
| [] -> Lwt.return_unit
| x :: tail ->
write t x >>= fun () -> write_list t tail
let map ~f a =
let b = Pipe.create() in
let rec fwd () =
Pipe.read b >>= function
| `Ok x -> write a (f x) >>= fwd
| `Error msg -> write_error a msg >>= fwd
| `End -> Lwt.return_unit
in
Pipe.keep b (fwd());
(* when a gets closed, close b too *)
Lwt.on_success (Pipe.on_close a) (fun () -> Pipe.close_async b);
b
let send_all l =
if l = [] then invalid_arg "send_all";
let res = Pipe.create () in
let rec fwd () =
Pipe.read res >>= function
| `End -> Lwt.return_unit
| `Ok x -> Lwt_list.iter_p (fun p -> write p x) l >>= fwd
| `Error msg -> Lwt_list.iter_p (fun p -> write_error p msg) l >>= fwd
in
(* do not GC before res dies; close res when any outputx is closed *)
Pipe.keep res (fwd ());
List.iter (Pipe.close_when_closed res) l;
res
let send_both a b = send_all [a; b]
end
module Reader = struct
type 'a t = ('a, [`r]) Pipe.t
let read = Pipe.read
let map ~f a =
let b = Pipe.create () in
let rec fwd () =
Pipe.read a >>= function
| `Ok x -> Pipe.write b (`Ok (f x)) >>= fwd
| (`Error _) as e -> Pipe.write b e >>= fwd
| `End -> Pipe.close b
in
Pipe.keep b (fwd());
b
let filter_map ~f a =
let b = Pipe.create () in
let rec fwd () =
Pipe.read a >>= function
| `Ok x ->
begin match f x with
| None -> fwd()
| Some y -> Pipe.write b (`Ok y) >>= fwd
end
| (`Error _) as e -> Pipe.write b e >>= fwd
| `End -> Pipe.close b
in
Pipe.keep b (fwd());
b
let rec fold ~f ~x t =
read t >>= function
| `End -> LwtErr.return x
| `Error msg -> LwtErr.fail msg
| `Ok y -> fold ~f ~x:(f x y) t
let rec fold_s ~f ~x t =
read t >>= function
| `End -> LwtErr.return x
| `Error msg -> LwtErr.fail msg
| `Ok y ->
f x y >>= fun x -> fold_s ~f ~x t
let rec iter ~f t =
read t >>= function
| `End -> LwtErr.return_unit
| `Error msg -> LwtErr.fail msg
| `Ok x -> f x; iter ~f t
let rec iter_s ~f t =
read t >>= function
| `End -> LwtErr.return_unit
| `Error msg -> LwtErr.fail msg
| `Ok x -> f x >>= fun () -> iter_s ~f t
let merge_all l =
if l = [] then invalid_arg "merge_all";
let res = Pipe.create () in
List.iter (fun p -> Pipe.connect p res) l;
(* connect res' input to all members of l; close res when they all close *)
Pipe.close_when_all_closed res l;
res
let merge_both a b = merge_all [a; b]
end
(** {2 Conversions} *)
let of_list l : _ Reader.t =
let p = Pipe.create ~max_size:0 () in
Pipe.keep p (Lwt_list.iter_s (Writer.write p) l >>= fun () -> Pipe.close p);
p
let of_array a =
let p = Pipe.create ~max_size:0 () in
let rec send i =
if i = Array.length a then Pipe.close p
else (
Writer.write p a.(i) >>= fun () ->
send (i+1)
)
in
Pipe.keep p (send 0);
p
let of_string a =
let p = Pipe.create ~max_size:0 () in
let rec send i =
if i = String.length a then Pipe.close p
else (
Writer.write p (String.get a i) >>= fun () ->
send (i+1)
)
in
Pipe.keep p (send 0);
p
let to_rev_list r =
Reader.fold ~f:(fun acc x -> x :: acc) ~x:[] r
let to_list r = to_rev_list r >>|= List.rev
let to_list_exn r =
to_list r >>= function
| `Error msg -> Lwt.fail (Failure msg)
| `Ok x -> Lwt.return x
let to_buffer buf =
let p = Pipe.create () in
Pipe.keep p (
Reader.iter ~f:(fun c -> Buffer.add_char buf c) p >>= fun _ ->
Lwt.return_unit
);
p
let to_buffer_str buf =
let p = Pipe.create () in
Pipe.keep p (
Reader.iter ~f:(fun s -> Buffer.add_string buf s) p >>= fun _ ->
Lwt.return_unit
);
p
(** {2 Basic IO wrappers} *)
module IO = struct
let read ?(bufsize=4096) ic : _ Reader.t =
let buf = Bytes.make bufsize ' ' in
let p = Pipe.create ~max_size:0 () in
let rec send() =
Lwt_io.read_into ic buf 0 bufsize >>= fun n ->
if n = 0 then Pipe.close p
else
Writer.write p (Bytes.sub_string buf 0 n) >>= fun () ->
send ()
in Lwt.async send;
p
let read_lines ic =
let p = Pipe.create () in
let rec send () =
Lwt_io.read_line_opt ic >>= function
| None -> Pipe.close p
| Some line -> Writer.write p line >>= fun () -> send ()
in
Lwt.async send;
p
let write oc =
let p = Pipe.create () in
Pipe.keep p (
Reader.iter_s ~f:(Lwt_io.write oc) p >>= fun _ ->
Pipe.close p
);
p
let write_lines oc =
let p = Pipe.create () in
Pipe.keep p (
Reader.iter_s ~f:(Lwt_io.write_line oc) p >>= fun _ ->
Pipe.close p
);
p
end