diff --git a/src/ws/tiny_httpd_ws.ml b/src/ws/tiny_httpd_ws.ml index 0c247c6c..80867d95 100644 --- a/src/ws/tiny_httpd_ws.ml +++ b/src/ws/tiny_httpd_ws.ml @@ -378,6 +378,28 @@ module Reader = struct ) end +let upgrade ic oc : _ * _ = + let writer = Writer.create ~oc () in + let reader = Reader.create ~ic ~writer () in + let ws_ic : IO.Input.t = + { + input = (fun buf i len -> Reader.read reader buf i len); + close = (fun () -> Reader.close reader); + } + in + let ws_oc : IO.Output.t = + { + flush = + (fun () -> + Writer.flush writer; + IO.Output.flush oc); + output_char = Writer.output_char writer; + output = Writer.output writer; + close = (fun () -> Writer.close writer); + } + in + ws_ic, ws_oc + (** Turn a regular connection handler (provided by the user) into a websocket upgrade handler *) module Make_upgrade_handler (X : sig val accept_ws_protocol : string -> bool @@ -424,25 +446,7 @@ end) : UPGRADE_HANDLER = struct try Ok (handshake_ req) with Bad_req s -> Error s let handle_connection addr () ic oc = - let writer = Writer.create ~oc () in - let reader = Reader.create ~ic ~writer () in - let ws_ic : IO.Input.t = - { - input = (fun buf i len -> Reader.read reader buf i len); - close = (fun () -> Reader.close reader); - } - in - let ws_oc : IO.Output.t = - { - flush = - (fun () -> - Writer.flush writer; - IO.Output.flush oc); - output_char = Writer.output_char writer; - output = Writer.output writer; - close = (fun () -> Writer.close writer); - } - in + let ws_ic, ws_oc = upgrade ic oc in try X.handler addr ws_ic ws_oc with Close_connection -> Log.debug (fun k -> k "websocket: requested to close the connection"); diff --git a/src/ws/tiny_httpd_ws.mli b/src/ws/tiny_httpd_ws.mli index 0a44803b..f3e063fc 100644 --- a/src/ws/tiny_httpd_ws.mli +++ b/src/ws/tiny_httpd_ws.mli @@ -4,6 +4,8 @@ module IO = Tiny_httpd_io type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit (** Websocket handler *) +val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t + val add_route_handler : ?accept:(unit Request.t -> (unit, int * string) result) -> ?accept_ws_protocol:(string -> bool) ->