Compare commits

..

5 commits

Author SHA1 Message Date
Simon Cruanes
d9c0f94869
chore: bounds on eio 2025-06-06 22:26:23 -04:00
Simon Cruanes
a98dd9b767
CI 2025-06-06 22:26:23 -04:00
Simon Cruanes
f80df7f6a7
example with eio 2025-06-06 22:26:23 -04:00
Simon Cruanes
d40f87f07b
feat: tiny_httpd_eio library
provides a tiny_httpd server that relies on Eio for non-blocking
sockets and for concurrency using eio fibers.
2025-06-06 22:26:23 -04:00
Simon Cruanes
4b4fd2afe1
format code 2025-06-06 22:26:22 -04:00
23 changed files with 142 additions and 310 deletions

View file

@ -1,28 +0,0 @@
name: format
on:
pull_request:
push:
branches:
- main
jobs:
format:
name: format
strategy:
matrix:
ocaml-compiler:
- '5.3'
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@main
- name: Use OCaml ${{ matrix.ocaml-compiler }}
uses: ocaml/setup-ocaml@v3
with:
ocaml-compiler: ${{ matrix.ocaml-compiler }}
dune-cache: true
allow-prerelease-opam: true
- run: opam install ocamlformat.0.27.0
- run: opam exec -- make format-check

36
.github/workflows/gh-pages.yml vendored Normal file
View file

@ -0,0 +1,36 @@
name: github pages
on:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Use OCaml
uses: ocaml/setup-ocaml@v3
with:
ocaml-compiler: 5.03.x
dune-cache: true
allow-prerelease-opam: true
- name: Deps
run: opam install odig tiny_httpd tiny_httpd_camlzip tiny_httpd_eio
- name: Build
run: opam exec -- odig odoc --cache-dir=_doc/ tiny_httpd tiny_httpd_camlzip tiny_httpd_eio
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./_doc/html
destination_dir: .
enable_jekyll: false
#keep_files: true

View file

@ -16,7 +16,7 @@ jobs:
#- macos-latest #- macos-latest
#- windows-latest #- windows-latest
ocaml-compiler: ocaml-compiler:
- 4.13.x - 4.08.x
- 4.14.x - 4.14.x
- 5.03.x - 5.03.x
@ -38,15 +38,7 @@ jobs:
- run: opam install ./tiny_httpd.opam ./tiny_httpd_camlzip.opam --deps-only --with-test - run: opam install ./tiny_httpd.opam ./tiny_httpd_camlzip.opam --deps-only --with-test
- name: Build (OCaml 4.x) - run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip,tiny_httpd_eio
run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip
if: ${{ !startsWith(matrix.ocaml-compiler, '5.') }}
- name: Build (OCaml 5.x, includes eio)
run: |
opam install ./tiny_httpd.opam ./tiny_httpd_eio.opam --deps-only --with-test
opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip,tiny_httpd_eio
if: ${{ startsWith(matrix.ocaml-compiler, '5.') }}
- run: opam exec -- dune build @src/runtest @examples/runtest @tests/runtest -p tiny_httpd - run: opam exec -- dune build @src/runtest @examples/runtest @tests/runtest -p tiny_httpd
if: ${{ matrix.os == 'ubuntu-latest' }} if: ${{ matrix.os == 'ubuntu-latest' }}
@ -58,10 +50,4 @@ jobs:
- run: opam install logs magic-mime -y - run: opam install logs magic-mime -y
- name: Final build (OCaml 4.x) - run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip,tiny_httpd_eio
run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip
if: ${{ !startsWith(matrix.ocaml-compiler, '5.') }}
- name: Final build (OCaml 5.x, includes eio)
run: opam exec -- dune build @install -p tiny_httpd,tiny_httpd_camlzip,tiny_httpd_eio
if: ${{ startsWith(matrix.ocaml-compiler, '5.') }}

2
.gitignore vendored
View file

@ -3,5 +3,3 @@ _build
_opam _opam
*.install *.install
.merlin .merlin
todo.md
*.tmp

View file

@ -23,12 +23,12 @@
result result
hmap hmap
(iostream (>= 0.2)) (iostream (>= 0.2))
(ocaml (>= 4.13)) (ocaml (>= 4.08))
(odoc :with-doc) (odoc :with-doc)
(logs :with-test) (logs :with-test)
(conf-libcurl :with-test) (conf-libcurl :with-test)
(ptime :with-test) (ptime :with-test)
(qcheck-core (and (>= 0.91) :with-test)))) (qcheck-core (and (>= 0.9) :with-test))))
(package (package
(name tiny_httpd_camlzip) (name tiny_httpd_camlzip)
@ -46,6 +46,5 @@
(depends (depends
(tiny_httpd (= :version)) (tiny_httpd (= :version))
(eio (and (>= 1.0) (< 2.0))) (eio (and (>= 1.0) (< 2.0)))
base-unix
(logs :with-test) (logs :with-test)
(odoc :with-doc))) (odoc :with-doc)))

