diff --git a/src/Tiny_httpd.ml b/src/Tiny_httpd.ml index 285b62f9..7d62eb1e 100644 --- a/src/Tiny_httpd.ml +++ b/src/Tiny_httpd.ml @@ -301,10 +301,20 @@ end module Headers = struct type t = (string * string) list - let contains = List.mem_assoc - let get ?(f=fun x->x) x h = try Some (List.assoc x h |> f) with Not_found -> None - let remove x h = List.filter (fun (k,_) -> k<>x) h - let set x y h = (x,y) :: List.filter (fun (k,_) -> k<>x) h + let contains name headers = + let name' = String.lowercase_ascii name in + List.exists (fun (n, _) -> name'=n) headers + let get_exn ?(f=fun x->x) x h = + let x' = String.lowercase_ascii x in + List.assoc x' h |> f + let get ?(f=fun x -> x) x h = + try Some (get_exn ~f x h) with Not_found -> None + let remove x h = + let x' = String.lowercase_ascii x in + List.filter (fun (k,_) -> k<>x') h + let set x y h = + let x' = String.lowercase_ascii x in + (x',y) :: List.filter (fun (k,_) -> k<>x') h let pp out l = let pp_pair out (k,v) = Format.fprintf out "@[%s: %s@]" k v in Format.fprintf out "@[%a@]" (Format.pp_print_list pp_pair) l @@ -320,7 +330,7 @@ module Headers = struct try Scanf.sscanf line "%s@: %s@\r" (fun k v->k,v) with _ -> bad_reqf 400 "invalid header line: %S" line in - loop ((k,v)::acc) + loop ((String.lowercase_ascii k,v)::acc) ) in loop [] @@ -444,8 +454,9 @@ module Request = struct _debug (fun k->k "got meth: %s, path %S" (Meth.to_string meth) path); let headers = Headers.parse_ ~buf bs in let host = - try List.assoc "Host" headers - with Not_found -> bad_reqf 400 "No 'Host' header in request" + match Headers.get "Host" headers with + | None -> bad_reqf 400 "No 'Host' header in request" + | Some h -> h in Ok (Some {meth; host; path; headers; body=()}) with @@ -459,7 +470,7 @@ module Request = struct let parse_body_ ~tr_stream ~buf (req:byte_stream t) : byte_stream t resp_result = try let size = - match List.assoc "Content-Length" req.headers |> int_of_string with + match Headers.get_exn "Content-Length" req.headers |> int_of_string with | n -> n (* body of fixed size *) | exception Not_found -> 0 | exception _ -> bad_reqf 400 "invalid content-length" @@ -506,6 +517,8 @@ end | None -> assert_failure "should parse" | Some req -> assert_equal (Some "coucou") (Headers.get "Host" req.Request.headers); + assert_equal (Some "coucou") (Headers.get "host" req.Request.headers); + assert_equal (Some "11") (Headers.get "content-length" req.Request.headers); assert_equal "hello" req.Request.path; let req = Request.Internal_.parse_body req str |> Request.read_body_full in assert_equal ~printer:(fun s->s) "salutations" req.Request.body;