mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-05 19:00:33 -05:00
312 lines
8.9 KiB
OCaml
312 lines
8.9 KiB
OCaml
open Common_
|
|
module Slice = Iostream_types.Slice
|
|
module Fut = Moonpool.Fut
|
|
|
|
let rec read (fd : Fd.t) buf i len : int =
|
|
if len = 0 || Fd.closed fd then
|
|
0
|
|
else (
|
|
match Unix.read fd.fd buf i len with
|
|
| exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
|
|
(* wait for FD to be ready *)
|
|
let cancel = Cancel_handle.create () in
|
|
let@ () =
|
|
Fiber.with_on_self_cancel (fun _ -> Cancel_handle.cancel cancel)
|
|
in
|
|
Moonpool.Private.Suspend_.suspend
|
|
{
|
|
handle =
|
|
(fun ~run:_ ~resume sus ->
|
|
Ev_loop.wait_readable fd cancel (fun cancel ->
|
|
resume sus @@ Ok ();
|
|
Cancel_handle.cancel cancel));
|
|
};
|
|
read fd buf i len
|
|
| n -> n
|
|
)
|
|
|
|
let rec write_once (fd : Fd.t) buf i len : int =
|
|
if len = 0 || Fd.closed fd then
|
|
0
|
|
else (
|
|
match Unix.write fd.fd buf i len with
|
|
| exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
|
|
(* wait for FD to be ready *)
|
|
let cancel = Cancel_handle.create () in
|
|
let@ () =
|
|
Fiber.with_on_self_cancel (fun _ -> Cancel_handle.cancel cancel)
|
|
in
|
|
Moonpool.Private.Suspend_.suspend
|
|
{
|
|
handle =
|
|
(fun ~run:_ ~resume sus ->
|
|
Ev_loop.wait_writable fd cancel (fun cancel ->
|
|
resume sus @@ Ok ();
|
|
Cancel_handle.cancel cancel));
|
|
};
|
|
write_once fd buf i len
|
|
| n -> n
|
|
)
|
|
|
|
let write fd buf i len : unit =
|
|
let i = ref i in
|
|
let len = ref len in
|
|
while !len > 0 do
|
|
let n = write_once fd buf !i !len in
|
|
i := !i + n;
|
|
len := !len - n
|
|
done
|
|
|
|
module Reader = struct
|
|
include Iostream_types.In
|
|
|
|
let of_fd ?(close_noerr = false) (fd : Fd.t) : t =
|
|
object
|
|
method input buf i len = read fd buf i len
|
|
|
|
method close () =
|
|
if close_noerr then
|
|
Fd.close_noerr fd
|
|
else
|
|
Fd.close fd
|
|
end
|
|
|
|
let of_slice (slice : Slice.t) : t =
|
|
(Iostream.In.of_bytes ~off:slice.off ~len:slice.len slice.bytes :> t)
|
|
end
|
|
|
|
module Buf_reader = struct
|
|
include Iostream_types.In_buf
|
|
|
|
let of_fd ?(close_noerr = false) ~(buf : bytes) (fd : Fd.t) : t =
|
|
let eof = ref false in
|
|
object
|
|
inherit Iostream.In_buf.t_from_refill ~bytes:buf ()
|
|
|
|
method private refill (slice : Slice.t) =
|
|
if not !eof then (
|
|
slice.off <- 0;
|
|
slice.len <- read fd slice.bytes 0 (Bytes.length slice.bytes);
|
|
if slice.len = 0 then eof := true
|
|
)
|
|
|
|
method close () =
|
|
if close_noerr then
|
|
Fd.close_noerr fd
|
|
else
|
|
Fd.close fd
|
|
end
|
|
end
|
|
|
|
module Writer = struct
|
|
include Iostream_types.Out
|
|
|
|
let of_fd ?(close_noerr = false) (fd : Fd.t) : t =
|
|
object
|
|
method output buf i len = write fd buf i len
|
|
|
|
method close () =
|
|
if close_noerr then
|
|
Fd.close_noerr fd
|
|
else
|
|
Fd.close fd
|
|
end
|
|
end
|
|
|
|
module Buf_writer = struct
|
|
include Iostream_types.Out_buf
|
|
|
|
let of_fd ?(close_noerr = false) ~(buf : bytes) (fd : Fd.t) : t =
|
|
object
|
|
inherit Iostream.Out_buf.t_from_output ~bytes:buf ()
|
|
method private output_underlying bs i len = write fd bs i len
|
|
|
|
method private close_underlying () =
|
|
if close_noerr then
|
|
Fd.close_noerr fd
|
|
else
|
|
Fd.close fd
|
|
end
|
|
end
|
|
|
|
module Buf_pool = struct
|
|
type t = { with_buf: 'a. int -> (bytes -> 'a) -> 'a } [@@unboxed]
|
|
|
|
let dummy : t = { with_buf = (fun size f -> f (Bytes.create size)) }
|
|
end
|
|
|
|
(** A TCP server abstraction. *)
|
|
module TCP_server = struct
|
|
type conn_handler =
|
|
client_addr:Sockaddr.t -> Buf_reader.t -> Buf_writer.t -> unit
|
|
|
|
class type t = object
|
|
method endpoint : unit -> Sockaddr.t
|
|
method active_connections : unit -> int
|
|
method running : unit -> bool
|
|
method run : unit -> unit
|
|
method stop : unit -> unit
|
|
method await : unit -> unit
|
|
end
|
|
|
|
let[@inline] run (self : #t) = self#run ()
|
|
let[@inline] stop (self : #t) = self#stop ()
|
|
let[@inline] endpoint (self : #t) = self#endpoint ()
|
|
let[@inline] await (self : #t) = self#await ()
|
|
|
|
type state =
|
|
| Created
|
|
| Running
|
|
| Stopped
|
|
|
|
let rec accept_ (sock : Fd.t) =
|
|
match Unix.accept sock.fd with
|
|
| csock, addr -> csock, addr
|
|
| exception Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) ->
|
|
(let cancel = Cancel_handle.create () in
|
|
let@ () =
|
|
Fiber.with_on_self_cancel (fun _ -> Cancel_handle.cancel cancel)
|
|
in
|
|
Tracing_.message "accept: suspend";
|
|
Moonpool.Private.Suspend_.suspend
|
|
{
|
|
handle =
|
|
(fun ~run:_ ~resume sus ->
|
|
Ev_loop.wait_readable sock cancel (fun _cancel ->
|
|
resume sus @@ Ok ()));
|
|
});
|
|
(accept_ [@tailcall]) sock
|
|
|
|
class base_server ?(listen = 32) ?(buf_pool = Buf_pool.dummy)
|
|
?(buf_size = 4096) ~runner ~(handle : conn_handler) (addr : Sockaddr.t) :
|
|
t =
|
|
let n_active_ = A.make 0 in
|
|
let st = A.make Created in
|
|
let fut, promise = Fut.make () in
|
|
|
|
object
|
|
method endpoint () = addr
|
|
method active_connections () = A.get n_active_
|
|
method running () = A.get st = Running
|
|
|
|
method stop () =
|
|
match A.exchange st Stopped with
|
|
| Stopped -> ()
|
|
| Created | Running -> Fut.fulfill_idempotent promise @@ Ok ()
|
|
|
|
method await () = Fut.await fut
|
|
|
|
method run () =
|
|
(* set to Running *)
|
|
let can_start =
|
|
let rec loop () =
|
|
match A.get st with
|
|
| Created -> A.compare_and_set st Created Running || loop ()
|
|
| _ -> false
|
|
in
|
|
loop ()
|
|
in
|
|
|
|
if can_start then
|
|
let@ _sp = Tracing_.with_span "moonpool-unix.tcp-server.run" in
|
|
|
|
let sock =
|
|
try
|
|
let sock =
|
|
Unix.socket (Sockaddr.domain addr) Unix.SOCK_STREAM 0
|
|
in
|
|
Unix.setsockopt sock Unix.TCP_NODELAY true;
|
|
Unix.set_nonblock sock;
|
|
Unix.bind sock addr;
|
|
Unix.listen sock listen;
|
|
Fd.create sock
|
|
with e ->
|
|
let bt = Printexc.get_raw_backtrace () in
|
|
A.set st Stopped;
|
|
Fut.fulfill_idempotent promise @@ Error (Exn_bt.make e bt);
|
|
Printexc.raise_with_backtrace e bt
|
|
in
|
|
|
|
while A.get st = Running do
|
|
let client_sock, client_addr = accept_ sock in
|
|
let client_fd = Fd.create client_sock in
|
|
|
|
(* start a fiber to handle the client *)
|
|
let _ : _ Fiber.t =
|
|
Fiber.spawn_top ~on:runner (fun () ->
|
|
A.incr n_active_;
|
|
|
|
let@ () =
|
|
Fun.protect ~finally:(fun () ->
|
|
A.decr n_active_;
|
|
Fd.close_noerr client_fd)
|
|
in
|
|
|
|
let@ buf_in = buf_pool.with_buf buf_size in
|
|
let@ buf_out = buf_pool.with_buf buf_size in
|
|
|
|
let ic = Buf_reader.of_fd client_fd ~buf:buf_in in
|
|
let oc = Buf_writer.of_fd client_fd ~buf:buf_out in
|
|
handle ~client_addr ic oc)
|
|
in
|
|
()
|
|
done
|
|
end
|
|
|
|
let create ?(after_init = ignore) ?listen ?buf_pool ?buf_size ~runner
|
|
~(handle : conn_handler) (addr : Sockaddr.t) : t =
|
|
let self =
|
|
new base_server ?listen ?buf_pool ?buf_size ~runner ~handle addr
|
|
in
|
|
after_init self;
|
|
self
|
|
|
|
let with_server ?listen ?buf_pool ?buf_size ~runner ~handle addr (f : _ -> 'a)
|
|
: 'a =
|
|
let server =
|
|
new base_server ?listen ?buf_pool ?buf_size ~runner ~handle addr
|
|
in
|
|
run server;
|
|
let@ () = Fun.protect ~finally:(fun () -> stop server) in
|
|
f server
|
|
end
|
|
|
|
module TCP_client = struct
|
|
(** connect asynchronously *)
|
|
let rec connect_ (sock : Fd.t) addr =
|
|
match Unix.connect sock.fd addr with
|
|
| () -> ()
|
|
| exception
|
|
Unix.Unix_error
|
|
((Unix.EWOULDBLOCK | Unix.EINPROGRESS | Unix.EAGAIN), _, _) ->
|
|
Moonpool.Private.Suspend_.suspend
|
|
{
|
|
handle =
|
|
(fun ~run:_ ~resume sus ->
|
|
Ev_loop.wait_writable sock Cancel_handle.dummy (fun _cancel ->
|
|
resume sus @@ Ok ()));
|
|
};
|
|
connect_ sock addr
|
|
|
|
let with_connect' addr (f : Fd.t -> 'a) : 'a =
|
|
let sock = Unix.socket (Sockaddr.domain addr) Unix.SOCK_STREAM 0 in
|
|
Unix.set_nonblock sock;
|
|
Unix.setsockopt sock Unix.TCP_NODELAY true;
|
|
let sock = Fd.create sock in
|
|
|
|
connect_ sock addr;
|
|
|
|
let finally () = Fd.close_noerr sock in
|
|
let@ () = Fun.protect ~finally in
|
|
f sock
|
|
|
|
let with_connect ?(buf_pool = Buf_pool.dummy) ?(buf_size = 4096) addr
|
|
(f : _ -> _ -> 'a) : 'a =
|
|
with_connect' addr (fun sock ->
|
|
let@ buf_in = buf_pool.with_buf buf_size in
|
|
let@ buf_out = buf_pool.with_buf buf_size in
|
|
|
|
let ic = Buf_reader.of_fd ~buf:buf_in sock in
|
|
let oc = Buf_writer.of_fd ~buf:buf_out sock in
|
|
f ic oc)
|
|
end
|