feat ws: pass the whole request to the handler

This commit is contained in:
Simon Cruanes 2024-04-02 14:35:57 -04:00
parent 4b845bf019
commit d8ff243e8d
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
5 changed files with 29 additions and 22 deletions

View file

@ -10,9 +10,9 @@ let setup_logging ~debug () =
else else
Logs.Info) Logs.Info)
let handle_ws _client_addr ic oc = let handle_ws (req : unit Request.t) ic oc =
Log.info (fun k -> Log.info (fun k ->
k "new client connection from %s" (Util.show_sockaddr _client_addr)); k "new client connection from %s" (Util.show_sockaddr req.client_addr));
let (_ : Thread.t) = let (_ : Thread.t) =
Thread.create Thread.create

View file

@ -35,12 +35,14 @@ module type UPGRADE_HANDLER = sig
val name : string val name : string
(** Name in the "upgrade" header *) (** Name in the "upgrade" header *)
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result val handshake :
Unix.sockaddr ->
unit Request.t ->
(Headers.t * handshake_state, string) result
(** Perform the handshake and upgrade the connection. The returned (** Perform the handshake and upgrade the connection. The returned
code is [101] alongside these headers. *) code is [101] alongside these headers. *)
val handle_connection : val handle_connection : handshake_state -> IO.Input.t -> IO.Output.t -> unit
Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit
(** Take control of the connection and take it from there *) (** Take control of the connection and take it from there *)
end end
@ -362,7 +364,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
| None -> bad_reqf 426 "expected 'connection: upgrade' header"); | None -> bad_reqf 426 "expected 'connection: upgrade' header");
(* ok, this is the upgrade we expected *) (* ok, this is the upgrade we expected *)
match UP.handshake req with match UP.handshake client_addr req with
| Error msg -> | Error msg ->
(* fail the upgrade *) (* fail the upgrade *)
Log.error (fun k -> k "upgrade failed: %s" msg); Log.error (fun k -> k "upgrade failed: %s" msg);
@ -378,7 +380,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
log_response req resp; log_response req resp;
Response.Private_.output_ ~bytes:bytes_res oc resp; Response.Private_.output_ ~bytes:bytes_res oc resp;
UP.handle_connection client_addr handshake_st ic oc UP.handle_connection handshake_st ic oc
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
handle_bad_req req e bt handle_bad_req req e bt

View file

@ -250,14 +250,18 @@ module type UPGRADE_HANDLER = sig
val name : string val name : string
(** Name in the "upgrade" header *) (** Name in the "upgrade" header *)
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result val handshake :
(** Perform the handshake and upgrade the connection. The returned Unix.sockaddr ->
code is [101] alongside these headers. unit Request.t ->
In case the handshake fails, this only returns [Error log_msg]. (Headers.t * handshake_state, string) result
The connection is closed without further ado. *) (** Perform the handshake and upgrade the connection. This returns either
[Ok (resp_headers, state)] in case of success, in which case the
server sends a [101] response with [resp_headers];
or it returns [Error log_msg] if the the handshake fails, in which case
the connection is closed without further ado and [log_msg] is logged
locally (but not returned to the client). *)
val handle_connection : val handle_connection : handshake_state -> IO.Input.t -> IO.Output.t -> unit
Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit
(** Take control of the connection and take it from ther.e *) (** Take control of the connection and take it from ther.e *)
end end

View file

@ -1,6 +1,6 @@
open Common_ws_ open Common_ws_
type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit
module Frame_type = struct module Frame_type = struct
type t = int type t = int
@ -407,8 +407,9 @@ let upgrade ic oc : _ * _ =
module Make_upgrade_handler (X : sig module Make_upgrade_handler (X : sig
val accept_ws_protocol : string -> bool val accept_ws_protocol : string -> bool
val handler : handler val handler : handler
end) : Server.UPGRADE_HANDLER = struct end) : Server.UPGRADE_HANDLER with type handshake_state = unit Request.t =
type handshake_state = unit struct
type handshake_state = unit Request.t
let name = "websocket" let name = "websocket"
@ -443,14 +444,14 @@ end) : Server.UPGRADE_HANDLER = struct
let headers = [ "sec-websocket-accept", accept ] in let headers = [ "sec-websocket-accept", accept ] in
Log.debug (fun k -> Log.debug (fun k ->
k "websocket: upgrade successful, accept key is %S" accept); k "websocket: upgrade successful, accept key is %S" accept);
headers, () headers, req
let handshake req : _ result = let handshake _addr req : _ result =
try Ok (handshake_ req) with Bad_req s -> Error s try Ok (handshake_ req) with Bad_req s -> Error s
let handle_connection addr () ic oc = let handle_connection req ic oc =
let ws_ic, ws_oc = upgrade ic oc in let ws_ic, ws_oc = upgrade ic oc in
try X.handler addr ws_ic ws_oc try X.handler req ws_ic ws_oc
with Close_connection -> with Close_connection ->
Log.debug (fun k -> k "websocket: requested to close the connection"); Log.debug (fun k -> k "websocket: requested to close the connection");
() ()

View file

@ -4,7 +4,7 @@
for a websocket server. It has no additional dependencies. for a websocket server. It has no additional dependencies.
*) *)
type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit
(** Websocket handler *) (** Websocket handler *)
val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t