mirror of
https://github.com/c-cube/tiny_httpd.git
synced 2025-12-05 19:00:32 -05:00
feat: first draft of the websocket library
This commit is contained in:
parent
d3a4dbc5b0
commit
e1f2edb0ab
3 changed files with 483 additions and 1 deletions
|
|
@ -4,4 +4,4 @@
|
|||
(public_name tiny_httpd_ws)
|
||||
(synopsis "Websockets for tiny_httpd")
|
||||
(private_modules common_)
|
||||
(libraries tiny_httpd base64 cryptokit))
|
||||
(libraries tiny_httpd base64 cryptokit threads))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,463 @@
|
|||
open Common_
|
||||
open Tiny_httpd_server
|
||||
module Log = Tiny_httpd_log
|
||||
module IO = Tiny_httpd_io
|
||||
|
||||
let spf = Printf.sprintf
|
||||
let ( let@ ) = ( @@ )
|
||||
|
||||
type handler = IO.Input.t -> IO.Output.t -> unit
|
||||
|
||||
module Frame_type = struct
|
||||
type t = int
|
||||
|
||||
let continuation : t = 0
|
||||
let text : t = 1
|
||||
let binary : t = 2
|
||||
let close : t = 8
|
||||
let ping : t = 9
|
||||
let pong : t = 10
|
||||
|
||||
let show = function
|
||||
| 0 -> "continuation"
|
||||
| 1 -> "text"
|
||||
| 2 -> "binary"
|
||||
| 8 -> "close"
|
||||
| 9 -> "ping"
|
||||
| 10 -> "pong"
|
||||
| _ty -> spf "unknown frame type %xd" _ty
|
||||
end
|
||||
|
||||
module Header = struct
|
||||
type t = {
|
||||
mutable fin: bool;
|
||||
mutable ty: Frame_type.t;
|
||||
mutable payload_len: int;
|
||||
mutable mask: bool;
|
||||
mutable mask_key: bytes; (** len = 4 *)
|
||||
}
|
||||
|
||||
let create () : t =
|
||||
{
|
||||
fin = false;
|
||||
ty = 0;
|
||||
payload_len = 0;
|
||||
mask = false;
|
||||
mask_key = Bytes.create 4;
|
||||
}
|
||||
end
|
||||
|
||||
exception Close_connection
|
||||
(** Raised to close the connection. *)
|
||||
|
||||
module Writer = struct
|
||||
type t = {
|
||||
header: Header.t;
|
||||
header_buf: bytes;
|
||||
buf: bytes; (** bufferize writes *)
|
||||
mutable offset: int; (** number of bytes already in [buf] *)
|
||||
oc: IO.Output.t;
|
||||
mutable closed: bool;
|
||||
mutex: Mutex.t;
|
||||
}
|
||||
|
||||
let create ?(buf_size = 16 * 1024) ~oc () : t =
|
||||
{
|
||||
header = Header.create ();
|
||||
header_buf = Bytes.create 16;
|
||||
buf = Bytes.create buf_size;
|
||||
offset = 0;
|
||||
oc;
|
||||
closed = false;
|
||||
mutex = Mutex.create ();
|
||||
}
|
||||
|
||||
let[@inline] with_mutex_ (self : t) f =
|
||||
Mutex.lock self.mutex;
|
||||
try
|
||||
let x = f () in
|
||||
Mutex.unlock self.mutex;
|
||||
x
|
||||
with e ->
|
||||
Mutex.unlock self.mutex;
|
||||
raise e
|
||||
|
||||
let close self =
|
||||
if not self.closed then (
|
||||
self.closed <- true;
|
||||
raise Close_connection
|
||||
)
|
||||
|
||||
let int_of_bool : bool -> int = Obj.magic
|
||||
|
||||
(** Write the frame header to [self.oc] *)
|
||||
let write_header_ (self : t) : unit =
|
||||
let header_len = ref 2 in
|
||||
let b0 =
|
||||
Char.chr ((int_of_bool self.header.fin lsl 7) lor self.header.ty)
|
||||
in
|
||||
Bytes.unsafe_set self.header_buf 0 b0;
|
||||
|
||||
(* we don't mask *)
|
||||
let payload_len = self.header.payload_len in
|
||||
let payload_first_byte =
|
||||
if payload_len < 126 then
|
||||
payload_len
|
||||
else if payload_len < 1 lsl 16 then (
|
||||
Bytes.set_int16_be self.header_buf 2 payload_len;
|
||||
header_len := 4;
|
||||
126
|
||||
) else (
|
||||
Bytes.set_int64_be self.header_buf 2 (Int64.of_int payload_len);
|
||||
header_len := 10;
|
||||
127
|
||||
)
|
||||
in
|
||||
|
||||
let b1 =
|
||||
Char.chr @@ ((int_of_bool self.header.mask lsl 7) lor payload_first_byte)
|
||||
in
|
||||
Bytes.unsafe_set self.header_buf 1 b1;
|
||||
|
||||
if self.header.mask then (
|
||||
Bytes.blit self.header_buf !header_len self.header.mask_key 0 4;
|
||||
header_len := !header_len + 4
|
||||
);
|
||||
|
||||
(*Log.debug (fun k ->
|
||||
k "websocket: write header ty=%s (%d B)"
|
||||
(Frame_type.show self.header.ty)
|
||||
!header_len);*)
|
||||
IO.Output.output self.oc self.header_buf 0 !header_len;
|
||||
()
|
||||
|
||||
(** Max fragment size: send 16 kB at a time *)
|
||||
let max_fragment_size = 16 * 1024
|
||||
|
||||
let[@inline never] really_output_buf_ (self : t) =
|
||||
self.header.fin <- true;
|
||||
self.header.ty <- Frame_type.binary;
|
||||
self.header.payload_len <- self.offset;
|
||||
self.header.mask <- false;
|
||||
write_header_ self;
|
||||
|
||||
IO.Output.output self.oc self.buf 0 self.offset;
|
||||
self.offset <- 0
|
||||
|
||||
let flush_ (self : t) =
|
||||
if self.closed then raise Close_connection;
|
||||
if self.offset > 0 then really_output_buf_ self
|
||||
|
||||
let[@inline] flush_if_full (self : t) : unit =
|
||||
if self.offset = Bytes.length self.buf then really_output_buf_ self
|
||||
|
||||
let send_pong (self : t) : unit =
|
||||
let@ () = with_mutex_ self in
|
||||
self.header.fin <- true;
|
||||
self.header.ty <- Frame_type.pong;
|
||||
self.header.payload_len <- 0;
|
||||
self.header.mask <- false;
|
||||
(* only write a header, we don't send a payload at all *)
|
||||
write_header_ self
|
||||
|
||||
let output_char (self : t) c : unit =
|
||||
let@ () = with_mutex_ self in
|
||||
let cap = Bytes.length self.buf - self.offset in
|
||||
(* make room for [c] *)
|
||||
if cap = 0 then really_output_buf_ self;
|
||||
Bytes.set self.buf self.offset c;
|
||||
self.offset <- self.offset + 1;
|
||||
(* if [c] made the buffer full, then flush it *)
|
||||
if cap = 1 then really_output_buf_ self
|
||||
|
||||
let output (self : t) buf i len : unit =
|
||||
let@ () = with_mutex_ self in
|
||||
let i = ref i in
|
||||
let len = ref len in
|
||||
while !len > 0 do
|
||||
flush_if_full self;
|
||||
|
||||
let n = min !len (Bytes.length self.buf - self.offset) in
|
||||
assert (n > 0);
|
||||
|
||||
Bytes.blit buf !i self.buf self.offset n;
|
||||
self.offset <- self.offset + n;
|
||||
|
||||
i := !i + n;
|
||||
len := !len - n
|
||||
done;
|
||||
flush_if_full self
|
||||
|
||||
let flush self : unit =
|
||||
let@ () = with_mutex_ self in
|
||||
flush_ self
|
||||
end
|
||||
|
||||
module Reader = struct
|
||||
type state =
|
||||
| Begin (** At the beginning of a frame *)
|
||||
| Reading_frame of { mutable remaining_bytes: int }
|
||||
(** Currently reading the payload of a frame with [remaining_bytes] left to read *)
|
||||
| Close
|
||||
|
||||
type t = {
|
||||
ic: IO.Input.t;
|
||||
writer: Writer.t; (** Writer, to send "pong" *)
|
||||
header_buf: bytes; (** small buffer to read frame headers *)
|
||||
small_buf: bytes; (** Used for control frames *)
|
||||
header: Header.t;
|
||||
last_ty: Frame_type.t; (** Last frame's type, used for continuation *)
|
||||
mutable state: state;
|
||||
}
|
||||
|
||||
let create ~ic ~(writer : Writer.t) () : t =
|
||||
{
|
||||
ic;
|
||||
header_buf = Bytes.create 8;
|
||||
small_buf = Bytes.create 128;
|
||||
writer;
|
||||
state = Begin;
|
||||
last_ty = 0;
|
||||
header = Header.create ();
|
||||
}
|
||||
|
||||
(** limitation: we only accept frames that are 2^30 bytes long or less *)
|
||||
let max_fragment_size = 1 lsl 30
|
||||
|
||||
(** Read next frame header into [self.header] *)
|
||||
let read_frame_header (self : t) : unit =
|
||||
(* read header *)
|
||||
IO.Input.really_input self.ic self.header_buf 0 2;
|
||||
|
||||
let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in
|
||||
let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in
|
||||
|
||||
self.header.fin <- b0 land 1 == 1;
|
||||
let ext = (b0 lsr 4) land 0b0111 in
|
||||
if ext <> 0 then (
|
||||
Log.error (fun k -> k "websocket: unknown extension %d, closing" ext);
|
||||
raise Close_connection
|
||||
);
|
||||
|
||||
self.header.ty <- b0 land 0b0000_1111;
|
||||
self.header.mask <- b1 land 0b1000_0000 != 0;
|
||||
|
||||
let payload_len : int =
|
||||
let len = b1 land 0b0111_1111 in
|
||||
if len = 126 then (
|
||||
IO.Input.really_input self.ic self.header_buf 0 2;
|
||||
Bytes.get_int16_be self.header_buf 0
|
||||
) else if len = 127 then (
|
||||
IO.Input.really_input self.ic self.header_buf 0 8;
|
||||
let len64 = Bytes.get_int64_be self.header_buf 0 in
|
||||
if compare len64 (Int64.of_int max_fragment_size) > 0 then (
|
||||
Log.error (fun k ->
|
||||
k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64
|
||||
max_fragment_size);
|
||||
raise Close_connection
|
||||
);
|
||||
|
||||
Int64.to_int len64
|
||||
) else
|
||||
len
|
||||
in
|
||||
self.header.payload_len <- payload_len;
|
||||
|
||||
if self.header.mask then
|
||||
IO.Input.really_input self.ic self.header.mask_key 0 4;
|
||||
|
||||
(*Log.debug (fun k ->
|
||||
k "websocket: read frame header type=%s payload_len=%d mask=%b"
|
||||
(Frame_type.show self.header.ty)
|
||||
self.header.payload_len self.header.mask);*)
|
||||
()
|
||||
|
||||
(** Apply masking to the parsed data *)
|
||||
let apply_masking ~mask_key (buf : bytes) off len : unit =
|
||||
for i = 0 to len - 1 do
|
||||
let c = Bytes.get buf (off + i) in
|
||||
let c_m = Bytes.unsafe_get mask_key (i land 0b11) in
|
||||
let c_xor = Char.chr (Char.code c lxor Char.code c_m) in
|
||||
Bytes.set buf (off + i) c_xor
|
||||
done
|
||||
|
||||
let read_body_to_string (self : t) : string =
|
||||
let len = self.header.payload_len in
|
||||
let buf = Bytes.create len in
|
||||
IO.Input.really_input self.ic buf 0 len;
|
||||
if self.header.mask then
|
||||
apply_masking ~mask_key:self.header.mask_key buf 0 len;
|
||||
Bytes.unsafe_to_string buf
|
||||
|
||||
(** Skip bytes of the body *)
|
||||
let skip_body (self : t) : unit =
|
||||
let len = ref self.header.payload_len in
|
||||
while !len > 0 do
|
||||
let n = min !len (Bytes.length self.small_buf) in
|
||||
IO.Input.really_input self.ic self.small_buf 0 n;
|
||||
len := !len - n
|
||||
done
|
||||
|
||||
(** State machine that reads [len] bytes into [buf] *)
|
||||
let rec read_rec (self : t) buf i len : int =
|
||||
match self.state with
|
||||
| Close -> 0
|
||||
| Reading_frame r ->
|
||||
let len = min len r.remaining_bytes in
|
||||
let n = IO.Input.input self.ic buf i len in
|
||||
|
||||
(* update state *)
|
||||
r.remaining_bytes <- r.remaining_bytes - n;
|
||||
if r.remaining_bytes = 0 then self.state <- Begin;
|
||||
|
||||
if self.header.mask then
|
||||
apply_masking ~mask_key:self.header.mask_key buf i n
|
||||
else (
|
||||
Log.error (fun k -> k "websocket: client's frames must be masked");
|
||||
raise Close_connection
|
||||
);
|
||||
n
|
||||
| Begin ->
|
||||
read_frame_header self;
|
||||
(*Log.debug (fun k ->
|
||||
k "websocket: read frame of type=%s payload_len=%d"
|
||||
(Frame_type.show self.header.ty)
|
||||
self.header.payload_len);*)
|
||||
(match self.header.ty with
|
||||
| 0 ->
|
||||
(* continuation *)
|
||||
if self.last_ty = 1 || self.last_ty = 2 then
|
||||
self.state <-
|
||||
Reading_frame { remaining_bytes = self.header.payload_len }
|
||||
else (
|
||||
Log.error (fun k ->
|
||||
k "continuation frame coming after frame of type %s"
|
||||
(Frame_type.show self.last_ty));
|
||||
raise Close_connection
|
||||
);
|
||||
read_rec self buf i len
|
||||
| 1 ->
|
||||
self.state <-
|
||||
Reading_frame { remaining_bytes = self.header.payload_len };
|
||||
read_rec self buf i len
|
||||
| 2 ->
|
||||
self.state <-
|
||||
Reading_frame { remaining_bytes = self.header.payload_len };
|
||||
read_rec self buf i len
|
||||
| 8 ->
|
||||
(* close frame *)
|
||||
self.state <- Close;
|
||||
let body = read_body_to_string self in
|
||||
if String.length body >= 2 then (
|
||||
let errcode = String.get_int16_be body 0 in
|
||||
Log.info (fun k ->
|
||||
k "client send 'close' with errcode=%d, message=%S" errcode
|
||||
(String.sub body 2 (String.length body - 2)))
|
||||
);
|
||||
0
|
||||
| 9 ->
|
||||
(* pong, just ignore *)
|
||||
skip_body self;
|
||||
Writer.send_pong self.writer;
|
||||
read_rec self buf i len
|
||||
| 10 ->
|
||||
(* pong, just ignore *)
|
||||
skip_body self;
|
||||
read_rec self buf i len
|
||||
| ty ->
|
||||
Log.error (fun k -> k "unknown frame type: %xd" ty);
|
||||
raise Close_connection)
|
||||
|
||||
let read self buf i len =
|
||||
try read_rec self buf i len
|
||||
with Close_connection ->
|
||||
self.state <- Close;
|
||||
0
|
||||
|
||||
let close self : unit =
|
||||
if self.state != Close then (
|
||||
Log.debug (fun k -> k "websocket: close connection from server side");
|
||||
self.state <- Close
|
||||
)
|
||||
end
|
||||
|
||||
(** Turn a regular connection handler (provided by the user) into a websocket upgrade handler *)
|
||||
module Make_upgrade_handler (X : sig
|
||||
val accept_ws_protocol : string -> bool
|
||||
val handler : handler
|
||||
end) : UPGRADE_HANDLER = struct
|
||||
type handshake_state = unit
|
||||
|
||||
let name = "websocket"
|
||||
|
||||
open struct
|
||||
exception Bad_req of string
|
||||
|
||||
let bad_req msg = raise (Bad_req msg)
|
||||
let bad_reqf fmt = Printf.ksprintf bad_req fmt
|
||||
end
|
||||
|
||||
let handshake_ (req : unit Request.t) =
|
||||
(match Request.get_header req "sec-websocket-protocol" with
|
||||
| None -> ()
|
||||
| Some proto when not (X.accept_ws_protocol proto) ->
|
||||
bad_reqf "handler rejected websocket protocol %S" proto
|
||||
| Some _proto -> ());
|
||||
let key =
|
||||
match Request.get_header req "sec-websocket-key" with
|
||||
| None -> bad_req "need sec-websocket-key"
|
||||
| Some k -> k
|
||||
in
|
||||
|
||||
(* TODO: "origin" header *)
|
||||
|
||||
(* produce the accept key *)
|
||||
let accept =
|
||||
(* yes, SHA1 is broken. It's also part of the spec for websockets. *)
|
||||
let hash = (Cryptokit.Hash.sha1 () [@ocaml.alert "-crypto"]) in
|
||||
Cryptokit.hash_string hash (key ^ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|> Base64.encode_exn
|
||||
in
|
||||
|
||||
let headers = [ "sec-websocket-accept", accept ] in
|
||||
Log.debug (fun k ->
|
||||
k "websocket: upgrade successful, accept key is %S" accept);
|
||||
headers, ()
|
||||
|
||||
let handshake req : _ result =
|
||||
try Ok (handshake_ req) with Bad_req s -> Error s
|
||||
|
||||
let handle_connection () ic oc =
|
||||
let writer = Writer.create ~oc () in
|
||||
let reader = Reader.create ~ic ~writer () in
|
||||
let ws_ic : IO.Input.t =
|
||||
{
|
||||
input = (fun buf i len -> Reader.read reader buf i len);
|
||||
close = (fun () -> Reader.close reader);
|
||||
}
|
||||
in
|
||||
let ws_oc : IO.Output.t =
|
||||
{
|
||||
flush =
|
||||
(fun () ->
|
||||
Writer.flush writer;
|
||||
IO.Output.flush oc);
|
||||
output_char = Writer.output_char writer;
|
||||
output = Writer.output writer;
|
||||
close = (fun () -> Writer.close writer);
|
||||
}
|
||||
in
|
||||
try X.handler ws_ic ws_oc
|
||||
with Close_connection ->
|
||||
Log.debug (fun k -> k "websocket: requested to close the connection");
|
||||
()
|
||||
end
|
||||
|
||||
let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true)
|
||||
(server : Tiny_httpd_server.t) route (f : handler) : unit =
|
||||
let module M = Make_upgrade_handler (struct
|
||||
let handler = f
|
||||
let accept_ws_protocol = accept_ws_protocol
|
||||
end) in
|
||||
let up : upgrade_handler = (module M) in
|
||||
Tiny_httpd_server.add_upgrade_handler ?accept server route up
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
open Common_
|
||||
open Tiny_httpd_server
|
||||
module IO = Tiny_httpd_io
|
||||
|
||||
(* FIXME: also pass client address to the handler *)
|
||||
|
||||
type handler = IO.Input.t -> IO.Output.t -> unit
|
||||
(** Websocket handler *)
|
||||
|
||||
val add_route_handler :
|
||||
?accept:(unit Request.t -> (unit, int * string) result) ->
|
||||
?accept_ws_protocol:(string -> bool) ->
|
||||
Tiny_httpd_server.t ->
|
||||
(upgrade_handler, upgrade_handler) Route.t ->
|
||||
handler ->
|
||||
unit
|
||||
(** Add a route handler for a websocket endpoint.
|
||||
@param accept_ws_protocol decides whether this endpoint accepts the websocket protocol
|
||||
sent by the client. Default accepts everything. *)
|
||||
Loading…
Add table
Reference in a new issue