diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 02a132ec..1f11a2b8 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -3,7 +3,7 @@ name: github pages on: push: branches: - - master + - main jobs: deploy: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8d714d10..ddeacbab 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -3,9 +3,8 @@ name: build on: pull_request: push: - schedule: - # Prime the caches every Monday - - cron: 0 1 * * MON + branches: + - main jobs: build: @@ -32,9 +31,6 @@ jobs: with: ocaml-compiler: ${{ matrix.ocaml-compiler }} allow-prerelease-opam: true - opam-local-packages: | - ./tiny_httpd.opam - ./tiny_httpd_camlzip.opam opam-depext-flags: --with-test - run: opam install ./tiny_httpd.opam ./tiny_httpd_camlzip.opam --deps-only --with-test diff --git a/.github/workflows/main5.yml b/.github/workflows/main5.yml deleted file mode 100644 index 04effbab..00000000 --- a/.github/workflows/main5.yml +++ /dev/null @@ -1,49 +0,0 @@ -name: build (ocaml 5) - -on: - pull_request: - push: - schedule: - # Prime the caches every Monday - - cron: 0 1 * * MON - -jobs: - build: - strategy: - fail-fast: true - matrix: - os: - - ubuntu-latest - #- macos-latest - #- windows-latest - ocaml-compiler: - - 5.1.x - - runs-on: ${{ matrix.os }} - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Use OCaml ${{ matrix.ocaml-compiler }} - uses: ocaml/setup-ocaml@v2 - with: - ocaml-compiler: ${{ matrix.ocaml-compiler }} - opam-depext-flags: --with-test - allow-prerelease-opam: true - - - run: opam install . --deps-only --with-test - - - run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip - - - run: opam exec -- dune build @src/runtest @examples/runtest @tests/runtest -p tiny_httpd - if: ${{ matrix.os == 'ubuntu-latest' }} - - - run: opam install tiny_httpd - - - run: opam exec -- dune build @src/runtest @examples/runtest @tests/runtest -p tiny_httpd_camlzip - if: ${{ matrix.os == 'ubuntu-latest' }} - - - run: opam install logs -y - - - run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip diff --git a/echo_ws.sh b/echo_ws.sh new file mode 100755 index 00000000..e087c3d5 --- /dev/null +++ b/echo_ws.sh @@ -0,0 +1,2 @@ +#!/bin/sh +exec dune exec --display=quiet --profile=release "examples/echo_ws.exe" -- $@ diff --git a/examples/dune b/examples/dune index b6f4728a..d2c19915 100644 --- a/examples/dune +++ b/examples/dune @@ -20,6 +20,12 @@ (modules writer) (libraries tiny_httpd logs)) +(executable + (name echo_ws) + (flags :standard -warn-error -a+8) + (modules echo_ws) + (libraries tiny_httpd tiny_httpd.ws logs)) + (rule (targets test_output.txt) (deps diff --git a/examples/echo_ws.ml b/examples/echo_ws.ml new file mode 100644 index 00000000..5a616d3f --- /dev/null +++ b/examples/echo_ws.ml @@ -0,0 +1,67 @@ +module S = Tiny_httpd +module Log = Tiny_httpd.Log +module IO = Tiny_httpd_io + +let setup_logging ~debug () = + Logs.set_reporter @@ Logs.format_reporter (); + Logs.set_level ~all:true + @@ Some + (if debug then + Logs.Debug + else + Logs.Info) + +let handle_ws _client_addr ic oc = + Log.info (fun k -> + k "new client connection from %s" + (Tiny_httpd_util.show_sockaddr _client_addr)); + + let (_ : Thread.t) = + Thread.create + (fun () -> + while true do + Thread.delay 3.; + IO.Output.output_string oc "(special ping!)"; + IO.Output.flush oc + done) + () + in + + let buf = Bytes.create 32 in + let continue = ref true in + while !continue do + let n = IO.Input.input ic buf 0 (Bytes.length buf) in + Log.debug (fun k -> + k "echo %d bytes from websocket: %S" n (Bytes.sub_string buf 0 n)); + + if n = 0 then continue := false; + IO.Output.output oc buf 0 n; + IO.Output.flush oc + done; + Log.info (fun k -> k "client exiting") + +let () = + let port_ = ref 8080 in + let j = ref 32 in + let debug = ref false in + Arg.parse + (Arg.align + [ + "--port", Arg.Set_int port_, " set port"; + "-p", Arg.Set_int port_, " set port"; + "--debug", Arg.Set debug, " enable debug"; + "-j", Arg.Set_int j, " maximum number of connections"; + ]) + (fun _ -> raise (Arg.Bad "")) + "echo [option]*"; + setup_logging ~debug:!debug (); + + let server = S.create ~port:!port_ ~max_connections:!j () in + Tiny_httpd_ws.add_route_handler server + S.Route.(exact "echo" @/ return) + handle_ws; + + Printf.printf "listening on http://%s:%d\n%!" (S.addr server) (S.port server); + match S.run server with + | Ok () -> () + | Error e -> raise e diff --git a/src/Tiny_httpd_io.ml b/src/Tiny_httpd_io.ml index bd829e49..207ba9a5 100644 --- a/src/Tiny_httpd_io.ml +++ b/src/Tiny_httpd_io.ml @@ -45,12 +45,60 @@ module Input = struct Unix.close fd); } + let of_slice (i_bs : bytes) (i_off : int) (i_len : int) : t = + let i_off = ref i_off in + let i_len = ref i_len in + { + input = + (fun buf i len -> + let n = min len !i_len in + Bytes.blit i_bs !i_off buf i n; + i_off := !i_off + n; + i_len := !i_len - n; + n); + close = ignore; + } + (** Read into the given slice. @return the number of bytes read, [0] means end of input. *) let[@inline] input (self : t) buf i len = self.input buf i len + (** Read exactly [len] bytes. + @raise End_of_file if the input did not contain enough data. *) + let really_input (self : t) buf i len : unit = + let i = ref i in + let len = ref len in + while !len > 0 do + let n = input self buf !i !len in + if n = 0 then raise End_of_file; + i := !i + n; + len := !len - n + done + (** Close the channel. *) let[@inline] close self : unit = self.close () + + let append (i1 : t) (i2 : t) : t = + let use_i1 = ref true in + let rec input buf i len : int = + if !use_i1 then ( + let n = i1.input buf i len in + if n = 0 then ( + use_i1 := false; + input buf i len + ) else + n + ) else + i2.input buf i len + in + + { + input; + close = + (fun () -> + close i1; + close i2); + } end (** Output channel (byte sink) *) diff --git a/src/Tiny_httpd_server.ml b/src/Tiny_httpd_server.ml index ade65cdc..7dcd466e 100644 --- a/src/Tiny_httpd_server.ml +++ b/src/Tiny_httpd_server.ml @@ -46,7 +46,8 @@ module Response_code = struct let[@inline] is_success n = n >= 200 && n < 400 end -type 'a resp_result = ('a, Response_code.t * string) result +type resp_error = Response_code.t * string +type 'a resp_result = ('a, resp_error) result let unwrap_resp_result = function | Ok x -> x @@ -633,6 +634,27 @@ end type server_sent_generator = (module SERVER_SENT_GENERATOR) +(** Handler that upgrades to another protocol *) +module type UPGRADE_HANDLER = sig + type handshake_state + (** Some specific state returned after handshake *) + + val name : string + (** Name in the "upgrade" header *) + + val handshake : unit Request.t -> (Headers.t * handshake_state, string) result + (** Perform the handshake and upgrade the connection. The returned + code is [101] alongside these headers. *) + + val handle_connection : + Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit + (** Take control of the connection and take it from there *) +end + +type upgrade_handler = (module UPGRADE_HANDLER) + +exception Upgrade of unit Request.t * upgrade_handler + module type IO_BACKEND = sig val init_addr : unit -> string val init_port : unit -> int @@ -644,6 +666,16 @@ module type IO_BACKEND = sig (** Server that can listen on a port and handle clients. *) end +type handler_result = + | Handle of cb_path_handler + | Fail of resp_error + | Upgrade of upgrade_handler + +let unwrap_handler_result req = function + | Handle x -> x + | Fail (c, s) -> raise (Bad_req (c, s)) + | Upgrade up -> raise (Upgrade (req, up)) + type t = { backend: (module IO_BACKEND); mutable tcp_server: IO.TCP_server.t option; @@ -653,8 +685,7 @@ type t = { mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) mutable middlewares_sorted: (int * Middleware.t) list lazy_t; (** sorted version of {!middlewares} *) - mutable path_handlers: - (unit Request.t -> cb_path_handler resp_result option) list; + mutable path_handlers: (unit Request.t -> handler_result option) list; (** path handlers *) buf_pool: Buf.t Pool.t; } @@ -726,7 +757,7 @@ let set_top_handler self f = self.handler <- f and makes it into a handler. *) let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth ~tr_req self (route : _ Route.t) f = - let ph req : cb_path_handler resp_result option = + let ph req : handler_result option = match meth with | Some m when m <> req.Request.meth -> None (* ignore *) | _ -> @@ -736,11 +767,11 @@ let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth (match accept req with | Ok () -> Some - (Ok + (Handle (fun oc -> Middleware.apply_l middlewares @@ fun req ~resp -> tr_req oc req ~resp handler)) - | Error _ as e -> Some e) + | Error err -> Some (Fail err)) | None -> None (* path didn't match *)) in self.path_handlers <- ph :: self.path_handlers @@ -821,6 +852,22 @@ let add_route_server_sent_handler ?accept self route f = in add_route_handler_ self ?accept ~meth:`GET route ~tr_req f +let add_upgrade_handler ?(accept = fun _ -> Ok ()) ?(middlewares = []) + (self : t) route f : unit = + let ph req : handler_result option = + if req.Request.meth <> `GET then + None + else ( + match accept req with + | Ok () -> + (match Route.eval req.Request.path_components route f with + | Some up -> Some (Upgrade up) + | None -> None (* path didn't match *)) + | Error err -> Some (Fail err) + ) + in + self.path_handlers <- ph :: self.path_handlers + let get_max_connection_ ?(max_connections = 64) () : int = let max_connections = max 4 max_connections in max_connections @@ -847,11 +894,6 @@ let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t = let is_ipv6_str addr : bool = String.contains addr ':' -let str_of_sockaddr = function - | Unix.ADDR_UNIX f -> f - | Unix.ADDR_INET (inet, port) -> - Printf.sprintf "%s:%d" (Unix.string_of_inet_addr inet) port - module Unix_tcp_server_ = struct type t = { addr: string; @@ -918,7 +960,8 @@ module Unix_tcp_server_ = struct let handle_client_unix_ (client_sock : Unix.file_descr) (client_addr : Unix.sockaddr) : unit = Log.info (fun k -> - k "serving new client on %s" (str_of_sockaddr client_addr)); + k "serving new client on %s" + (Tiny_httpd_util.show_sockaddr client_addr)); Unix.(setsockopt_float client_sock SO_RCVTIMEO self.timeout); Unix.(setsockopt_float client_sock SO_SNDTIMEO self.timeout); let oc = @@ -928,12 +971,14 @@ module Unix_tcp_server_ = struct handle.handle ~client_addr ic oc; Log.info (fun k -> k "done with client on %s, exiting" - @@ str_of_sockaddr client_addr); - (try Unix.close client_sock + @@ Tiny_httpd_util.show_sockaddr client_addr); + (try + Unix.shutdown client_sock Unix.SHUTDOWN_ALL; + Unix.close client_sock with e -> Log.error (fun k -> k "error when closing sock for client %s: %s" - (str_of_sockaddr client_addr) + (Tiny_httpd_util.show_sockaddr client_addr) (Printexc.to_string e))); () in @@ -962,7 +1007,7 @@ module Unix_tcp_server_ = struct k "@[Handler: uncaught exception for client %s:@ \ %s@ %s@]" - (str_of_sockaddr client_addr) + (Tiny_httpd_util.show_sockaddr client_addr) (Printexc.to_string e) (Printexc.raw_backtrace_to_string bt))); ignore Unix.(sigprocmask SIG_UNBLOCK Sys.[ sigint; sighup ]) @@ -1030,15 +1075,103 @@ let find_map f l = in aux f l +let string_as_list_contains_ (s : string) (sub : string) : bool = + let fragments = String.split_on_char ',' s in + List.exists (fun fragment -> String.trim fragment = sub) fragments + (* handle client on [ic] and [oc] *) let client_handle_for (self : t) ~client_addr ic oc : unit = Pool.with_resource self.buf_pool @@ fun buf -> Pool.with_resource self.buf_pool @@ fun buf_res -> let is = Byte_stream.of_input ~buf_size:self.buf_size ic in + let (module B) = self.backend in + + (* how to log the response to this query *) + let log_response (req : _ Request.t) (resp : Response.t) = + if not Log.dummy then ( + let msgf k = + let elapsed = B.get_time_s () -. req.start_time in + k + ("response to=%s code=%d time=%.3fs path=%S" : _ format4) + (Tiny_httpd_util.show_sockaddr client_addr) + resp.code elapsed req.path + in + if Response_code.is_success resp.code then + Log.info msgf + else + Log.error msgf + ) + in + + (* handle generic exception *) + let handle_exn e = + let resp = + Response.fail ~code:500 "server error: %s" (Printexc.to_string e) + in + if not Log.dummy then + Log.error (fun k -> + k "response to %s code=%d" + (Tiny_httpd_util.show_sockaddr client_addr) + resp.code); + Response.output_ ~buf:buf_res oc resp + in + + let handle_bad_req req e = + let resp = + Response.fail ~code:500 "server error: %s" (Printexc.to_string e) + in + log_response req resp; + Response.output_ ~buf:buf_res oc resp + in + + let handle_upgrade req (module UP : UPGRADE_HANDLER) : unit = + Log.debug (fun k -> k "upgrade connection"); + try + (* check headers *) + (match Request.get_header req "connection" with + | Some str when string_as_list_contains_ str "Upgrade" -> () + | _ -> bad_reqf 426 "connection header must contain 'Upgrade'"); + (match Request.get_header req "upgrade" with + | Some u when u = UP.name -> () + | Some u -> bad_reqf 426 "expected upgrade to be '%s', got '%s'" UP.name u + | None -> bad_reqf 426 "expected 'connection: upgrade' header"); + + (* ok, this is the upgrade we expected *) + match UP.handshake req with + | Error msg -> + (* fail the upgrade *) + Log.error (fun k -> k "upgrade failed: %s" msg); + let resp = Response.make_raw ~code:429 "upgrade required" in + log_response req resp; + Response.output_ ~buf:buf_res oc resp + | Ok (headers, handshake_st) -> + (* send the upgrade reply *) + let headers = + [ "connection", "upgrade"; "upgrade", UP.name ] @ headers + in + let resp = Response.make_string ~code:101 ~headers (Ok "") in + log_response req resp; + Response.output_ ~buf:buf_res oc resp; + + (* now, give the whole connection over to the upgraded connection. + Make sure to give the leftovers from [is] as well, if any. + There might not be any because the first message doesn't normally come + directly in the same packet as the handshake, but still. *) + let ic = + if is.len > 0 then ( + Log.debug (fun k -> k "LEFTOVERS! %d B" is.len); + IO.Input.append (IO.Input.of_slice is.bs is.off is.len) ic + ) else + ic + in + + UP.handle_connection client_addr handshake_st ic oc + with e -> handle_bad_req req e + in + let continue = ref true in - while !continue && running self do - Log.debug (fun k -> k "read next request"); - let (module B) = self.backend in + + let handle_one_req () = match Request.parse_req_start ~client_addr ~get_time_s:B.get_time_s ~buf is with @@ -1054,28 +1187,11 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = if Request.close_after_req req then continue := false; - (* how to log the response to this query *) - let log_response (resp : Response.t) = - if not Log.dummy then ( - let msgf k = - let elapsed = B.get_time_s () -. req.start_time in - k - ("response to=%s code=%d time=%.3fs path=%S" : _ format4) - (str_of_sockaddr client_addr) - resp.code elapsed req.path - in - if Response_code.is_success resp.code then - Log.info msgf - else - Log.error msgf - ) - in - (try (* is there a handler for this path? *) let base_handler = match find_map (fun ph -> ph req) self.path_handlers with - | Some f -> unwrap_resp_result f + | Some f -> unwrap_handler_result req f | None -> fun _oc req ~resp -> resp (self.handler req) in @@ -1108,7 +1224,7 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = try if Headers.get "connection" r.Response.headers = Some "close" then continue := false; - log_response r; + log_response req r; Response.output_ ~buf:buf_res oc r with Sys_error _ -> continue := false in @@ -1123,16 +1239,23 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = | Bad_req (code, s) -> continue := false; let resp = Response.make_raw ~code s in - log_response resp; + log_response req resp; Response.output_ ~buf:buf_res oc resp - | e -> - continue := false; - let resp = - Response.fail ~code:500 "server error: %s" (Printexc.to_string e) - in - log_response resp; - Response.output_ ~buf:buf_res oc resp) - done + | Upgrade _ as e -> raise e + | e -> handle_bad_req req e) + in + + try + while !continue && running self do + Log.debug (fun k -> k "read next request"); + handle_one_req () + done + with + | Upgrade (req, up) -> + (* upgrades take over the whole connection, we won't process + any further request *) + handle_upgrade req up + | e -> handle_exn e let client_handler (self : t) : IO.TCP_server.conn_handler = { IO.TCP_server.handle = client_handle_for self } diff --git a/src/Tiny_httpd_server.mli b/src/Tiny_httpd_server.mli index 67270cdf..591fff94 100644 --- a/src/Tiny_httpd_server.mli +++ b/src/Tiny_httpd_server.mli @@ -645,6 +645,46 @@ val add_route_server_sent_handler : @since 0.9 *) +(** {2 Upgrade handlers} + + These handlers upgrade the connection to another protocol. + @since NEXT_RELEASE *) + +(** Handler that upgrades to another protocol. + @since NEXT_RELEASE *) +module type UPGRADE_HANDLER = sig + type handshake_state + (** Some specific state returned after handshake *) + + val name : string + (** Name in the "upgrade" header *) + + val handshake : unit Request.t -> (Headers.t * handshake_state, string) result + (** Perform the handshake and upgrade the connection. The returned + code is [101] alongside these headers. + In case the handshake fails, this only returns [Error log_msg]. + The connection is closed without further ado. *) + + val handle_connection : + Unix.sockaddr -> + handshake_state -> + Tiny_httpd_io.Input.t -> + Tiny_httpd_io.Output.t -> + unit + (** Take control of the connection and take it from ther.e *) +end + +type upgrade_handler = (module UPGRADE_HANDLER) +(** @since NEXT_RELEASE *) + +val add_upgrade_handler : + ?accept:(unit Request.t -> (unit, Response_code.t * string) result) -> + ?middlewares:Middleware.t list -> + t -> + ('a, upgrade_handler) Route.t -> + 'a -> + unit + (** {2 Run the server} *) val running : t -> bool diff --git a/src/Tiny_httpd_util.ml b/src/Tiny_httpd_util.ml index c87adbe4..9ec935ae 100644 --- a/src/Tiny_httpd_util.ml +++ b/src/Tiny_httpd_util.ml @@ -107,3 +107,8 @@ let parse_query s : (_ list, string) result = | Invalid_argument _ | Not_found | Failure _ -> Error (Printf.sprintf "error in parse_query for %S: i=%d,j=%d" s !i !j) | Invalid_query -> Error ("invalid query string: " ^ s) + +let show_sockaddr = function + | Unix.ADDR_UNIX f -> f + | Unix.ADDR_INET (inet, port) -> + Printf.sprintf "%s:%d" (Unix.string_of_inet_addr inet) port diff --git a/src/Tiny_httpd_util.mli b/src/Tiny_httpd_util.mli index f29209ce..ac996855 100644 --- a/src/Tiny_httpd_util.mli +++ b/src/Tiny_httpd_util.mli @@ -34,3 +34,7 @@ val parse_query : string -> ((string * string) list, string) result The order might not be preserved. @since 0.3 *) + +val show_sockaddr : Unix.sockaddr -> string +(** Simple printer for socket addresses. + @since NEXT_RELEASE *) diff --git a/src/ws/common_.ml b/src/ws/common_.ml new file mode 100644 index 00000000..699d6c9f --- /dev/null +++ b/src/ws/common_.ml @@ -0,0 +1,2 @@ +let spf = Printf.sprintf +let ( let@ ) = ( @@ ) diff --git a/src/ws/dune b/src/ws/dune new file mode 100644 index 00000000..f2aab877 --- /dev/null +++ b/src/ws/dune @@ -0,0 +1,11 @@ + +(library + (name tiny_httpd_ws) + (public_name tiny_httpd.ws) + (synopsis "Websockets for tiny_httpd") + (private_modules common_ utils_) + (foreign_stubs + (language c) + (names tiny_httpd_ws_stubs) + (flags :standard -std=c99 -fPIC -O2)) + (libraries tiny_httpd threads)) diff --git a/src/ws/tiny_httpd_ws.ml b/src/ws/tiny_httpd_ws.ml new file mode 100644 index 00000000..80867d95 --- /dev/null +++ b/src/ws/tiny_httpd_ws.ml @@ -0,0 +1,463 @@ +open Common_ +open Tiny_httpd_server +module Log = Tiny_httpd_log +module IO = Tiny_httpd_io + +type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit + +module Frame_type = struct + type t = int + + let continuation : t = 0 + let text : t = 1 + let binary : t = 2 + let close : t = 8 + let ping : t = 9 + let pong : t = 10 + + let show = function + | 0 -> "continuation" + | 1 -> "text" + | 2 -> "binary" + | 8 -> "close" + | 9 -> "ping" + | 10 -> "pong" + | _ty -> spf "unknown frame type %xd" _ty +end + +module Header = struct + type t = { + mutable fin: bool; + mutable ty: Frame_type.t; + mutable payload_len: int; + mutable mask: bool; + mask_key: bytes; (** len = 4 *) + } + + let create () : t = + { + fin = false; + ty = 0; + payload_len = 0; + mask = false; + mask_key = Bytes.create 4; + } +end + +exception Close_connection +(** Raised to close the connection. *) + +module Writer = struct + type t = { + header: Header.t; + header_buf: bytes; + buf: bytes; (** bufferize writes *) + mutable offset: int; (** number of bytes already in [buf] *) + oc: IO.Output.t; + mutable closed: bool; + mutex: Mutex.t; + } + + let create ?(buf_size = 16 * 1024) ~oc () : t = + { + header = Header.create (); + header_buf = Bytes.create 16; + buf = Bytes.create buf_size; + offset = 0; + oc; + closed = false; + mutex = Mutex.create (); + } + + let[@inline] with_mutex_ (self : t) f = + Mutex.lock self.mutex; + try + let x = f () in + Mutex.unlock self.mutex; + x + with e -> + Mutex.unlock self.mutex; + raise e + + let close self = + if not self.closed then ( + self.closed <- true; + raise Close_connection + ) + + let int_of_bool : bool -> int = Obj.magic + + (** Write the frame header to [self.oc] *) + let write_header_ (self : t) : unit = + let header_len = ref 2 in + let b0 = + Char.chr ((int_of_bool self.header.fin lsl 7) lor self.header.ty) + in + Bytes.unsafe_set self.header_buf 0 b0; + + (* we don't mask *) + let payload_len = self.header.payload_len in + let payload_first_byte = + if payload_len < 126 then + payload_len + else if payload_len < 1 lsl 16 then ( + Bytes.set_int16_be self.header_buf 2 payload_len; + header_len := 4; + 126 + ) else ( + Bytes.set_int64_be self.header_buf 2 (Int64.of_int payload_len); + header_len := 10; + 127 + ) + in + + let b1 = + Char.chr @@ ((int_of_bool self.header.mask lsl 7) lor payload_first_byte) + in + Bytes.unsafe_set self.header_buf 1 b1; + + if self.header.mask then ( + Bytes.blit self.header_buf !header_len self.header.mask_key 0 4; + header_len := !header_len + 4 + ); + + (*Log.debug (fun k -> + k "websocket: write header ty=%s (%d B)" + (Frame_type.show self.header.ty) + !header_len);*) + IO.Output.output self.oc self.header_buf 0 !header_len; + () + + (** Max fragment size: send 16 kB at a time *) + let max_fragment_size = 16 * 1024 + + let[@inline never] really_output_buf_ (self : t) = + self.header.fin <- true; + self.header.ty <- Frame_type.binary; + self.header.payload_len <- self.offset; + self.header.mask <- false; + write_header_ self; + + IO.Output.output self.oc self.buf 0 self.offset; + self.offset <- 0 + + let flush_ (self : t) = + if self.closed then raise Close_connection; + if self.offset > 0 then really_output_buf_ self + + let[@inline] flush_if_full (self : t) : unit = + if self.offset = Bytes.length self.buf then really_output_buf_ self + + let send_pong (self : t) : unit = + let@ () = with_mutex_ self in + self.header.fin <- true; + self.header.ty <- Frame_type.pong; + self.header.payload_len <- 0; + self.header.mask <- false; + (* only write a header, we don't send a payload at all *) + write_header_ self + + let output_char (self : t) c : unit = + let@ () = with_mutex_ self in + let cap = Bytes.length self.buf - self.offset in + (* make room for [c] *) + if cap = 0 then really_output_buf_ self; + Bytes.set self.buf self.offset c; + self.offset <- self.offset + 1; + (* if [c] made the buffer full, then flush it *) + if cap = 1 then really_output_buf_ self + + let output (self : t) buf i len : unit = + let@ () = with_mutex_ self in + let i = ref i in + let len = ref len in + while !len > 0 do + flush_if_full self; + + let n = min !len (Bytes.length self.buf - self.offset) in + assert (n > 0); + + Bytes.blit buf !i self.buf self.offset n; + self.offset <- self.offset + n; + + i := !i + n; + len := !len - n + done; + flush_if_full self + + let flush self : unit = + let@ () = with_mutex_ self in + flush_ self +end + +module Reader = struct + type state = + | Begin (** At the beginning of a frame *) + | Reading_frame of { mutable remaining_bytes: int } + (** Currently reading the payload of a frame with [remaining_bytes] left to read *) + | Close + + type t = { + ic: IO.Input.t; + writer: Writer.t; (** Writer, to send "pong" *) + header_buf: bytes; (** small buffer to read frame headers *) + small_buf: bytes; (** Used for control frames *) + header: Header.t; + last_ty: Frame_type.t; (** Last frame's type, used for continuation *) + mutable state: state; + } + + let create ~ic ~(writer : Writer.t) () : t = + { + ic; + header_buf = Bytes.create 8; + small_buf = Bytes.create 128; + writer; + state = Begin; + last_ty = 0; + header = Header.create (); + } + + (** limitation: we only accept frames that are 2^30 bytes long or less *) + let max_fragment_size = 1 lsl 30 + + (** Read next frame header into [self.header] *) + let read_frame_header (self : t) : unit = + (* read header *) + IO.Input.really_input self.ic self.header_buf 0 2; + + let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in + let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in + + self.header.fin <- b0 land 1 == 1; + let ext = (b0 lsr 4) land 0b0111 in + if ext <> 0 then ( + Log.error (fun k -> k "websocket: unknown extension %d, closing" ext); + raise Close_connection + ); + + self.header.ty <- b0 land 0b0000_1111; + self.header.mask <- b1 land 0b1000_0000 != 0; + + let payload_len : int = + let len = b1 land 0b0111_1111 in + if len = 126 then ( + IO.Input.really_input self.ic self.header_buf 0 2; + Bytes.get_int16_be self.header_buf 0 + ) else if len = 127 then ( + IO.Input.really_input self.ic self.header_buf 0 8; + let len64 = Bytes.get_int64_be self.header_buf 0 in + if compare len64 (Int64.of_int max_fragment_size) > 0 then ( + Log.error (fun k -> + k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64 + max_fragment_size); + raise Close_connection + ); + + Int64.to_int len64 + ) else + len + in + self.header.payload_len <- payload_len; + + if self.header.mask then + IO.Input.really_input self.ic self.header.mask_key 0 4; + + (*Log.debug (fun k -> + k "websocket: read frame header type=%s payload_len=%d mask=%b" + (Frame_type.show self.header.ty) + self.header.payload_len self.header.mask);*) + () + + external apply_masking_ : bytes -> bytes -> int -> int -> unit + = "tiny_httpd_ws_apply_masking" + [@@noalloc] + (** Apply masking to the parsed data *) + + let[@inline] apply_masking ~mask_key (buf : bytes) off len : unit = + assert (off >= 0 && off + len <= Bytes.length buf); + apply_masking_ mask_key buf off len + + let read_body_to_string (self : t) : string = + let len = self.header.payload_len in + let buf = Bytes.create len in + IO.Input.really_input self.ic buf 0 len; + if self.header.mask then + apply_masking ~mask_key:self.header.mask_key buf 0 len; + Bytes.unsafe_to_string buf + + (** Skip bytes of the body *) + let skip_body (self : t) : unit = + let len = ref self.header.payload_len in + while !len > 0 do + let n = min !len (Bytes.length self.small_buf) in + IO.Input.really_input self.ic self.small_buf 0 n; + len := !len - n + done + + (** State machine that reads [len] bytes into [buf] *) + let rec read_rec (self : t) buf i len : int = + match self.state with + | Close -> 0 + | Reading_frame r -> + let len = min len r.remaining_bytes in + let n = IO.Input.input self.ic buf i len in + + (* update state *) + r.remaining_bytes <- r.remaining_bytes - n; + if r.remaining_bytes = 0 then self.state <- Begin; + + if self.header.mask then + apply_masking ~mask_key:self.header.mask_key buf i n + else ( + Log.error (fun k -> k "websocket: client's frames must be masked"); + raise Close_connection + ); + n + | Begin -> + read_frame_header self; + (*Log.debug (fun k -> + k "websocket: read frame of type=%s payload_len=%d" + (Frame_type.show self.header.ty) + self.header.payload_len);*) + (match self.header.ty with + | 0 -> + (* continuation *) + if self.last_ty = 1 || self.last_ty = 2 then + self.state <- + Reading_frame { remaining_bytes = self.header.payload_len } + else ( + Log.error (fun k -> + k "continuation frame coming after frame of type %s" + (Frame_type.show self.last_ty)); + raise Close_connection + ); + read_rec self buf i len + | 1 -> + self.state <- + Reading_frame { remaining_bytes = self.header.payload_len }; + read_rec self buf i len + | 2 -> + self.state <- + Reading_frame { remaining_bytes = self.header.payload_len }; + read_rec self buf i len + | 8 -> + (* close frame *) + self.state <- Close; + let body = read_body_to_string self in + if String.length body >= 2 then ( + let errcode = Bytes.get_int16_be (Bytes.unsafe_of_string body) 0 in + Log.info (fun k -> + k "client send 'close' with errcode=%d, message=%S" errcode + (String.sub body 2 (String.length body - 2))) + ); + 0 + | 9 -> + (* pong, just ignore *) + skip_body self; + Writer.send_pong self.writer; + read_rec self buf i len + | 10 -> + (* pong, just ignore *) + skip_body self; + read_rec self buf i len + | ty -> + Log.error (fun k -> k "unknown frame type: %xd" ty); + raise Close_connection) + + let read self buf i len = + try read_rec self buf i len + with Close_connection -> + self.state <- Close; + 0 + + let close self : unit = + if self.state != Close then ( + Log.debug (fun k -> k "websocket: close connection from server side"); + self.state <- Close + ) +end + +let upgrade ic oc : _ * _ = + let writer = Writer.create ~oc () in + let reader = Reader.create ~ic ~writer () in + let ws_ic : IO.Input.t = + { + input = (fun buf i len -> Reader.read reader buf i len); + close = (fun () -> Reader.close reader); + } + in + let ws_oc : IO.Output.t = + { + flush = + (fun () -> + Writer.flush writer; + IO.Output.flush oc); + output_char = Writer.output_char writer; + output = Writer.output writer; + close = (fun () -> Writer.close writer); + } + in + ws_ic, ws_oc + +(** Turn a regular connection handler (provided by the user) into a websocket upgrade handler *) +module Make_upgrade_handler (X : sig + val accept_ws_protocol : string -> bool + val handler : handler +end) : UPGRADE_HANDLER = struct + type handshake_state = unit + + let name = "websocket" + + open struct + exception Bad_req of string + + let bad_req msg = raise (Bad_req msg) + let bad_reqf fmt = Printf.ksprintf bad_req fmt + end + + let handshake_ (req : unit Request.t) = + (match Request.get_header req "sec-websocket-protocol" with + | None -> () + | Some proto when not (X.accept_ws_protocol proto) -> + bad_reqf "handler rejected websocket protocol %S" proto + | Some _proto -> ()); + let key = + match Request.get_header req "sec-websocket-key" with + | None -> bad_req "need sec-websocket-key" + | Some k -> k + in + + (* TODO: "origin" header *) + + (* produce the accept key *) + let accept = + (* yes, SHA1 is broken. It's also part of the spec for websockets. *) + Utils_.sha_1 (key ^ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + |> Utils_.B64.encode ~url:false + in + + let headers = [ "sec-websocket-accept", accept ] in + Log.debug (fun k -> + k "websocket: upgrade successful, accept key is %S" accept); + headers, () + + let handshake req : _ result = + try Ok (handshake_ req) with Bad_req s -> Error s + + let handle_connection addr () ic oc = + let ws_ic, ws_oc = upgrade ic oc in + try X.handler addr ws_ic ws_oc + with Close_connection -> + Log.debug (fun k -> k "websocket: requested to close the connection"); + () +end + +let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true) + (server : Tiny_httpd_server.t) route (f : handler) : unit = + let module M = Make_upgrade_handler (struct + let handler = f + let accept_ws_protocol = accept_ws_protocol + end) in + let up : upgrade_handler = (module M) in + Tiny_httpd_server.add_upgrade_handler ?accept server route up diff --git a/src/ws/tiny_httpd_ws.mli b/src/ws/tiny_httpd_ws.mli new file mode 100644 index 00000000..b3440559 --- /dev/null +++ b/src/ws/tiny_httpd_ws.mli @@ -0,0 +1,19 @@ +open Tiny_httpd_server +module IO = Tiny_httpd_io + +type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit +(** Websocket handler *) + +val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t +(** Upgrade a byte stream to the websocket framing protocol. *) + +val add_route_handler : + ?accept:(unit Request.t -> (unit, int * string) result) -> + ?accept_ws_protocol:(string -> bool) -> + Tiny_httpd_server.t -> + (upgrade_handler, upgrade_handler) Route.t -> + handler -> + unit +(** Add a route handler for a websocket endpoint. + @param accept_ws_protocol decides whether this endpoint accepts the websocket protocol + sent by the client. Default accepts everything. *) diff --git a/src/ws/tiny_httpd_ws_stubs.c b/src/ws/tiny_httpd_ws_stubs.c new file mode 100644 index 00000000..779e255a --- /dev/null +++ b/src/ws/tiny_httpd_ws_stubs.c @@ -0,0 +1,21 @@ + +#include +#include +#include + +CAMLprim value tiny_httpd_ws_apply_masking(value _mask_key, value _buf, + value _offset, value _len) { + CAMLparam4(_mask_key, _buf, _offset, _len); + + char const *mask_key = String_val(_mask_key); + char *buf = Bytes_val(_buf); + intnat offset = Int_val(_offset); + intnat len = Int_val(_len); + + for (intnat i = 0; i < len; ++i) { + char c = buf[offset + i]; + char c_m = mask_key[i & 0x3]; + buf[offset + i] = c ^ c_m; + } + CAMLreturn(Val_unit); +} diff --git a/src/ws/utils_.ml b/src/ws/utils_.ml new file mode 100644 index 00000000..0ac04652 --- /dev/null +++ b/src/ws/utils_.ml @@ -0,0 +1,198 @@ +(* To keep the library lightweight, we vendor base64 and sha1 + from Daniel Bünzli's excellent libraries. Both of these functions + are used only for the websocket handshake, on tiny data + (one header's worth). + + vendored from https://github.com/dbuenzli/uuidm + and https://github.com/dbuenzli/webs . *) + +module B64 = struct + let alpha = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + + let alpha_url = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + + let encode ~url s = + let rec loop alpha len e ei s i = + if i >= len then + Bytes.unsafe_to_string e + else ( + let i0 = i and i1 = i + 1 and i2 = i + 2 in + let b0 = Char.code s.[i0] in + let b1 = + if i1 >= len then + 0 + else + Char.code s.[i1] + in + let b2 = + if i2 >= len then + 0 + else + Char.code s.[i2] + in + let u = (b0 lsl 16) lor (b1 lsl 8) lor b2 in + let c0 = alpha.[u lsr 18] in + let c1 = alpha.[(u lsr 12) land 63] in + let c2 = + if i1 >= len then + '=' + else + alpha.[(u lsr 6) land 63] + in + let c3 = + if i2 >= len then + '=' + else + alpha.[u land 63] + in + Bytes.set e ei c0; + Bytes.set e (ei + 1) c1; + Bytes.set e (ei + 2) c2; + Bytes.set e (ei + 3) c3; + loop alpha len e (ei + 4) s (i2 + 1) + ) + in + match String.length s with + | 0 -> "" + | len -> + let alpha = + if url then + alpha_url + else + alpha + in + loop alpha len (Bytes.create ((len + 2) / 3 * 4)) 0 s 0 +end + +let sha_1 s = + (* Based on pseudo-code of RFC 3174. Slow and ugly but does the job. *) + let sha_1_pad s = + let len = String.length s in + let blen = 8 * len in + let rem = len mod 64 in + let mlen = + if rem > 55 then + len + 128 - rem + else + len + 64 - rem + in + let m = Bytes.create mlen in + Bytes.blit_string s 0 m 0 len; + Bytes.fill m len (mlen - len) '\x00'; + Bytes.set m len '\x80'; + if Sys.word_size > 32 then ( + Bytes.set m (mlen - 8) (Char.unsafe_chr ((blen lsr 56) land 0xFF)); + Bytes.set m (mlen - 7) (Char.unsafe_chr ((blen lsr 48) land 0xFF)); + Bytes.set m (mlen - 6) (Char.unsafe_chr ((blen lsr 40) land 0xFF)); + Bytes.set m (mlen - 5) (Char.unsafe_chr ((blen lsr 32) land 0xFF)) + ); + Bytes.set m (mlen - 4) (Char.unsafe_chr ((blen lsr 24) land 0xFF)); + Bytes.set m (mlen - 3) (Char.unsafe_chr ((blen lsr 16) land 0xFF)); + Bytes.set m (mlen - 2) (Char.unsafe_chr ((blen lsr 8) land 0xFF)); + Bytes.set m (mlen - 1) (Char.unsafe_chr (blen land 0xFF)); + m + in + (* Operations on int32 *) + let ( &&& ) = ( land ) in + let ( lor ) = Int32.logor in + let ( lxor ) = Int32.logxor in + let ( land ) = Int32.logand in + let ( ++ ) = Int32.add in + let lnot = Int32.lognot in + let sr = Int32.shift_right in + let sl = Int32.shift_left in + let cls n x = sl x n lor Int32.shift_right_logical x (32 - n) in + (* Start *) + let m = sha_1_pad s in + let w = Array.make 16 0l in + let h0 = ref 0x67452301l in + let h1 = ref 0xEFCDAB89l in + let h2 = ref 0x98BADCFEl in + let h3 = ref 0x10325476l in + let h4 = ref 0xC3D2E1F0l in + let a = ref 0l in + let b = ref 0l in + let c = ref 0l in + let d = ref 0l in + let e = ref 0l in + for i = 0 to (Bytes.length m / 64) - 1 do + (* For each block *) + (* Fill w *) + let base = i * 64 in + for j = 0 to 15 do + let k = base + (j * 4) in + w.(j) <- + sl (Int32.of_int (Char.code @@ Bytes.get m k)) 24 + lor sl (Int32.of_int (Char.code @@ Bytes.get m (k + 1))) 16 + lor sl (Int32.of_int (Char.code @@ Bytes.get m (k + 2))) 8 + lor Int32.of_int (Char.code @@ Bytes.get m (k + 3)) + done; + (* Loop *) + a := !h0; + b := !h1; + c := !h2; + d := !h3; + e := !h4; + for t = 0 to 79 do + let f, k = + if t <= 19 then + !b land !c lor (lnot !b land !d), 0x5A827999l + else if t <= 39 then + !b lxor !c lxor !d, 0x6ED9EBA1l + else if t <= 59 then + !b land !c lor (!b land !d) lor (!c land !d), 0x8F1BBCDCl + else + !b lxor !c lxor !d, 0xCA62C1D6l + in + let s = t &&& 0xF in + if t >= 16 then + w.(s) <- + cls 1 + (w.(s + 13 &&& 0xF) + lxor w.(s + 8 &&& 0xF) + lxor w.(s + 2 &&& 0xF) + lxor w.(s)); + let temp = cls 5 !a ++ f ++ !e ++ w.(s) ++ k in + e := !d; + d := !c; + c := cls 30 !b; + b := !a; + a := temp + done; + (* Update *) + h0 := !h0 ++ !a; + h1 := !h1 ++ !b; + h2 := !h2 ++ !c; + h3 := !h3 ++ !d; + h4 := !h4 ++ !e + done; + let h = Bytes.create 20 in + let i2s h k i = + Bytes.set h k (Char.unsafe_chr (Int32.to_int (sr i 24) &&& 0xFF)); + Bytes.set h (k + 1) (Char.unsafe_chr (Int32.to_int (sr i 16) &&& 0xFF)); + Bytes.set h (k + 2) (Char.unsafe_chr (Int32.to_int (sr i 8) &&& 0xFF)); + Bytes.set h (k + 3) (Char.unsafe_chr (Int32.to_int i &&& 0xFF)) + in + i2s h 0 !h0; + i2s h 4 !h1; + i2s h 8 !h2; + i2s h 12 !h3; + i2s h 16 !h4; + Bytes.unsafe_to_string h + +(*--------------------------------------------------------------------------- + Copyright (c) 2008 The uuidm programmers + + Permission to use, copy, modify, and/or distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + ---------------------------------------------------------------------------*)