diff --git a/src/ws/tiny_httpd_ws.ml b/src/ws/tiny_httpd_ws.ml index 41a8cc67..a0cf01e1 100644 --- a/src/ws/tiny_httpd_ws.ml +++ b/src/ws/tiny_httpd_ws.ml @@ -1,5 +1,28 @@ open Common_ws_ +module With_lock = struct + type t = { with_lock: 'a. (unit -> 'a) -> 'a } + type builder = unit -> t + + let default_builder : builder = + fun () -> + let mutex = Mutex.create () in + { + with_lock = + (fun f -> + Mutex.lock mutex; + try + let x = f () in + Mutex.unlock mutex; + x + with e -> + Mutex.unlock mutex; + raise e); + } + + let builder : builder ref = ref default_builder +end + type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit module Frame_type = struct @@ -52,7 +75,7 @@ module Writer = struct mutable offset: int; (** number of bytes already in [buf] *) oc: IO.Output.t; mutable closed: bool; - mutex: Mutex.t; + mutex: With_lock.t; } let create ?(buf_size = 16 * 1024) ~oc () : t = @@ -63,19 +86,9 @@ module Writer = struct offset = 0; oc; closed = false; - mutex = Mutex.create (); + mutex = !With_lock.builder (); } - let[@inline] with_mutex_ (self : t) f = - Mutex.lock self.mutex; - try - let x = f () in - Mutex.unlock self.mutex; - x - with e -> - Mutex.unlock self.mutex; - raise e - let[@inline] close self = self.closed <- true let int_of_bool : bool -> int = Obj.magic @@ -142,7 +155,7 @@ module Writer = struct if self.offset = Bytes.length self.buf then really_output_buf_ self let send_pong (self : t) : unit = - let@ () = with_mutex_ self in + let@ () = self.mutex.with_lock in self.header.fin <- true; self.header.ty <- Frame_type.pong; self.header.payload_len <- 0; @@ -151,7 +164,7 @@ module Writer = struct write_header_ self let output_char (self : t) c : unit = - let@ () = with_mutex_ self in + let@ () = self.mutex.with_lock in let cap = Bytes.length self.buf - self.offset in (* make room for [c] *) if cap = 0 then really_output_buf_ self; @@ -161,7 +174,7 @@ module Writer = struct if cap = 1 then really_output_buf_ self let output (self : t) buf i len : unit = - let@ () = with_mutex_ self in + let@ () = self.mutex.with_lock in let i = ref i in let len = ref len in while !len > 0 do @@ -179,7 +192,7 @@ module Writer = struct flush_if_full self let flush self : unit = - let@ () = with_mutex_ self in + let@ () = self.mutex.with_lock in flush_ self end @@ -187,8 +200,8 @@ module Reader = struct type state = | Begin (** At the beginning of a frame *) | Reading_frame of { mutable remaining_bytes: int; mutable num_read: int } - (** Currently reading the payload of a frame with [remaining_bytes] - left to read from the underlying [ic] *) + (** Currently reading the payload of a frame with [remaining_bytes] left + to read from the underlying [ic] *) | Close type t = { @@ -266,7 +279,7 @@ module Reader = struct external apply_masking_ : key:bytes -> key_offset:int -> buf:bytes -> int -> int -> unit = "tiny_httpd_ws_apply_masking" - [@@noalloc] + [@@noalloc] (** Apply masking to the parsed data *) let[@inline] apply_masking ~mask_key ~mask_offset (buf : bytes) off len : unit @@ -414,7 +427,8 @@ let upgrade ic oc : _ * _ = in ws_ic, ws_oc -(** Turn a regular connection handler (provided by the user) into a websocket upgrade handler *) +(** Turn a regular connection handler (provided by the user) into a websocket + upgrade handler *) module Make_upgrade_handler (X : sig val accept_ws_protocol : string -> bool val handler : handler diff --git a/src/ws/tiny_httpd_ws.mli b/src/ws/tiny_httpd_ws.mli index 10ce1fee..4066b2b4 100644 --- a/src/ws/tiny_httpd_ws.mli +++ b/src/ws/tiny_httpd_ws.mli @@ -1,8 +1,7 @@ (** Websockets for Tiny_httpd. - This sub-library ([tiny_httpd.ws]) exports a small implementation - for a websocket server. It has no additional dependencies. - *) + This sub-library ([tiny_httpd.ws]) exports a small implementation for a + websocket server. It has no additional dependencies. *) type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit (** Websocket handler *) @@ -11,8 +10,8 @@ val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t (** Upgrade a byte stream to the websocket framing protocol. *) exception Close_connection -(** Exception that can be raised from IOs inside the handler, - when the connection is closed from underneath. *) +(** Exception that can be raised from IOs inside the handler, when the + connection is closed from underneath. *) val add_route_handler : ?accept:(unit Request.t -> (unit, int * string) result) -> @@ -23,8 +22,9 @@ val add_route_handler : handler -> unit (** Add a route handler for a websocket endpoint. - @param accept_ws_protocol decides whether this endpoint accepts the websocket protocol - sent by the client. Default accepts everything. *) + @param accept_ws_protocol + decides whether this endpoint accepts the websocket protocol sent by the + client. Default accepts everything. *) (**/**) @@ -33,4 +33,15 @@ module Private_ : sig mask_key:bytes -> mask_offset:int -> bytes -> int -> int -> unit end +(** @since NEXT_RELEASE *) +module With_lock : sig + type t = { with_lock: 'a. (unit -> 'a) -> 'a } + type builder = unit -> t + + val default_builder : builder + (** Lock using [Mutex]. *) + + val builder : builder ref +end + (**/**)