diff --git a/src/Tiny_httpd.ml b/src/Tiny_httpd.ml index f67866e2..8fa91fe3 100644 --- a/src/Tiny_httpd.ml +++ b/src/Tiny_httpd.ml @@ -82,16 +82,36 @@ module Byte_stream = struct let of_chan = of_chan_ ~close:close_in let of_chan_close_noerr = of_chan_ ~close:close_in_noerr - let rec iter f (self:t) : unit = + let of_fd (fd:Unix.file_descr) : t = + let i = ref 0 in + let len = ref 0 in + let buf = Bytes.make 4096 ' ' in + { bs_fill_buf=(fun () -> + if !i >= !len then ( + i := 0; + len := Unix.read fd buf 0 (Bytes.length buf); + ); + buf, !i, !len - !i); + bs_consume=(fun n -> i := !i + n); + bs_close=(fun () -> Unix.close fd) + } + + let rec iter_full f (self:t) : unit = let s, i, len = self.bs_fill_buf () in if len=0 then ( self.bs_close(); ) else ( - f s i len; - self.bs_consume len; - (iter [@tailcall]) f self + let n = f s i len in + assert (n f s i len; len) self + + let to_fd (fd:Unix.file_descr) (self:t) = + iter_full (fun s i len -> Unix.write fd s i len) self + let to_chan (oc:out_channel) (self:t) = iter (fun s i len -> output oc s i len) self @@ -449,7 +469,7 @@ module Request = struct in Ok (Some {meth; host; path; headers; body=()}) with - | End_of_file | Sys_error _ -> Ok None + | End_of_file | Sys_error _ | Unix.Unix_error _ -> Ok None | Bad_req (c,s) -> Error (c,s) | e -> Error (400, Printexc.to_string e) @@ -558,32 +578,42 @@ module Response = struct Format.fprintf out "{@[code=%d;@ headers=%a;@ body=%a@]}" self.code Headers.pp self.headers pp_body self.body + let rec write_fd_ (fd:Unix.file_descr) s i len : unit = + let n = Unix.write fd s i len in + if n < len then ( + (write_fd_ [@tailcall]) fd s (i+n) (len-n) + ) + + let write_fd_str_ fd s = + write_fd_ fd (Bytes.unsafe_of_string s) 0 (String.length s) + (* print a stream as a series of chunks *) - let output_stream_chunked_ (oc:out_channel) (str:byte_stream) : unit = + let output_stream_chunked_ (fd:Unix.file_descr) (str:byte_stream) : unit = let continue = ref true in while !continue do (* next chunk *) let s, i, len = str.bs_fill_buf () in - Printf.fprintf oc "%x\r\n" len; - output oc s i len; + write_fd_str_ fd @@ Printf.sprintf "%x\r\n" len; + write_fd_ fd s i len; str.bs_consume len; if len = 0 then ( continue := false; ); - output_string oc "\r\n"; + write_fd_str_ fd "\r\n"; done; () - let output_ (oc:out_channel) (self:t) : unit = - Printf.fprintf oc "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code); - List.iter (fun (k,v) -> Printf.fprintf oc "%s: %s\r\n" k v) self.headers; - output_string oc "\r\n"; + let output_ (fd:Unix.file_descr) (self:t) : unit = + write_fd_str_ fd + (Printf.sprintf "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code)); + List.iter (fun (k,v) -> write_fd_str_ fd @@ Printf.sprintf "%s: %s\r\n" k v) self.headers; + write_fd_str_ fd "\r\n"; begin match self.body with | `String "" -> () - | `String s -> output_string oc s; - | `Stream str -> output_stream_chunked_ oc str; + | `String s -> write_fd_str_ fd s; + | `Stream str -> output_stream_chunked_ fd str; end; - flush oc + () end (* semaphore, for limiting concurrency. *) @@ -688,10 +718,8 @@ let find_map f l = in aux f l let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = - let ic = Unix.in_channel_of_descr client_sock in - let oc = Unix.out_channel_of_descr client_sock in let buf = Buf_.create() in - let is = Byte_stream.of_chan ic in + let is = Byte_stream.of_fd client_sock in let continue = ref true in while !continue && self.running do _debug (fun k->k "read next request"); @@ -700,8 +728,8 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = | Error (c,s) -> let res = Response.make_raw ~code:c s in begin - try Response.output_ oc res - with Sys_error _ -> () + try Response.output_ client_sock res + with Sys_error _ | Unix.Unix_error _ -> () end; continue := false | Ok (Some req) -> @@ -717,7 +745,7 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = begin match Request.get_header ~f:String.trim req "Expect" with | Some "100-continue" -> _debug (fun k->k "send back: 100 CONTINUE"); - Response.output_ oc (Response.make_raw ~code:100 ""); + Response.output_ client_sock (Response.make_raw ~code:100 ""); | Some s -> bad_reqf 417 "unknown expectation %s" s | None -> () end; @@ -748,13 +776,13 @@ let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = Response.fail ~code:500 "server error: %s" (Printexc.to_string e) in begin - try Response.output_ oc res - with Sys_error _ -> continue := false + try Response.output_ client_sock res + with Sys_error _ | Unix.Unix_error _ -> continue := false end | exception Bad_req (code,s) -> - Response.output_ oc (Response.make_raw ~code s); + Response.output_ client_sock (Response.make_raw ~code s); continue := false - | exception Sys_error _ -> + | exception (Sys_error _ | Unix.Unix_error _) -> continue := false; (* connection broken somehow *) done; _debug (fun k->k "done with client, exiting"); @@ -778,18 +806,14 @@ let run (self:t) : (unit,_) result = Unix.setsockopt_optint sock Unix.SO_LINGER None; let inet_addr = Unix.inet_addr_of_string self.addr in Unix.bind sock (Unix.ADDR_INET (inet_addr, self.port)); - Unix.listen sock (self.sem_max_connections.Sem_.n); - - ignore @@ Thread.create (fun () -> - while true do - _debug (fun k->k "sem: %d" self.sem_max_connections.n); - Unix.sleep 1; - done) (); - + Unix.listen sock (2 * self.sem_max_connections.Sem_.n); while self.running do (* limit concurrency *) - let client_sock, _ = Unix.accept sock in Sem_.acquire self.sem_max_connections; + let client_sock, _ = Unix.accept sock in + (try Unix.setsockopt_optint client_sock Unix.SO_LINGER None with _->()); + (try Unix.setsockopt_float client_sock SO_RCVTIMEO 5. with _ -> ()); + (try Unix.setsockopt_float client_sock SO_SNDTIMEO 5. with _ -> ()); self.new_thread (fun () -> try diff --git a/src/Tiny_httpd.mli b/src/Tiny_httpd.mli index d5b4824d..1142fd28 100644 --- a/src/Tiny_httpd.mli +++ b/src/Tiny_httpd.mli @@ -1,4 +1,3 @@ - (** {1 Tiny Http Server} This library implements a very simple, basic HTTP/1.1 server using blocking @@ -112,6 +111,18 @@ module Byte_stream : sig val of_chan : in_channel -> t (** Make a buffered stream from the given channel. *) + val of_fd : Unix.file_descr -> t + (** Make a buffered stream from the given unix filedescriptor, which + must be readable. + @since 0.4 + *) + + val to_fd : Unix.file_descr -> t -> unit + (** Make a buffered stream from the given unix filedescriptor, which + must be readable. + @since 0.4 + *) + val of_chan_close_noerr : in_channel -> t (** Same as {!of_chan} but the [close] method will never fail. *) @@ -121,8 +132,13 @@ module Byte_stream : sig val of_string : string -> t + val iter_full : (bytes -> int -> int -> int) -> t -> unit + (** Iterate on the chunks, consuming chunks partially. + The function returns how many bytes were actually consumed. + @since 0.4 *) + val iter : (bytes -> int -> int -> unit) -> t -> unit - (** Iterate on the chunks of the stream + (** Iterate on the chunks of the stream. @since 0.3 *) val to_chan : out_channel -> t -> unit