diff --git a/src/Tiny_httpd.ml b/src/Tiny_httpd.ml index 53891e53..972f3ea2 100644 --- a/src/Tiny_httpd.ml +++ b/src/Tiny_httpd.ml @@ -18,61 +18,64 @@ let _debug k = Printf.kfprintf (fun oc -> Printf.fprintf oc "\n%!") stdout fmt) ) -type out = { fd : Unix.file_descr - ; buf : Bytes.t - ; mutable pos : int } +module Out = struct + type t = { fd : Unix.file_descr + ; buf : Bytes.t + ; mutable pos : int } -let out_of_descr fd = - let size = Unix.(getsockopt_int fd SO_SNDBUF) in - { fd; buf = Bytes.create size; pos = 0 } + let out_of_descr fd = + let size = Unix.(getsockopt_int fd SO_SNDBUF) in + { fd; buf = Bytes.create size; pos = 0 } -let rec write_out ({buf;fd;_} as oc) i len = - try - let _,w,_ = Unix.select [] [fd] [] (-1.0) in - if w <> [] then + let rec write_out ({buf;fd;_} as oc) i len = + try + let _,w,_ = Unix.select [] [fd] [] (-1.0) in + if w <> [] then + begin + let written = Unix.single_write fd buf i len in + if written < len then + write_out oc (i+written) (len-written) + end + else assert false + with Sys_blocked_io + | Unix.Unix_error((EAGAIN|EWOULDBLOCK),_,_) -> + write_out oc i len + | Unix.Unix_error((EPIPE|EBADF),_,_) -> + raise (Sys_error "broken pipe") + | e -> + Printf.eprintf "unexpected exception in write_out: %s\n%!" + (Printexc.to_string e); + assert false + + let rec output ({buf;pos;_} as oc) s i len = + let buf_len = Bytes.length buf in + let do_write, to_write = + if len >= buf_len - pos then + true, buf_len - pos + else + false, len + in + Bytes.blit s i buf pos to_write; + oc.pos <- pos + to_write; + if do_write then begin - let written = Unix.single_write fd buf i len in - if written < len then - write_out oc (i+written) (len-written) + write_out oc 0 buf_len; + oc.pos <- 0; + output oc s (i + to_write) (len - to_write) end - else assert false - with Sys_blocked_io - | Unix.Unix_error((EAGAIN|EWOULDBLOCK),_,_) -> - write_out oc i len - | Unix.Unix_error((EPIPE|EBADF),_,_) -> - raise (Sys_error "broken pipe") - | e -> - Printf.eprintf "unexpected exception in write_out: %s\n%!" - (Printexc.to_string e); - assert false -let rec output ({buf;pos;_} as oc) s i len = - let buf_len = Bytes.length buf in - let do_write, to_write = - if len >= buf_len - pos then - true, buf_len - pos - else - false, len - in - Bytes.blit s i buf pos to_write; - oc.pos <- pos + to_write; - if do_write then - begin - write_out oc 0 buf_len; - oc.pos <- 0; - output oc s (i + to_write) (len - to_write) - end + let output_string oc str = + let buf = Bytes.unsafe_of_string str in + output oc buf 0 (String.length str) -let output_string oc str = - let buf = Bytes.unsafe_of_string str in - output oc buf 0 (String.length str) + let fprintf oc format = + Printf.ksprintf (output_string oc) format -let fprintf oc format = - Printf.ksprintf (output_string oc) format + let flush oc = + write_out oc 0 oc.pos; + oc.pos <- 0 +end -let flush oc = - write_out oc 0 oc.pos; - oc.pos <- 0 module Buf_ = struct type t = { @@ -711,23 +714,23 @@ module Response = struct self.code Headers.pp self.headers pp_body self.body (* print a stream as a series of chunks *) - let output_stream_chunked_ (oc:out) (str:byte_stream) : unit = + let output_stream_chunked_ (oc:Out.t) (str:byte_stream) : unit = let continue = ref true in while !continue do (* next chunk *) let s, i, len = str.bs_fill_buf () in - fprintf oc "%x\r\n" len; - output oc s i len; + Out.fprintf oc "%x\r\n" len; + Out.output oc s i len; str.bs_consume len; if len = 0 then ( continue := false; ); - output_string oc "\r\n"; + Out.output_string oc "\r\n"; done; () - let output_ (oc:out) (self:t) : unit = - fprintf oc "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code); + let output_ (oc:Out.t) (self:t) : unit = + Out.fprintf oc "HTTP/1.1 %d %s\r\n" self.code (Response_code.descr self.code); let body, is_chunked = match self.body with | `String s when String.length s > 1024 * 500 -> (* chunk-encode large bodies *) @@ -746,14 +749,14 @@ module Response = struct let self = {self with headers; body} in _debug (fun k->k "output response: %s" (Format.asprintf "%a" pp {self with body=`String "<…>"})); - List.iter (fun (k,v) -> fprintf oc "%s: %s\r\n" k v) headers; - output_string oc "\r\n"; + List.iter (fun (k,v) -> Out.fprintf oc "%s: %s\r\n" k v) headers; + Out.output_string oc "\r\n"; begin match body with | `String "" | `Void -> () - | `String s -> output_string oc s; + | `String s -> Out.output_string oc s; | `Stream str -> output_stream_chunked_ oc str; end; - flush oc + Out.flush oc end (* semaphore, for limiting concurrency. *) @@ -872,7 +875,7 @@ end (* a request handler. handles a single request. *) type cb_path_handler = - out -> + Out.t -> byte_stream Request.t -> resp:(Response.t -> unit) -> unit @@ -940,7 +943,7 @@ let add_path_handler_ | handler -> (* we have a handler, do we accept the request based on its headers? *) begin match accept req with - | Ok () -> Some (Ok (fun (_oc:out) req ~resp -> resp (handler (tr_req req)))) + | Ok () -> Some (Ok (fun (_oc:Out.t) req ~resp -> resp (handler (tr_req req)))) | Error _ as e -> Some e end | exception _ -> @@ -971,7 +974,7 @@ let add_route_handler_ | Some handler -> (* we have a handler, do we accept the request based on its headers? *) begin match accept req with - | Ok () -> Some (Ok (fun (oc:out) req ~resp -> tr_req oc req ~resp handler)) + | Ok () -> Some (Ok (fun (oc:Out.t) req ~resp -> tr_req oc req ~resp handler)) | Error _ as e -> Some e end | None -> @@ -993,7 +996,7 @@ let[@inline] _opt_iter ~f o = match o with | Some x -> f x let add_route_server_sent_handler ?accept self route f = - let tr_req (oc:out) req ~resp f = + let tr_req (oc:Out.t) req ~resp f = let req = Request.read_body_full req in let headers = ref Headers.(empty |> set "content-type" "text/event-stream") in @@ -1010,13 +1013,13 @@ let add_route_server_sent_handler ?accept self route f = let send_event ?event ?id ?retry ~data () : unit = send_response_idempotent_(); - _opt_iter event ~f:(fun e -> fprintf oc "data: %s\n" e); - _opt_iter id ~f:(fun e -> fprintf oc "id: %s\n" e); - _opt_iter retry ~f:(fun e -> fprintf oc "retry: %s\n" e); + _opt_iter event ~f:(fun e -> Out.fprintf oc "data: %s\n" e); + _opt_iter id ~f:(fun e -> Out.fprintf oc "id: %s\n" e); + _opt_iter retry ~f:(fun e -> Out.fprintf oc "retry: %s\n" e); let l = String.split_on_char '\n' data in - List.iter (fun s -> fprintf oc "data: %s\n" s) l; - output_string oc "\n"; (* finish group *) - flush oc + List.iter (fun s -> Out.fprintf oc "data: %s\n" s) l; + Out.output_string oc "\n"; (* finish group *) + Out.flush oc in let module SSG = struct let set_headers h = @@ -1058,7 +1061,7 @@ let find_map f l = let handle_client_ (self:t) (client_sock:Unix.file_descr) : unit = let _ = Unix.set_nonblock client_sock in - let oc = out_of_descr client_sock in + let oc = Out.out_of_descr client_sock in let buf = Buf_.create() in let is = Byte_stream.of_descr client_sock in let continue = ref true in