Merge pull request #78 from c-cube/wip-ws

add a websocket library
This commit is contained in:
Simon Cruanes 2024-02-07 15:28:34 -05:00 committed by GitHub
commit 89e3fb91dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1059 additions and 103 deletions

View file

@ -3,7 +3,7 @@ name: github pages
on: on:
push: push:
branches: branches:
- master - main
jobs: jobs:
deploy: deploy:

View file

@ -3,9 +3,8 @@ name: build
on: on:
pull_request: pull_request:
push: push:
schedule: branches:
# Prime the caches every Monday - main
- cron: 0 1 * * MON
jobs: jobs:
build: build:
@ -32,9 +31,6 @@ jobs:
with: with:
ocaml-compiler: ${{ matrix.ocaml-compiler }} ocaml-compiler: ${{ matrix.ocaml-compiler }}
allow-prerelease-opam: true allow-prerelease-opam: true
opam-local-packages: |
./tiny_httpd.opam
./tiny_httpd_camlzip.opam
opam-depext-flags: --with-test opam-depext-flags: --with-test
- run: opam install ./tiny_httpd.opam ./tiny_httpd_camlzip.opam --deps-only --with-test - run: opam install ./tiny_httpd.opam ./tiny_httpd_camlzip.opam --deps-only --with-test

View file

@ -1,49 +0,0 @@
name: build (ocaml 5)
on:
pull_request:
push:
schedule:
# Prime the caches every Monday
- cron: 0 1 * * MON
jobs:
build:
strategy:
fail-fast: true
matrix:
os:
- ubuntu-latest
#- macos-latest
#- windows-latest
ocaml-compiler:
- 5.1.x
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Use OCaml ${{ matrix.ocaml-compiler }}
uses: ocaml/setup-ocaml@v2
with:
ocaml-compiler: ${{ matrix.ocaml-compiler }}
opam-depext-flags: --with-test
allow-prerelease-opam: true
- run: opam install . --deps-only --with-test
- run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip
- run: opam exec -- dune build @src/runtest @examples/runtest @tests/runtest -p tiny_httpd
if: ${{ matrix.os == 'ubuntu-latest' }}
- run: opam install tiny_httpd
- run: opam exec -- dune build @src/runtest @examples/runtest @tests/runtest -p tiny_httpd_camlzip
if: ${{ matrix.os == 'ubuntu-latest' }}
- run: opam install logs -y
- run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip

2
echo_ws.sh Executable file
View file

@ -0,0 +1,2 @@
#!/bin/sh
exec dune exec --display=quiet --profile=release "examples/echo_ws.exe" -- $@

View file

@ -20,6 +20,12 @@
(modules writer) (modules writer)
(libraries tiny_httpd logs)) (libraries tiny_httpd logs))
(executable
(name echo_ws)
(flags :standard -warn-error -a+8)
(modules echo_ws)
(libraries tiny_httpd tiny_httpd.ws logs))
(rule (rule
(targets test_output.txt) (targets test_output.txt)
(deps (deps

67
examples/echo_ws.ml Normal file
View file

@ -0,0 +1,67 @@
module S = Tiny_httpd
module Log = Tiny_httpd.Log
module IO = Tiny_httpd_io
let setup_logging ~debug () =
Logs.set_reporter @@ Logs.format_reporter ();
Logs.set_level ~all:true
@@ Some
(if debug then
Logs.Debug
else
Logs.Info)
let handle_ws _client_addr ic oc =
Log.info (fun k ->
k "new client connection from %s"
(Tiny_httpd_util.show_sockaddr _client_addr));
let (_ : Thread.t) =
Thread.create
(fun () ->
while true do
Thread.delay 3.;
IO.Output.output_string oc "(special ping!)";
IO.Output.flush oc
done)
()
in
let buf = Bytes.create 32 in
let continue = ref true in
while !continue do
let n = IO.Input.input ic buf 0 (Bytes.length buf) in
Log.debug (fun k ->
k "echo %d bytes from websocket: %S" n (Bytes.sub_string buf 0 n));
if n = 0 then continue := false;
IO.Output.output oc buf 0 n;
IO.Output.flush oc
done;
Log.info (fun k -> k "client exiting")
let () =
let port_ = ref 8080 in
let j = ref 32 in
let debug = ref false in
Arg.parse
(Arg.align
[
"--port", Arg.Set_int port_, " set port";
"-p", Arg.Set_int port_, " set port";
"--debug", Arg.Set debug, " enable debug";
"-j", Arg.Set_int j, " maximum number of connections";
])
(fun _ -> raise (Arg.Bad ""))
"echo [option]*";
setup_logging ~debug:!debug ();
let server = S.create ~port:!port_ ~max_connections:!j () in
Tiny_httpd_ws.add_route_handler server
S.Route.(exact "echo" @/ return)
handle_ws;
Printf.printf "listening on http://%s:%d\n%!" (S.addr server) (S.port server);
match S.run server with
| Ok () -> ()
| Error e -> raise e

View file

@ -45,12 +45,60 @@ module Input = struct
Unix.close fd); Unix.close fd);
} }
let of_slice (i_bs : bytes) (i_off : int) (i_len : int) : t =
let i_off = ref i_off in
let i_len = ref i_len in
{
input =
(fun buf i len ->
let n = min len !i_len in
Bytes.blit i_bs !i_off buf i n;
i_off := !i_off + n;
i_len := !i_len - n;
n);
close = ignore;
}
(** Read into the given slice. (** Read into the given slice.
@return the number of bytes read, [0] means end of input. *) @return the number of bytes read, [0] means end of input. *)
let[@inline] input (self : t) buf i len = self.input buf i len let[@inline] input (self : t) buf i len = self.input buf i len
(** Read exactly [len] bytes.
@raise End_of_file if the input did not contain enough data. *)
let really_input (self : t) buf i len : unit =
let i = ref i in
let len = ref len in
while !len > 0 do
let n = input self buf !i !len in
if n = 0 then raise End_of_file;
i := !i + n;
len := !len - n
done
(** Close the channel. *) (** Close the channel. *)
let[@inline] close self : unit = self.close () let[@inline] close self : unit = self.close ()
let append (i1 : t) (i2 : t) : t =
let use_i1 = ref true in
let rec input buf i len : int =
if !use_i1 then (
let n = i1.input buf i len in
if n = 0 then (
use_i1 := false;
input buf i len
) else
n
) else
i2.input buf i len
in
{
input;
close =
(fun () ->
close i1;
close i2);
}
end end
(** Output channel (byte sink) *) (** Output channel (byte sink) *)