View file

@ -73,10 +73,10 @@
; produce an embedded FS ; produce an embedded FS
(library (library
(name echo_vfs) (name echo_vfs)
(modules vfs) (modules vfs)
(wrapped false) (wrapped false)
(libraries tiny_httpd)) (libraries tiny_httpd))
(rule (rule
(targets vfs.ml) (targets vfs.ml)

View file

@ -69,10 +69,12 @@ let middleware_stat () : Server.Middleware.t * (unit -> string) =
let middleware_trace : Server.Middleware.t = let middleware_trace : Server.Middleware.t =
fun (h : Server.Middleware.handler) req ~resp -> fun (h : Server.Middleware.handler) req ~resp ->
let _sp = Trace.enter_span ~__FILE__ ~__LINE__ "http.handle" in let _sp =
Trace.enter_manual_toplevel_span ~__FILE__ ~__LINE__ "http.handle"
in
let new_resp (r : Response.t) = let new_resp (r : Response.t) =
Trace.add_data_to_span _sp [ "http.code", `Int r.code ]; Trace.add_data_to_manual_span _sp [ "http.code", `Int r.code ];
Trace.exit_span _sp; Trace.exit_manual_span _sp;
resp r resp r
in in
h req ~resp:new_resp h req ~resp:new_resp

View file

@ -83,13 +83,8 @@ let parse_line_ (line : string) : _ result =
Ok (k, v) Ok (k, v)
with Failure msg -> Error msg with Failure msg -> Error msg
let parse_ ~(buf : Buf.t) ?(max_headers = 100) ?(max_header_size = 16 * 1024) let parse_ ~(buf : Buf.t) (bs : IO.Input.t) : t =
?(max_total_size = 256 * 1024) (bs : IO.Input.t) : t = let rec loop acc =
let rec loop acc count total_size =
if count >= max_headers then
bad_reqf 431 "too many headers (max: %d)" max_headers;
if total_size >= max_total_size then
bad_reqf 431 "headers too large (max: %d bytes)" max_total_size;
match IO.Input.read_line_using_opt ~buf bs with match IO.Input.read_line_using_opt ~buf bs with
| None -> raise End_of_file | None -> raise End_of_file
| Some "" -> assert false | Some "" -> assert false
@ -97,15 +92,12 @@ let parse_ ~(buf : Buf.t) ?(max_headers = 100) ?(max_header_size = 16 * 1024)
| Some line when line.[String.length line - 1] <> '\r' -> | Some line when line.[String.length line - 1] <> '\r' ->
bad_reqf 400 "bad header line, not ended in CRLF" bad_reqf 400 "bad header line, not ended in CRLF"
| Some line -> | Some line ->
let line_len = String.length line in
if line_len > max_header_size then
bad_reqf 431 "header too large (max: %d bytes)" max_header_size;
let k, v = let k, v =
match parse_line_ line with match parse_line_ line with
| Ok r -> r | Ok r -> r
| Error msg -> | Error msg ->
bad_reqf 400 "invalid header line: %s\nline is: %S" msg line bad_reqf 400 "invalid header line: %s\nline is: %S" msg line
in in
loop ((k, v) :: acc) (count + 1) (total_size + line_len) loop ((k, v) :: acc)
in in
loop [] 0 0 loop []

View file

@ -34,14 +34,7 @@ val pp : Format.formatter -> t -> unit
(**/*) (**/*)
val parse_ : val parse_ : buf:Buf.t -> IO.Input.t -> t
buf:Buf.t ->
?max_headers:int ->
?max_header_size:int ->
?max_total_size:int ->
IO.Input.t ->
t
val parse_line_ : string -> (string * string, string) result val parse_line_ : string -> (string * string, string) result
(**/*) (**/*)

View file

@ -25,7 +25,6 @@ let descr = function
| 411 -> "Length required" | 411 -> "Length required"
| 413 -> "Payload too large" | 413 -> "Payload too large"
| 417 -> "Expectation failed" | 417 -> "Expectation failed"
| 431 -> "Request Header Fields Too Large"
| 500 -> "Internal server error" | 500 -> "Internal server error"
| 501 -> "Not implemented" | 501 -> "Not implemented"
| 503 -> "Service unavailable" | 503 -> "Service unavailable"

View file

@ -55,9 +55,6 @@ val to_string : _ t -> string
val to_url : ('a, string) t -> 'a val to_url : ('a, string) t -> 'a
(** [to_url route args] takes a route, and turns it into a URL path.
@since NEXT_RELEASE *)
module Private_ : sig module Private_ : sig
val eval : string list -> ('a, 'b) t -> 'a -> 'b option val eval : string list -> ('a, 'b) t -> 'a -> 'b option
end end

View file

@ -1,6 +1,8 @@
(library (library
(name tiny_httpd_eio) (name tiny_httpd_eio)
(public_name tiny_httpd_eio) (public_name tiny_httpd_eio)
(synopsis "An EIO-based backend for Tiny_httpd") (synopsis "An EIO-based backend for Tiny_httpd")
(flags :standard -safe-string -warn-error -a+8) (flags :standard -safe-string -warn-error -a+8)
(libraries tiny_httpd eio eio.unix)) (libraries tiny_httpd eio eio.unix))

View file

@ -31,10 +31,8 @@ let eio_sock_addr_to_unix (a : Eio.Net.Sockaddr.stream) : Unix.sockaddr =
| `Tcp (h, p) -> Unix.ADDR_INET (eio_ipaddr_to_unix h, p) | `Tcp (h, p) -> Unix.ADDR_INET (eio_ipaddr_to_unix h, p)
| `Unix s -> Unix.ADDR_UNIX s | `Unix s -> Unix.ADDR_UNIX s
let ic_of_flow ~closed ~buf_pool:ic_pool (flow : _ Eio.Net.stream_socket) : let ic_of_flow ~buf_pool:ic_pool (flow : _ Eio.Net.stream_socket) : IO.Input.t =
IO.Input.t =
let cstruct = Pool.Raw.acquire ic_pool in let cstruct = Pool.Raw.acquire ic_pool in
let sent_shutdown = ref false in
object object
inherit Iostream.In_buf.t_from_refill () inherit Iostream.In_buf.t_from_refill ()
@ -54,22 +52,15 @@ let ic_of_flow ~closed ~buf_pool:ic_pool (flow : _ Eio.Net.stream_socket) :
sl.len <- n sl.len <- n
method close () = method close () =
if not !closed then ( Pool.Raw.release ic_pool cstruct;
closed := true; Eio.Flow.shutdown flow `Receive
Pool.Raw.release ic_pool cstruct
);
if not !sent_shutdown then (
sent_shutdown := true;
Eio.Flow.shutdown flow `Receive
)
end end
let oc_of_flow ~closed ~buf_pool:oc_pool (flow : _ Eio.Net.stream_socket) : let oc_of_flow ~buf_pool:oc_pool (flow : _ Eio.Net.stream_socket) : IO.Output.t
IO.Output.t = =
(* write buffer *) (* write buffer *)
let wbuf : Cstruct.t = Pool.Raw.acquire oc_pool in let wbuf : Cstruct.t = Pool.Raw.acquire oc_pool in
let offset = ref 0 in let offset = ref 0 in
let sent_shutdown = ref false in
object (self) object (self)
method flush () : unit = method flush () : unit =
@ -100,14 +91,8 @@ let oc_of_flow ~closed ~buf_pool:oc_pool (flow : _ Eio.Net.stream_socket) :
if !offset = Cstruct.length wbuf then self#flush () if !offset = Cstruct.length wbuf then self#flush ()
method close () = method close () =
if not !closed then ( Pool.Raw.release oc_pool wbuf;
closed := true; Eio.Flow.shutdown flow `Send
Pool.Raw.release oc_pool wbuf
);
if not !sent_shutdown then (
sent_shutdown := true;
Eio.Flow.shutdown flow `Send
)
end end
let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size
@ -133,8 +118,7 @@ let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size
let module M = struct let module M = struct
let init_addr () = addr let init_addr () = addr
let init_port () = port let init_port () = port
let clock = Eio.Stdenv.clock stdenv let get_time_s () = Unix.gettimeofday ()
let get_time_s () = Eio.Time.now clock
let max_connections = get_max_connection_ ?max_connections () let max_connections = get_max_connection_ ?max_connections ()
let pool_size = let pool_size =
@ -143,7 +127,7 @@ let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size
| None -> min 4096 (max_connections * 2) | None -> min 4096 (max_connections * 2)
let cstruct_pool = let cstruct_pool =
Pool.create ~max_size:pool_size Pool.create ~max_size:max_connections
~mk_item:(fun () -> Cstruct.create buf_size) ~mk_item:(fun () -> Cstruct.create buf_size)
() ()
@ -153,7 +137,6 @@ let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size
(fun ~after_init ~handle () : unit -> (fun ~after_init ~handle () : unit ->
let running = Atomic.make true in let running = Atomic.make true in
let active_conns = Atomic.make 0 in let active_conns = Atomic.make 0 in
let sem = Eio.Semaphore.make max_connections in
Eio.Switch.on_release sw (fun () -> Atomic.set running false); Eio.Switch.on_release sw (fun () -> Atomic.set running false);
let net = Eio.Stdenv.net stdenv in let net = Eio.Stdenv.net stdenv in
@ -165,26 +148,17 @@ let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size
sockaddr sockaddr
in in
(* Resolve actual address/port (important for port 0) *)
let actual_addr, actual_port =
match Eio.Net.listening_addr sock with
| `Tcp (_, p) -> addr, p
| `Unix s -> Printf.sprintf "unix:%s" s, 0
in
let tcp_server : IO.TCP_server.t = let tcp_server : IO.TCP_server.t =
{ {
running = (fun () -> Atomic.get running); running = (fun () -> Atomic.get running);
stop = stop =
(fun () -> (fun () ->
Atomic.set running false; Atomic.set running false;
(* Backstop: fail the switch after 60s if handlers don't complete *) Eio.Switch.fail sw Exit);
Eio.Fiber.fork_daemon ~sw (fun () -> endpoint =
Eio.Time.sleep clock 60.0; (fun () ->
if Eio.Switch.get_error sw |> Option.is_none then (* TODO: find the real port *)
Eio.Switch.fail sw Exit; addr, port);
`Stop_daemon));
endpoint = (fun () -> actual_addr, actual_port);
active_connections = (fun () -> Atomic.get active_conns); active_connections = (fun () -> Atomic.get active_conns);
} }
in in
@ -192,50 +166,33 @@ let io_backend ?addr ?port ?unix_sock ?max_connections ?max_buf_pool_size
after_init tcp_server; after_init tcp_server;
while Atomic.get running do while Atomic.get running do
match Eio.Net.accept ~sw sock with Eio.Net.accept_fork ~sw
| exception (Eio.Cancel.Cancelled _ | Eio.Io _) ~on_error:(fun exn ->
when not (Atomic.get running) -> Log.error (fun k ->
(* Socket closed or switch cancelled during shutdown; exit loop *) k "error in client handler: %s" (Printexc.to_string exn)))
() sock
| conn, client_addr -> (fun flow client_addr ->
(* Acquire semaphore BEFORE spawning a fiber so we Atomic.incr active_conns;
bound the number of in-flight fibers. *) let@ () =
Eio.Semaphore.acquire sem; Fun.protect ~finally:(fun () ->
Eio.Fiber.fork ~sw (fun () -> Log.debug (fun k ->
let@ () = k "Tiny_httpd_eio: client handler returned");
Fun.protect ~finally:(fun () -> Atomic.decr active_conns)
Log.debug (fun k -> in
k "Tiny_httpd_eio: client handler returned"); let ic = ic_of_flow ~buf_pool:cstruct_pool flow in
Atomic.decr active_conns; let oc = oc_of_flow ~buf_pool:cstruct_pool flow in
Eio.Semaphore.release sem;
try Eio.Flow.close conn with Eio.Io _ -> ())
in
(try
Eio_unix.Fd.use_exn "setsockopt" (Eio_unix.Net.fd conn)
(fun fd -> Unix.setsockopt fd Unix.TCP_NODELAY true)
with Unix.Unix_error _ -> ());
Atomic.incr active_conns;
let ic_closed = ref false in
let oc_closed = ref false in
let ic =
ic_of_flow ~closed:ic_closed ~buf_pool:cstruct_pool conn
in
let oc =
oc_of_flow ~closed:oc_closed ~buf_pool:cstruct_pool conn
in
Log.debug (fun k -> Log.debug (fun k ->
k "handling client on %a…" Eio.Net.Sockaddr.pp k "handling client on %a…" Eio.Net.Sockaddr.pp client_addr);
client_addr); let client_addr_unix = eio_sock_addr_to_unix client_addr in
let client_addr_unix = eio_sock_addr_to_unix client_addr in try handle.handle ~client_addr:client_addr_unix ic oc
try handle.handle ~client_addr:client_addr_unix ic oc with exn ->
with exn -> let bt = Printexc.get_raw_backtrace () in
let bt = Printexc.get_raw_backtrace () in Log.error (fun k ->
Log.error (fun k -> k "Client handler for %a failed with %s\n%s"
k "Client handler for %a failed with %s\n%s" Eio.Net.Sockaddr.pp client_addr
Eio.Net.Sockaddr.pp client_addr (Printexc.to_string exn)
(Printexc.to_string exn) (Printexc.raw_backtrace_to_string bt)))
(Printexc.raw_backtrace_to_string bt)))
done); done);
} }
end in end in

View file

@ -43,27 +43,6 @@ let contains_dot_dot s =
false false
with Exit -> true with Exit -> true
(* Check if string [s] starts with prefix [pre] *)
let string_prefix ~pre s =
let len_pre = String.length pre in
String.length s >= len_pre && String.sub s 0 len_pre = pre
(* Check if a path is safe (doesn't escape root directory).
Only needed for real filesystem access. *)
let is_path_safe ~root_canonical ~path =
try
let full_path = Filename.concat root_canonical path in
let path_canonical = Unix.realpath full_path in
string_prefix ~pre:root_canonical path_canonical
with Unix.Unix_error _ ->
(* If realpath fails (e.g., file doesn't exist for uploads),
check parent directory *)
(try
let parent = Filename.dirname (Filename.concat root_canonical path) in
let parent_canonical = Unix.realpath parent in
string_prefix ~pre:root_canonical parent_canonical
with Unix.Unix_error _ -> false)
(* Human readable size *) (* Human readable size *)
let human_size (x : int) : string = let human_size (x : int) : string =
if x >= 1_000_000_000 then if x >= 1_000_000_000 then
@ -227,17 +206,6 @@ let html_list_dir (module VFS : VFS) ~prefix ~parent d : Html.elt =
(* @param on_fs: if true, we assume the file exists on the FS *) (* @param on_fs: if true, we assume the file exists on the FS *)
let add_vfs_ ~on_fs ~top ~config ~vfs:((module VFS : VFS) as vfs) ~prefix server let add_vfs_ ~on_fs ~top ~config ~vfs:((module VFS : VFS) as vfs) ~prefix server
: unit = : unit =
let root_canonical =
if on_fs then (
try Some (Unix.realpath top) with _ -> None
) else
None
in
let check_path path =
match root_canonical with
| Some root -> is_path_safe ~root_canonical:root ~path
| None -> not (contains_dot_dot path)
in
let route () = let route () =
if prefix = "" then if prefix = "" then
Route.rest_of_path_urlencoded Route.rest_of_path_urlencoded
@ -246,7 +214,7 @@ let add_vfs_ ~on_fs ~top ~config ~vfs:((module VFS : VFS) as vfs) ~prefix server
in in
if config.delete then if config.delete then
S.add_route_handler server ~meth:`DELETE (route ()) (fun path _req -> S.add_route_handler server ~meth:`DELETE (route ()) (fun path _req ->
if not (check_path path) then if contains_dot_dot path then
Response.fail_raise ~code:403 "invalid path in delete" Response.fail_raise ~code:403 "invalid path in delete"
else else
Response.make_string Response.make_string
@ -265,7 +233,7 @@ let add_vfs_ ~on_fs ~top ~config ~vfs:((module VFS : VFS) as vfs) ~prefix server
| Some n when n > config.max_upload_size -> | Some n when n > config.max_upload_size ->
Error Error
(403, "max upload size is " ^ string_of_int config.max_upload_size) (403, "max upload size is " ^ string_of_int config.max_upload_size)
| Some _ when not (check_path req.Request.path) -> | Some _ when contains_dot_dot req.Request.path ->
Error (403, "invalid path (contains '..')") Error (403, "invalid path (contains '..')")
| _ -> Ok ()) | _ -> Ok ())
(fun path req -> (fun path req ->
@ -296,7 +264,7 @@ let add_vfs_ ~on_fs ~top ~config ~vfs:((module VFS : VFS) as vfs) ~prefix server
| None -> Response.fail_raise ~code:403 "Cannot access file" | None -> Response.fail_raise ~code:403 "Cannot access file"
| Some t -> Printf.sprintf "mtime: %.4f" t) | Some t -> Printf.sprintf "mtime: %.4f" t)
in in
if not (check_path path) then if contains_dot_dot path then
Response.fail ~code:403 "Path is forbidden" Response.fail ~code:403 "Path is forbidden"
else if not (VFS.contains path) then else if not (VFS.contains path) then
Response.fail ~code:404 "File not found" Response.fail ~code:404 "File not found"

View file

@ -51,7 +51,7 @@
(public_name tiny_httpd.ws) (public_name tiny_httpd.ws)
(synopsis "Websockets for tiny_httpd") (synopsis "Websockets for tiny_httpd")
(private_modules common_ws_ utils_) (private_modules common_ws_ utils_)
(flags :standard -w -32 -open Tiny_httpd_core) (flags :standard -open Tiny_httpd_core)
(foreign_stubs (foreign_stubs
(language c) (language c)
(names tiny_httpd_ws_stubs) (names tiny_httpd_ws_stubs)

View file

@ -1,36 +1,15 @@
open Common_ws_ open Common_ws_
module With_lock = struct
type t = { with_lock: 'a. (unit -> 'a) -> 'a }
type builder = unit -> t
let default_builder : builder =
fun () ->
let mutex = Mutex.create () in
{
with_lock =
(fun f ->
Mutex.lock mutex;
try
let x = f () in
Mutex.unlock mutex;
x
with e ->
Mutex.unlock mutex;
raise e);
}
end
type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit
module Frame_type = struct module Frame_type = struct
type t = int type t = int
let _continuation : t = 0 let continuation : t = 0
let _text : t = 1 let text : t = 1
let binary : t = 2 let binary : t = 2
let _close : t = 8 let close : t = 8
let _ping : t = 9 let ping : t = 9
let pong : t = 10 let pong : t = 10
let show = function let show = function
@ -73,10 +52,10 @@ module Writer = struct
mutable offset: int; (** number of bytes already in [buf] *) mutable offset: int; (** number of bytes already in [buf] *)
oc: IO.Output.t; oc: IO.Output.t;
mutable closed: bool; mutable closed: bool;
mutex: With_lock.t; mutex: Mutex.t;
} }
let create ?(buf_size = 16 * 1024) ~with_lock ~oc () : t = let create ?(buf_size = 16 * 1024) ~oc () : t =
{ {
header = Header.create (); header = Header.create ();
header_buf = Bytes.create 16; header_buf = Bytes.create 16;
@ -84,9 +63,19 @@ module Writer = struct
offset = 0; offset = 0;
oc; oc;
closed = false; closed = false;
mutex = with_lock; mutex = Mutex.create ();
} }
let[@inline] with_mutex_ (self : t) f =
Mutex.lock self.mutex;
try
let x = f () in
Mutex.unlock self.mutex;
x
with e ->
Mutex.unlock self.mutex;
raise e
let[@inline] close self = self.closed <- true let[@inline] close self = self.closed <- true
let int_of_bool : bool -> int = Obj.magic let int_of_bool : bool -> int = Obj.magic
@ -132,7 +121,7 @@ module Writer = struct
() ()
(** Max fragment size: send 16 kB at a time *) (** Max fragment size: send 16 kB at a time *)
let _max_fragment_size = 16 * 1024 let max_fragment_size = 16 * 1024
let[@inline never] really_output_buf_ (self : t) = let[@inline never] really_output_buf_ (self : t) =
self.header.fin <- true; self.header.fin <- true;
@ -153,7 +142,7 @@ module Writer = struct
if self.offset = Bytes.length self.buf then really_output_buf_ self if self.offset = Bytes.length self.buf then really_output_buf_ self
let send_pong (self : t) : unit = let send_pong (self : t) : unit =
let@ () = self.mutex.with_lock in let@ () = with_mutex_ self in
self.header.fin <- true; self.header.fin <- true;
self.header.ty <- Frame_type.pong; self.header.ty <- Frame_type.pong;
self.header.payload_len <- 0; self.header.payload_len <- 0;
@ -162,7 +151,7 @@ module Writer = struct
write_header_ self write_header_ self
let output_char (self : t) c : unit = let output_char (self : t) c : unit =
let@ () = self.mutex.with_lock in let@ () = with_mutex_ self in
let cap = Bytes.length self.buf - self.offset in let cap = Bytes.length self.buf - self.offset in
(* make room for [c] *) (* make room for [c] *)
if cap = 0 then really_output_buf_ self; if cap = 0 then really_output_buf_ self;
@ -172,7 +161,7 @@ module Writer = struct
if cap = 1 then really_output_buf_ self if cap = 1 then really_output_buf_ self
let output (self : t) buf i len : unit = let output (self : t) buf i len : unit =
let@ () = self.mutex.with_lock in let@ () = with_mutex_ self in
let i = ref i in let i = ref i in
let len = ref len in let len = ref len in
while !len > 0 do while !len > 0 do
@ -190,7 +179,7 @@ module Writer = struct
flush_if_full self flush_if_full self
let flush self : unit = let flush self : unit =
let@ () = self.mutex.with_lock in let@ () = with_mutex_ self in
flush_ self flush_ self
end end
@ -401,8 +390,8 @@ module Reader = struct
) )
end end
let upgrade ?(with_lock = With_lock.default_builder ()) ic oc : _ * _ = let upgrade ic oc : _ * _ =
let writer = Writer.create ~with_lock ~oc () in let writer = Writer.create ~oc () in
let reader = Reader.create ~ic ~writer () in let reader = Reader.create ~ic ~writer () in
let ws_ic : IO.Input.t = let ws_ic : IO.Input.t =
object object
@ -429,7 +418,6 @@ let upgrade ?(with_lock = With_lock.default_builder ()) ic oc : _ * _ =
upgrade handler *) upgrade handler *)
module Make_upgrade_handler (X : sig module Make_upgrade_handler (X : sig
val accept_ws_protocol : string -> bool val accept_ws_protocol : string -> bool
val with_lock : With_lock.builder
val handler : handler val handler : handler
end) : Server.UPGRADE_HANDLER with type handshake_state = unit Request.t = end) : Server.UPGRADE_HANDLER with type handshake_state = unit Request.t =
struct struct
@ -474,8 +462,7 @@ struct
try Ok (handshake_ req) with Bad_req s -> Error s try Ok (handshake_ req) with Bad_req s -> Error s
let handle_connection req ic oc = let handle_connection req ic oc =
let with_lock = X.with_lock () in let ws_ic, ws_oc = upgrade ic oc in
let ws_ic, ws_oc = upgrade ~with_lock ic oc in
try X.handler req ws_ic ws_oc try X.handler req ws_ic ws_oc
with Close_connection -> with Close_connection ->
Log.debug (fun k -> k "websocket: requested to close the connection"); Log.debug (fun k -> k "websocket: requested to close the connection");
@ -483,11 +470,9 @@ struct
end end
let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true) ?middlewares let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true) ?middlewares
?(with_lock = With_lock.default_builder) (server : Server.t) route (server : Server.t) route (f : handler) : unit =
(f : handler) : unit =
let module M = Make_upgrade_handler (struct let module M = Make_upgrade_handler (struct
let handler = f let handler = f
let with_lock = with_lock
let accept_ws_protocol = accept_ws_protocol let accept_ws_protocol = accept_ws_protocol
end) in end) in
let up : Server.upgrade_handler = (module M) in let up : Server.upgrade_handler = (module M) in

View file

@ -3,36 +3,11 @@
This sub-library ([tiny_httpd.ws]) exports a small implementation for a This sub-library ([tiny_httpd.ws]) exports a small implementation for a
websocket server. It has no additional dependencies. *) websocket server. It has no additional dependencies. *)
(** Synchronization primitive used to allow both the reader to reply to "ping",
and the handler to send messages, without stepping on each other's toes.
@since NEXT_RELEASE *)
module With_lock : sig
type t = { with_lock: 'a. (unit -> 'a) -> 'a }
(** A primitive to run the callback in a critical section where others cannot
run at the same time.
The default is a mutex, but that works poorly with thread pools so it's
possible to use a semaphore or a cooperative mutex instead. *)
type builder = unit -> t
val default_builder : builder
(** Lock using [Mutex]. *)
end
type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit
(** Websocket handler *) (** Websocket handler *)
val upgrade : val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t
?with_lock:With_lock.t -> (** Upgrade a byte stream to the websocket framing protocol. *)
IO.Input.t ->
IO.Output.t ->
IO.Input.t * IO.Output.t
(** Upgrade a byte stream to the websocket framing protocol.
@param with_lock
if provided, use this to prevent reader and writer to compete on sending
frames. since NEXT_RELEASE. *)
exception Close_connection exception Close_connection
(** Exception that can be raised from IOs inside the handler, when the (** Exception that can be raised from IOs inside the handler, when the
@ -42,7 +17,6 @@ val add_route_handler :
?accept:(unit Request.t -> (unit, int * string) result) -> ?accept:(unit Request.t -> (unit, int * string) result) ->
?accept_ws_protocol:(string -> bool) -> ?accept_ws_protocol:(string -> bool) ->
?middlewares:Server.Head_middleware.t list -> ?middlewares:Server.Head_middleware.t list ->
?with_lock:With_lock.builder ->
Server.t -> Server.t ->
(Server.upgrade_handler, Server.upgrade_handler) Route.t -> (Server.upgrade_handler, Server.upgrade_handler) Route.t ->
handler -> handler ->
@ -50,11 +24,7 @@ val add_route_handler :
(** Add a route handler for a websocket endpoint. (** Add a route handler for a websocket endpoint.
@param accept_ws_protocol @param accept_ws_protocol
decides whether this endpoint accepts the websocket protocol sent by the decides whether this endpoint accepts the websocket protocol sent by the
client. Default accepts everything. client. Default accepts everything. *)
@param with_lock
if provided, use this to synchronize writes between the frame reader
(replies "pong" to "ping") and the handler emitting writes. since
NEXT_RELEASE. *)
(**/**) (**/**)

View file

@ -1,4 +1,4 @@
(tests (tests
(names t_util t_buf t_server t_io t_response t_headers) (names t_util t_buf t_server t_io t_response)
(package tiny_httpd) (package tiny_httpd)
(libraries tiny_httpd.core qcheck-core qcheck-core.runner test_util)) (libraries tiny_httpd.core qcheck-core qcheck-core.runner test_util))

View file

@ -1,23 +0,0 @@
open Tiny_httpd_core
(* Test that header size limits are enforced *)
let test_header_too_large () =
(* Create a header that's larger than 16KB *)
let large_value = String.make 20000 'x' in
let q =
"GET / HTTP/1.1\r\nHost: example.com\r\nX-Large: " ^ large_value
^ "\r\n\r\n"
in
let str = IO.Input.of_string q in
let client_addr = Unix.(ADDR_INET (inet_addr_loopback, 1024)) in
let buf = Buf.create () in
try
let _ =
Request.Private_.parse_req_start_exn ~client_addr ~buf
~get_time_s:(fun _ -> 0.)
str
in
failwith "should have failed with 431"
with Tiny_httpd_core.Response.Bad_req (431, _) -> () (* expected *)
let () = test_header_too_large ()

View file

@ -40,7 +40,7 @@ let () = assert_eq (Ok [ "foo", "bar" ]) (U.parse_query "yolo#foo=bar")
let () = let () =
add_qcheck add_qcheck
@@ QCheck.Test.make ~name:__LOC__ ~long_factor:20 ~count:1_000 @@ QCheck.Test.make ~name:__LOC__ ~long_factor:20 ~count:1_000
Q.(list_small (pair string string)) Q.(small_list (pair string string))
(fun l -> (fun l ->
List.iter List.iter
(fun (a, b) -> (fun (a, b) ->

View file

@ -14,9 +14,9 @@ let () =
@@ QCheck.Test.make ~count:10_000 @@ QCheck.Test.make ~count:10_000
Q.( Q.(
triple triple
(bytes_size (Gen.return 4)) (bytes_of_size (Gen.return 4))
(option nat_small) (option small_nat)
(bytes_size Gen.(0 -- 6000)) (bytes_of_size Gen.(0 -- 6000))
(* |> Q.add_stat ("b.size", fun (_k, b) -> Bytes.length b) *) (* |> Q.add_stat ("b.size", fun (_k, b) -> Bytes.length b) *)
|> Q.add_shrink_invariant (fun (k, _, _) -> Bytes.length k = 4)) |> Q.add_shrink_invariant (fun (k, _, _) -> Bytes.length k = 4))
(fun (key, mask_offset, b) -> (fun (key, mask_offset, b) ->

View file

@ -17,12 +17,12 @@ depends: [
"result" "result"
"hmap" "hmap"
"iostream" {>= "0.2"} "iostream" {>= "0.2"}
"ocaml" {>= "4.13"} "ocaml" {>= "4.08"}
"odoc" {with-doc} "odoc" {with-doc}
"logs" {with-test} "logs" {with-test}
"conf-libcurl" {with-test} "conf-libcurl" {with-test}
"ptime" {with-test} "ptime" {with-test}
"qcheck-core" {>= "0.91" & with-test} "qcheck-core" {>= "0.9" & with-test}
] ]
depopts: [ depopts: [
"logs" "logs"

View file

@ -11,7 +11,6 @@ depends: [
"dune" {>= "3.2"} "dune" {>= "3.2"}
"tiny_httpd" {= version} "tiny_httpd" {= version}
"eio" {>= "1.0" & < "2.0"} "eio" {>= "1.0" & < "2.0"}
"base-unix"
"logs" {with-test} "logs" {with-test}
"odoc" {with-doc} "odoc" {with-doc}
] ]