View file

@ -46,7 +46,8 @@ module Response_code = struct
let[@inline] is_success n = n >= 200 && n < 400 let[@inline] is_success n = n >= 200 && n < 400
end end
type 'a resp_result = ('a, Response_code.t * string) result type resp_error = Response_code.t * string
type 'a resp_result = ('a, resp_error) result
let unwrap_resp_result = function let unwrap_resp_result = function
| Ok x -> x | Ok x -> x
@ -633,6 +634,27 @@ end
type server_sent_generator = (module SERVER_SENT_GENERATOR) type server_sent_generator = (module SERVER_SENT_GENERATOR)
(** Handler that upgrades to another protocol *)
module type UPGRADE_HANDLER = sig
type handshake_state
(** Some specific state returned after handshake *)
val name : string
(** Name in the "upgrade" header *)
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result
(** Perform the handshake and upgrade the connection. The returned
code is [101] alongside these headers. *)
val handle_connection :
Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit
(** Take control of the connection and take it from there *)
end
type upgrade_handler = (module UPGRADE_HANDLER)
exception Upgrade of unit Request.t * upgrade_handler
module type IO_BACKEND = sig module type IO_BACKEND = sig
val init_addr : unit -> string val init_addr : unit -> string
val init_port : unit -> int val init_port : unit -> int
@ -644,6 +666,16 @@ module type IO_BACKEND = sig
(** Server that can listen on a port and handle clients. *) (** Server that can listen on a port and handle clients. *)
end end
type handler_result =
| Handle of cb_path_handler
| Fail of resp_error
| Upgrade of upgrade_handler
let unwrap_handler_result req = function
| Handle x -> x
| Fail (c, s) -> raise (Bad_req (c, s))
| Upgrade up -> raise (Upgrade (req, up))
type t = { type t = {
backend: (module IO_BACKEND); backend: (module IO_BACKEND);
mutable tcp_server: IO.TCP_server.t option; mutable tcp_server: IO.TCP_server.t option;
@ -653,8 +685,7 @@ type t = {
mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) mutable middlewares: (int * Middleware.t) list; (** Global middlewares *)
mutable middlewares_sorted: (int * Middleware.t) list lazy_t; mutable middlewares_sorted: (int * Middleware.t) list lazy_t;
(** sorted version of {!middlewares} *) (** sorted version of {!middlewares} *)
mutable path_handlers: mutable path_handlers: (unit Request.t -> handler_result option) list;
(unit Request.t -> cb_path_handler resp_result option) list;
(** path handlers *) (** path handlers *)
buf_pool: Buf.t Pool.t; buf_pool: Buf.t Pool.t;
} }
@ -726,7 +757,7 @@ let set_top_handler self f = self.handler <- f
and makes it into a handler. *) and makes it into a handler. *)
let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
~tr_req self (route : _ Route.t) f = ~tr_req self (route : _ Route.t) f =
let ph req : cb_path_handler resp_result option = let ph req : handler_result option =
match meth with match meth with
| Some m when m <> req.Request.meth -> None (* ignore *) | Some m when m <> req.Request.meth -> None (* ignore *)
| _ -> | _ ->
@ -736,11 +767,11 @@ let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth
(match accept req with (match accept req with
| Ok () -> | Ok () ->
Some Some
(Ok (Handle
(fun oc -> (fun oc ->
Middleware.apply_l middlewares @@ fun req ~resp -> Middleware.apply_l middlewares @@ fun req ~resp ->
tr_req oc req ~resp handler)) tr_req oc req ~resp handler))
| Error _ as e -> Some e) | Error err -> Some (Fail err))
| None -> None (* path didn't match *)) | None -> None (* path didn't match *))
in in
self.path_handlers <- ph :: self.path_handlers self.path_handlers <- ph :: self.path_handlers
@ -821,6 +852,22 @@ let add_route_server_sent_handler ?accept self route f =
in in
add_route_handler_ self ?accept ~meth:`GET route ~tr_req f add_route_handler_ self ?accept ~meth:`GET route ~tr_req f
let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = [])
(self : t) route f : unit =
let ph req : handler_result option =
if req.Request.meth <> `GET then
None
else (
match accept req with
| Ok () ->
(match Route.eval req.Request.path_components route f with
| Some up -> Some (Upgrade up)
| None -> None (* path didn't match *))
| Error err -> Some (Fail err)
)
in
self.path_handlers <- ph :: self.path_handlers
let get_max_connection_ ?(max_connections = 64) () : int = let get_max_connection_ ?(max_connections = 64) () : int =
let max_connections = max 4 max_connections in let max_connections = max 4 max_connections in
max_connections max_connections
@ -847,11 +894,6 @@ let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t =
let is_ipv6_str addr : bool = String.contains addr ':' let is_ipv6_str addr : bool = String.contains addr ':'
let str_of_sockaddr = function
| Unix.ADDR_UNIX f -> f
| Unix.ADDR_INET (inet, port) ->
Printf.sprintf "%s:%d" (Unix.string_of_inet_addr inet) port
module Unix_tcp_server_ = struct module Unix_tcp_server_ = struct
type t = { type t = {
addr: string; addr: string;
@ -918,7 +960,8 @@ module Unix_tcp_server_ = struct
let handle_client_unix_ (client_sock : Unix.file_descr) let handle_client_unix_ (client_sock : Unix.file_descr)
(client_addr : Unix.sockaddr) : unit = (client_addr : Unix.sockaddr) : unit =
Log.info (fun k -> Log.info (fun k ->
k "serving new client on %s" (str_of_sockaddr client_addr)); k "serving new client on %s"
(Tiny_httpd_util.show_sockaddr client_addr));
Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout); Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout);
Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout); Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout);
let oc = let oc =
@ -928,12 +971,14 @@ module Unix_tcp_server_ = struct
handle.handle ~client_addr ic oc; handle.handle ~client_addr ic oc;
Log.info (fun k -> Log.info (fun k ->
k "done with client on %s, exiting" k "done with client on %s, exiting"
@@ str_of_sockaddr client_addr); @@ Tiny_httpd_util.show_sockaddr client_addr);
(try Unix.close client_sock (try
Unix.shutdown client_sock Unix.SHUTDOWN_ALL;
Unix.close client_sock
with e -> with e ->
Log.error (fun k -> Log.error (fun k ->
k "error when closing sock for client %s: %s" k "error when closing sock for client %s: %s"
(str_of_sockaddr client_addr) (Tiny_httpd_util.show_sockaddr client_addr)
(Printexc.to_string e))); (Printexc.to_string e)));
() ()
in in
@ -962,7 +1007,7 @@ module Unix_tcp_server_ = struct
k k
"@[<v>Handler: uncaught exception for client %s:@ \ "@[<v>Handler: uncaught exception for client %s:@ \
%s@ %s@]" %s@ %s@]"
(str_of_sockaddr client_addr) (Tiny_httpd_util.show_sockaddr client_addr)
(Printexc.to_string e) (Printexc.to_string e)
(Printexc.raw_backtrace_to_string bt))); (Printexc.raw_backtrace_to_string bt)));
ignore Unix.(sigprocmask SIG_UNBLOCK Sys.[ sigint; sighup ]) ignore Unix.(sigprocmask SIG_UNBLOCK Sys.[ sigint; sighup ])
@ -1030,15 +1075,103 @@ let find_map f l =
in in
aux f l aux f l
let string_as_list_contains_ (s : string) (sub : string) : bool =
let fragments = String.split_on_char ',' s in
List.exists (fun fragment -> String.trim fragment = sub) fragments
(* handle client on [ic] and [oc] *) (* handle client on [ic] and [oc] *)
let client_handle_for (self : t) ~client_addr ic oc : unit = let client_handle_for (self : t) ~client_addr ic oc : unit =
Pool.with_resource self.buf_pool @@ fun buf -> Pool.with_resource self.buf_pool @@ fun buf ->
Pool.with_resource self.buf_pool @@ fun buf_res -> Pool.with_resource self.buf_pool @@ fun buf_res ->
let is = Byte_stream.of_input ~buf_size:self.buf_size ic in let is = Byte_stream.of_input ~buf_size:self.buf_size ic in
let (module B) = self.backend in
(* how to log the response to this query *)
let log_response (req : _ Request.t) (resp : Response.t) =
if not Log.dummy then (
let msgf k =
let elapsed = B.get_time_s () -. req.start_time in
k
("response to=%s code=%d time=%.3fs path=%S" : _ format4)
(Tiny_httpd_util.show_sockaddr client_addr)
resp.code elapsed req.path
in
if Response_code.is_success resp.code then
Log.info msgf
else
Log.error msgf
)
in
(* handle generic exception *)
let handle_exn e =
let resp =
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
in
if not Log.dummy then
Log.error (fun k ->
k "response to %s code=%d"
(Tiny_httpd_util.show_sockaddr client_addr)
resp.code);
Response.output_ ~buf:buf_res oc resp
in
let handle_bad_req req e =
let resp =
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
in
log_response req resp;
Response.output_ ~buf:buf_res oc resp
in
let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit =
Log.debug (fun k -> k "upgrade connection");
try
(* check headers *)
(match Request.get_header req "connection" with
| Some str when string_as_list_contains_ str "Upgrade" -> ()
| _ -> bad_reqf 426 "connection header must contain 'Upgrade'");
(match Request.get_header req "upgrade" with
| Some u when u = UP.name -> ()
| Some u -> bad_reqf 426 "expected upgrade to be '%s', got '%s'" UP.name u
| None -> bad_reqf 426 "expected 'connection: upgrade' header");
(* ok, this is the upgrade we expected *)
match UP.handshake req with
| Error msg ->
(* fail the upgrade *)
Log.error (fun k -> k "upgrade failed: %s" msg);
let resp = Response.make_raw ~code:429 "upgrade required" in
log_response req resp;
Response.output_ ~buf:buf_res oc resp
| Ok (headers, handshake_st) ->
(* send the upgrade reply *)
let headers =
[ "connection", "upgrade"; "upgrade", UP.name ] @ headers
in
let resp = Response.make_string ~code:101 ~headers (Ok "") in
log_response req resp;
Response.output_ ~buf:buf_res oc resp;
(* now, give the whole connection over to the upgraded connection.
Make sure to give the leftovers from [is] as well, if any.
There might not be any because the first message doesn't normally come
directly in the same packet as the handshake, but still. *)
let ic =
if is.len > 0 then (
Log.debug (fun k -> k "LEFTOVERS! %d B" is.len);
IO.Input.append (IO.Input.of_slice is.bs is.off is.len) ic
) else
ic
in
UP.handle_connection client_addr handshake_st ic oc
with e -> handle_bad_req req e
in
let continue = ref true in let continue = ref true in
while !continue && running self do
Log.debug (fun k -> k "read next request"); let handle_one_req () =
let (module B) = self.backend in
match match
Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is
with with
@ -1054,28 +1187,11 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
if Request.close_after_req req then continue := false; if Request.close_after_req req then continue := false;
(* how to log the response to this query *)
let log_response (resp : Response.t) =
if not Log.dummy then (
let msgf k =
let elapsed = B.get_time_s () -. req.start_time in
k
("response to=%s code=%d time=%.3fs path=%S" : _ format4)
(str_of_sockaddr client_addr)
resp.code elapsed req.path
in
if Response_code.is_success resp.code then
Log.info msgf
else
Log.error msgf
)
in
(try (try
(* is there a handler for this path? *) (* is there a handler for this path? *)
let base_handler = let base_handler =
match find_map (fun ph -> ph req) self.path_handlers with match find_map (fun ph -> ph req) self.path_handlers with
| Some f -> unwrap_resp_result f | Some f -> unwrap_handler_result req f
| None -> fun _oc req ~resp -> resp (self.handler req) | None -> fun _oc req ~resp -> resp (self.handler req)
in in
@ -1108,7 +1224,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
try try
if Headers.get "connection" r.Response.headers = Some "close" then if Headers.get "connection" r.Response.headers = Some "close" then
continue := false; continue := false;
log_response r; log_response req r;
Response.output_ ~buf:buf_res oc r Response.output_ ~buf:buf_res oc r
with Sys_error _ -> continue := false with Sys_error _ -> continue := false
in in
@ -1123,16 +1239,23 @@ let client_handle_for (self : t) ~client_addr ic oc : unit =
| Bad_req (code, s) -> | Bad_req (code, s) ->
continue := false; continue := false;
let resp = Response.make_raw ~code s in let resp = Response.make_raw ~code s in
log_response resp; log_response req resp;
Response.output_ ~buf:buf_res oc resp Response.output_ ~buf:buf_res oc resp
| e -> | Upgrade _ as e -> raise e
continue := false; | e -> handle_bad_req req e)
let resp = in
Response.fail ~code:500 "server error: %s" (Printexc.to_string e)
in try
log_response resp; while !continue && running self do
Response.output_ ~buf:buf_res oc resp) Log.debug (fun k -> k "read next request");
done handle_one_req ()
done
with
| Upgrade (req, up) ->
(* upgrades take over the whole connection, we won't process
any further request *)
handle_upgrade req up
| e -> handle_exn e
let client_handler (self : t) : IO.TCP_server.conn_handler = let client_handler (self : t) : IO.TCP_server.conn_handler =
{ IO.TCP_server.handle = client_handle_for self } { IO.TCP_server.handle = client_handle_for self }

View file

@ -645,6 +645,46 @@ val add_route_server_sent_handler :
@since 0.9 *) @since 0.9 *)
(** {2 Upgrade handlers}
These handlers upgrade the connection to another protocol.
@since NEXT_RELEASE *)
(** Handler that upgrades to another protocol.
@since NEXT_RELEASE *)
module type UPGRADE_HANDLER = sig
type handshake_state
(** Some specific state returned after handshake *)
val name : string
(** Name in the "upgrade" header *)
val handshake : unit Request.t -> (Headers.t * handshake_state, string) result
(** Perform the handshake and upgrade the connection. The returned
code is [101] alongside these headers.
In case the handshake fails, this only returns [Error log_msg].
The connection is closed without further ado. *)
val handle_connection :
Unix.sockaddr ->
handshake_state ->
Tiny_httpd_io.Input.t ->
Tiny_httpd_io.Output.t ->
unit
(** Take control of the connection and take it from ther.e *)
end
type upgrade_handler = (module UPGRADE_HANDLER)
(** @since NEXT_RELEASE *)
val add_upgrade_handler :
?accept:(unit Request.t -> (unit, Response_code.t * string) result) ->
?middlewares:Middleware.t list ->
t ->
('a, upgrade_handler) Route.t ->
'a ->
unit
(** {2 Run the server} *) (** {2 Run the server} *)
val running : t -> bool val running : t -> bool

View file

@ -107,3 +107,8 @@ let parse_query s : (_ list, string) result =
| Invalid_argument _ | Not_found | Failure _ -> | Invalid_argument _ | Not_found | Failure _ ->
Error (Printf.sprintf "error in parse_query for %S: i=%d,j=%d" s !i !j) Error (Printf.sprintf "error in parse_query for %S: i=%d,j=%d" s !i !j)
| Invalid_query -> Error ("invalid query string: " ^ s) | Invalid_query -> Error ("invalid query string: " ^ s)
let show_sockaddr = function
| Unix.ADDR_UNIX f -> f
| Unix.ADDR_INET (inet, port) ->
Printf.sprintf "%s:%d" (Unix.string_of_inet_addr inet) port

View file

@ -34,3 +34,7 @@ val parse_query : string -> ((string * string) list, string) result
The order might not be preserved. The order might not be preserved.
@since 0.3 @since 0.3
*) *)
val show_sockaddr : Unix.sockaddr -> string
(** Simple printer for socket addresses.
@since NEXT_RELEASE *)

2
src/ws/common_.ml Normal file
View file

@ -0,0 +1,2 @@
let spf = Printf.sprintf
let ( let@ ) = ( @@ )

11
src/ws/dune Normal file
View file

@ -0,0 +1,11 @@
(library
(name tiny_httpd_ws)
(public_name tiny_httpd.ws)
(synopsis "Websockets for tiny_httpd")
(private_modules common_ utils_)
(foreign_stubs
(language c)
(names tiny_httpd_ws_stubs)
(flags :standard -std=c99 -fPIC -O2))
(libraries tiny_httpd threads))

463
src/ws/tiny_httpd_ws.ml Normal file
View file

@ -0,0 +1,463 @@
open Common_
open Tiny_httpd_server
module Log = Tiny_httpd_log
module IO = Tiny_httpd_io
type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit
module Frame_type = struct
type t = int
let continuation : t = 0
let text : t = 1
let binary : t = 2
let close : t = 8
let ping : t = 9
let pong : t = 10
let show = function
| 0 -> "continuation"
| 1 -> "text"
| 2 -> "binary"
| 8 -> "close"
| 9 -> "ping"
| 10 -> "pong"
| _ty -> spf "unknown frame type %xd" _ty
end
module Header = struct
type t = {
mutable fin: bool;
mutable ty: Frame_type.t;
mutable payload_len: int;
mutable mask: bool;
mask_key: bytes; (** len = 4 *)
}
let create () : t =
{
fin = false;
ty = 0;
payload_len = 0;
mask = false;
mask_key = Bytes.create 4;
}
end
exception Close_connection
(** Raised to close the connection. *)
module Writer = struct
type t = {
header: Header.t;
header_buf: bytes;
buf: bytes; (** bufferize writes *)
mutable offset: int; (** number of bytes already in [buf] *)
oc: IO.Output.t;
mutable closed: bool;
mutex: Mutex.t;
}
let create ?(buf_size = 16 * 1024) ~oc () : t =
{
header = Header.create ();
header_buf = Bytes.create 16;
buf = Bytes.create buf_size;
offset = 0;
oc;
closed = false;
mutex = Mutex.create ();
}
let[@inline] with_mutex_ (self : t) f =
Mutex.lock self.mutex;
try
let x = f () in
Mutex.unlock self.mutex;
x
with e ->
Mutex.unlock self.mutex;
raise e
let close self =
if not self.closed then (
self.closed <- true;
raise Close_connection
)
let int_of_bool : bool -> int = Obj.magic
(** Write the frame header to [self.oc] *)
let write_header_ (self : t) : unit =
let header_len = ref 2 in
let b0 =
Char.chr ((int_of_bool self.header.fin lsl 7) lor self.header.ty)
in
Bytes.unsafe_set self.header_buf 0 b0;
(* we don't mask *)
let payload_len = self.header.payload_len in
let payload_first_byte =
if payload_len < 126 then
payload_len
else if payload_len < 1 lsl 16 then (
Bytes.set_int16_be self.header_buf 2 payload_len;
header_len := 4;
126
) else (
Bytes.set_int64_be self.header_buf 2 (Int64.of_int payload_len);
header_len := 10;
127
)
in
let b1 =
Char.chr @@ ((int_of_bool self.header.mask lsl 7) lor payload_first_byte)
in
Bytes.unsafe_set self.header_buf 1 b1;
if self.header.mask then (
Bytes.blit self.header_buf !header_len self.header.mask_key 0 4;
header_len := !header_len + 4
);
(*Log.debug (fun k ->
k "websocket: write header ty=%s (%d B)"
(Frame_type.show self.header.ty)
!header_len);*)
IO.Output.output self.oc self.header_buf 0 !header_len;
()
(** Max fragment size: send 16 kB at a time *)
let max_fragment_size = 16 * 1024
let[@inline never] really_output_buf_ (self : t) =
self.header.fin <- true;
self.header.ty <- Frame_type.binary;
self.header.payload_len <- self.offset;
self.header.mask <- false;
write_header_ self;
IO.Output.output self.oc self.buf 0 self.offset;
self.offset <- 0
let flush_ (self : t) =
if self.closed then raise Close_connection;
if self.offset > 0 then really_output_buf_ self
let[@inline] flush_if_full (self : t) : unit =
if self.offset = Bytes.length self.buf then really_output_buf_ self
let send_pong (self : t) : unit =
let@ () = with_mutex_ self in
self.header.fin <- true;
self.header.ty <- Frame_type.pong;
self.header.payload_len <- 0;
self.header.mask <- false;
(* only write a header, we don't send a payload at all *)
write_header_ self
let output_char (self : t) c : unit =
let@ () = with_mutex_ self in
let cap = Bytes.length self.buf - self.offset in
(* make room for [c] *)
if cap = 0 then really_output_buf_ self;
Bytes.set self.buf self.offset c;
self.offset <- self.offset + 1;
(* if [c] made the buffer full, then flush it *)
if cap = 1 then really_output_buf_ self
let output (self : t) buf i len : unit =
let@ () = with_mutex_ self in
let i = ref i in
let len = ref len in
while !len > 0 do
flush_if_full self;
let n = min !len (Bytes.length self.buf - self.offset) in
assert (n > 0);
Bytes.blit buf !i self.buf self.offset n;
self.offset <- self.offset + n;
i := !i + n;
len := !len - n
done;
flush_if_full self
let flush self : unit =
let@ () = with_mutex_ self in
flush_ self
end
module Reader = struct
type state =
| Begin (** At the beginning of a frame *)
| Reading_frame of { mutable remaining_bytes: int }
(** Currently reading the payload of a frame with [remaining_bytes] left to read *)
| Close
type t = {
ic: IO.Input.t;
writer: Writer.t; (** Writer, to send "pong" *)
header_buf: bytes; (** small buffer to read frame headers *)
small_buf: bytes; (** Used for control frames *)
header: Header.t;
last_ty: Frame_type.t; (** Last frame's type, used for continuation *)
mutable state: state;
}
let create ~ic ~(writer : Writer.t) () : t =
{
ic;
header_buf = Bytes.create 8;
small_buf = Bytes.create 128;
writer;
state = Begin;
last_ty = 0;
header = Header.create ();
}
(** limitation: we only accept frames that are 2^30 bytes long or less *)
let max_fragment_size = 1 lsl 30
(** Read next frame header into [self.header] *)
let read_frame_header (self : t) : unit =
(* read header *)
IO.Input.really_input self.ic self.header_buf 0 2;
let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in
let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in
self.header.fin <- b0 land 1 == 1;
let ext = (b0 lsr 4) land 0b0111 in
if ext <> 0 then (
Log.error (fun k -> k "websocket: unknown extension %d, closing" ext);
raise Close_connection
);
self.header.ty <- b0 land 0b0000_1111;
self.header.mask <- b1 land 0b1000_0000 != 0;
let payload_len : int =
let len = b1 land 0b0111_1111 in
if len = 126 then (
IO.Input.really_input self.ic self.header_buf 0 2;
Bytes.get_int16_be self.header_buf 0
) else if len = 127 then (
IO.Input.really_input self.ic self.header_buf 0 8;
let len64 = Bytes.get_int64_be self.header_buf 0 in
if compare len64 (Int64.of_int max_fragment_size) > 0 then (
Log.error (fun k ->
k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64
max_fragment_size);
raise Close_connection
);
Int64.to_int len64
) else
len
in
self.header.payload_len <- payload_len;
if self.header.mask then
IO.Input.really_input self.ic self.header.mask_key 0 4;
(*Log.debug (fun k ->
k "websocket: read frame header type=%s payload_len=%d mask=%b"
(Frame_type.show self.header.ty)
self.header.payload_len self.header.mask);*)
()
external apply_masking_ : bytes -> bytes -> int -> int -> unit
= "tiny_httpd_ws_apply_masking"
[@@noalloc]
(** Apply masking to the parsed data *)
let[@inline] apply_masking ~mask_key (buf : bytes) off len : unit =
assert (off >= 0 && off + len <= Bytes.length buf);
apply_masking_ mask_key buf off len
let read_body_to_string (self : t) : string =
let len = self.header.payload_len in
let buf = Bytes.create len in
IO.Input.really_input self.ic buf 0 len;
if self.header.mask then
apply_masking ~mask_key:self.header.mask_key buf 0 len;
Bytes.unsafe_to_string buf
(** Skip bytes of the body *)
let skip_body (self : t) : unit =
let len = ref self.header.payload_len in
while !len > 0 do
let n = min !len (Bytes.length self.small_buf) in
IO.Input.really_input self.ic self.small_buf 0 n;
len := !len - n
done
(** State machine that reads [len] bytes into [buf] *)
let rec read_rec (self : t) buf i len : int =
match self.state with
| Close -> 0
| Reading_frame r ->
let len = min len r.remaining_bytes in
let n = IO.Input.input self.ic buf i len in
(* update state *)
r.remaining_bytes <- r.remaining_bytes - n;
if r.remaining_bytes = 0 then self.state <- Begin;
if self.header.mask then
apply_masking ~mask_key:self.header.mask_key buf i n
else (
Log.error (fun k -> k "websocket: client's frames must be masked");
raise Close_connection
);
n
| Begin ->
read_frame_header self;
(*Log.debug (fun k ->
k "websocket: read frame of type=%s payload_len=%d"
(Frame_type.show self.header.ty)
self.header.payload_len);*)
(match self.header.ty with
| 0 ->
(* continuation *)
if self.last_ty = 1 || self.last_ty = 2 then
self.state <-
Reading_frame { remaining_bytes = self.header.payload_len }
else (
Log.error (fun k ->
k "continuation frame coming after frame of type %s"
(Frame_type.show self.last_ty));
raise Close_connection
);
read_rec self buf i len
| 1 ->
self.state <-
Reading_frame { remaining_bytes = self.header.payload_len };
read_rec self buf i len
| 2 ->
self.state <-
Reading_frame { remaining_bytes = self.header.payload_len };
read_rec self buf i len
| 8 ->
(* close frame *)
self.state <- Close;
let body = read_body_to_string self in
if String.length body >= 2 then (
let errcode = Bytes.get_int16_be (Bytes.unsafe_of_string body) 0 in
Log.info (fun k ->
k "client send 'close' with errcode=%d, message=%S" errcode
(String.sub body 2 (String.length body - 2)))
);
0
| 9 ->
(* pong, just ignore *)
skip_body self;
Writer.send_pong self.writer;
read_rec self buf i len
| 10 ->
(* pong, just ignore *)
skip_body self;
read_rec self buf i len
| ty ->
Log.error (fun k -> k "unknown frame type: %xd" ty);
raise Close_connection)
let read self buf i len =
try read_rec self buf i len
with Close_connection ->
self.state <- Close;
0
let close self : unit =
if self.state != Close then (
Log.debug (fun k -> k "websocket: close connection from server side");
self.state <- Close
)
end
let upgrade ic oc : _ * _ =
let writer = Writer.create ~oc () in
let reader = Reader.create ~ic ~writer () in
let ws_ic : IO.Input.t =
{
input = (fun buf i len -> Reader.read reader buf i len);
close = (fun () -> Reader.close reader);
}
in
let ws_oc : IO.Output.t =
{
flush =
(fun () ->
Writer.flush writer;
IO.Output.flush oc);
output_char = Writer.output_char writer;
output = Writer.output writer;
close = (fun () -> Writer.close writer);
}
in
ws_ic, ws_oc
(** Turn a regular connection handler (provided by the user) into a websocket upgrade handler *)
module Make_upgrade_handler (X : sig
val accept_ws_protocol : string -> bool
val handler : handler
end) : UPGRADE_HANDLER = struct
type handshake_state = unit
let name = "websocket"
open struct
exception Bad_req of string
let bad_req msg = raise (Bad_req msg)
let bad_reqf fmt = Printf.ksprintf bad_req fmt
end
let handshake_ (req : unit Request.t) =
(match Request.get_header req "sec-websocket-protocol" with
| None -> ()
| Some proto when not (X.accept_ws_protocol proto) ->
bad_reqf "handler rejected websocket protocol %S" proto
| Some _proto -> ());
let key =
match Request.get_header req "sec-websocket-key" with
| None -> bad_req "need sec-websocket-key"
| Some k -> k
in
(* TODO: "origin" header *)
(* produce the accept key *)
let accept =
(* yes, SHA1 is broken. It's also part of the spec for websockets. *)
Utils_.sha_1 (key ^ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|> Utils_.B64.encode ~url:false
in
let headers = [ "sec-websocket-accept", accept ] in
Log.debug (fun k ->
k "websocket: upgrade successful, accept key is %S" accept);
headers, ()
let handshake req : _ result =
try Ok (handshake_ req) with Bad_req s -> Error s
let handle_connection addr () ic oc =
let ws_ic, ws_oc = upgrade ic oc in
try X.handler addr ws_ic ws_oc
with Close_connection ->
Log.debug (fun k -> k "websocket: requested to close the connection");
()
end
let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true)
(server : Tiny_httpd_server.t) route (f : handler) : unit =
let module M = Make_upgrade_handler (struct
let handler = f
let accept_ws_protocol = accept_ws_protocol
end) in
let up : upgrade_handler = (module M) in
Tiny_httpd_server.add_upgrade_handler ?accept server route up

19
src/ws/tiny_httpd_ws.mli Normal file
View file

@ -0,0 +1,19 @@
open Tiny_httpd_server
module IO = Tiny_httpd_io
type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit
(** Websocket handler *)
val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t
(** Upgrade a byte stream to the websocket framing protocol. *)
val add_route_handler :
?accept:(unit Request.t -> (unit, int * string) result) ->
?accept_ws_protocol:(string -> bool) ->
Tiny_httpd_server.t ->
(upgrade_handler, upgrade_handler) Route.t ->
handler ->
unit
(** Add a route handler for a websocket endpoint.
@param accept_ws_protocol decides whether this endpoint accepts the websocket protocol
sent by the client. Default accepts everything. *)

View file

@ -0,0 +1,21 @@
#include <caml/alloc.h>
#include <caml/memory.h>
#include <caml/mlvalues.h>
CAMLprim value tiny_httpd_ws_apply_masking(value _mask_key, value _buf,
value _offset, value _len) {
CAMLparam4(_mask_key, _buf, _offset, _len);
char const *mask_key = String_val(_mask_key);
char *buf = Bytes_val(_buf);
intnat offset = Int_val(_offset);
intnat len = Int_val(_len);
for (intnat i = 0; i < len; ++i) {
char c = buf[offset + i];
char c_m = mask_key[i & 0x3];
buf[offset + i] = c ^ c_m;
}
CAMLreturn(Val_unit);
}

198
src/ws/utils_.ml Normal file
View file

@ -0,0 +1,198 @@
(* To keep the library lightweight, we vendor base64 and sha1
from Daniel Bünzli's excellent libraries. Both of these functions
are used only for the websocket handshake, on tiny data
(one header's worth).
vendored from https://github.com/dbuenzli/uuidm
and https://github.com/dbuenzli/webs . *)
module B64 = struct
let alpha = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
let alpha_url =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
let encode ~url s =
let rec loop alpha len e ei s i =
if i >= len then
Bytes.unsafe_to_string e
else (
let i0 = i and i1 = i + 1 and i2 = i + 2 in
let b0 = Char.code s.[i0] in
let b1 =
if i1 >= len then
0
else
Char.code s.[i1]
in
let b2 =
if i2 >= len then
0
else
Char.code s.[i2]
in
let u = (b0 lsl 16) lor (b1 lsl 8) lor b2 in
let c0 = alpha.[u lsr 18] in
let c1 = alpha.[(u lsr 12) land 63] in
let c2 =
if i1 >= len then
'='
else
alpha.[(u lsr 6) land 63]
in
let c3 =
if i2 >= len then
'='
else
alpha.[u land 63]
in
Bytes.set e ei c0;
Bytes.set e (ei + 1) c1;
Bytes.set e (ei + 2) c2;
Bytes.set e (ei + 3) c3;
loop alpha len e (ei + 4) s (i2 + 1)
)
in
match String.length s with
| 0 -> ""
| len ->
let alpha =
if url then
alpha_url
else
alpha
in
loop alpha len (Bytes.create ((len + 2) / 3 * 4)) 0 s 0
end
let sha_1 s =
(* Based on pseudo-code of RFC 3174. Slow and ugly but does the job. *)
let sha_1_pad s =
let len = String.length s in
let blen = 8 * len in
let rem = len mod 64 in
let mlen =
if rem > 55 then
len + 128 - rem
else
len + 64 - rem
in
let m = Bytes.create mlen in
Bytes.blit_string s 0 m 0 len;
Bytes.fill m len (mlen - len) '\x00';
Bytes.set m len '\x80';
if Sys.word_size > 32 then (
Bytes.set m (mlen - 8) (Char.unsafe_chr ((blen lsr 56) land 0xFF));
Bytes.set m (mlen - 7) (Char.unsafe_chr ((blen lsr 48) land 0xFF));
Bytes.set m (mlen - 6) (Char.unsafe_chr ((blen lsr 40) land 0xFF));
Bytes.set m (mlen - 5) (Char.unsafe_chr ((blen lsr 32) land 0xFF))
);
Bytes.set m (mlen - 4) (Char.unsafe_chr ((blen lsr 24) land 0xFF));
Bytes.set m (mlen - 3) (Char.unsafe_chr ((blen lsr 16) land 0xFF));
Bytes.set m (mlen - 2) (Char.unsafe_chr ((blen lsr 8) land 0xFF));
Bytes.set m (mlen - 1) (Char.unsafe_chr (blen land 0xFF));
m
in
(* Operations on int32 *)
let ( &&& ) = ( land ) in
let ( lor ) = Int32.logor in
let ( lxor ) = Int32.logxor in
let ( land ) = Int32.logand in
let ( ++ ) = Int32.add in
let lnot = Int32.lognot in
let sr = Int32.shift_right in
let sl = Int32.shift_left in
let cls n x = sl x n lor Int32.shift_right_logical x (32 - n) in
(* Start *)
let m = sha_1_pad s in
let w = Array.make 16 0l in
let h0 = ref 0x67452301l in
let h1 = ref 0xEFCDAB89l in
let h2 = ref 0x98BADCFEl in
let h3 = ref 0x10325476l in
let h4 = ref 0xC3D2E1F0l in
let a = ref 0l in
let b = ref 0l in
let c = ref 0l in
let d = ref 0l in
let e = ref 0l in
for i = 0 to (Bytes.length m / 64) - 1 do
(* For each block *)
(* Fill w *)
let base = i * 64 in
for j = 0 to 15 do
let k = base + (j * 4) in
w.(j) <-
sl (Int32.of_int (Char.code @@ Bytes.get m k)) 24
lor sl (Int32.of_int (Char.code @@ Bytes.get m (k + 1))) 16
lor sl (Int32.of_int (Char.code @@ Bytes.get m (k + 2))) 8
lor Int32.of_int (Char.code @@ Bytes.get m (k + 3))
done;
(* Loop *)
a := !h0;
b := !h1;
c := !h2;
d := !h3;
e := !h4;
for t = 0 to 79 do
let f, k =
if t <= 19 then
!b land !c lor (lnot !b land !d), 0x5A827999l
else if t <= 39 then
!b lxor !c lxor !d, 0x6ED9EBA1l
else if t <= 59 then
!b land !c lor (!b land !d) lor (!c land !d), 0x8F1BBCDCl
else
!b lxor !c lxor !d, 0xCA62C1D6l
in
let s = t &&& 0xF in
if t >= 16 then
w.(s) <-
cls 1
(w.(s + 13 &&& 0xF)
lxor w.(s + 8 &&& 0xF)
lxor w.(s + 2 &&& 0xF)
lxor w.(s));
let temp = cls 5 !a ++ f ++ !e ++ w.(s) ++ k in
e := !d;
d := !c;
c := cls 30 !b;
b := !a;
a := temp
done;
(* Update *)
h0 := !h0 ++ !a;
h1 := !h1 ++ !b;
h2 := !h2 ++ !c;
h3 := !h3 ++ !d;
h4 := !h4 ++ !e
done;
let h = Bytes.create 20 in
let i2s h k i =
Bytes.set h k (Char.unsafe_chr (Int32.to_int (sr i 24) &&& 0xFF));
Bytes.set h (k + 1) (Char.unsafe_chr (Int32.to_int (sr i 16) &&& 0xFF));
Bytes.set h (k + 2) (Char.unsafe_chr (Int32.to_int (sr i 8) &&& 0xFF));
Bytes.set h (k + 3) (Char.unsafe_chr (Int32.to_int i &&& 0xFF))
in
i2s h 0 !h0;
i2s h 4 !h1;
i2s h 8 !h2;
i2s h 12 !h3;
i2s h 16 !h4;
Bytes.unsafe_to_string h
(*---------------------------------------------------------------------------
Copyright (c) 2008 The uuidm programmers
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
---------------------------------------------------------------------------*)