diff --git a/src/ws/tiny_httpd_ws.ml b/src/ws/tiny_httpd_ws.ml index d886d57f..d070c77c 100644 --- a/src/ws/tiny_httpd_ws.ml +++ b/src/ws/tiny_httpd_ws.ml @@ -191,7 +191,7 @@ end module Reader = struct type state = | Begin (** At the beginning of a frame *) - | Reading_frame of { mutable remaining_bytes: int } + | Reading_frame of { mutable remaining_bytes: int; mutable num_read: int } (** Currently reading the payload of a frame with [remaining_bytes] left to read from the underlying [ic] *) | Close @@ -268,22 +268,26 @@ module Reader = struct self.header.payload_len self.header.mask);*) () - external apply_masking_ : key:bytes -> buf:bytes -> int -> int -> unit + external apply_masking_ : + key:bytes -> key_offset:int -> buf:bytes -> int -> int -> unit = "tiny_httpd_ws_apply_masking" [@@noalloc] (** Apply masking to the parsed data *) - let[@inline] apply_masking ~mask_key (buf : bytes) off len : unit = + let[@inline] apply_masking ~mask_key ~mask_offset (buf : bytes) off len : unit + = assert ( - Bytes.length mask_key = 4 && off >= 0 && off + len <= Bytes.length buf); - apply_masking_ ~key:mask_key ~buf off len + Bytes.length mask_key = 4 + && mask_offset >= 0 && off >= 0 + && off + len <= Bytes.length buf); + apply_masking_ ~key:mask_key ~key_offset:mask_offset ~buf off len let read_body_to_string (self : t) : string = let len = self.header.payload_len in let buf = Bytes.create len in IO.Input.really_input self.ic buf 0 len; if self.header.mask then - apply_masking ~mask_key:self.header.mask_key buf 0 len; + apply_masking ~mask_key:self.header.mask_key ~mask_offset:0 buf 0 len; Bytes.unsafe_to_string buf (** Skip bytes of the body *) @@ -303,33 +307,45 @@ module Reader = struct self.state <- Begin; read_rec self buf i len | Reading_frame r -> + Printf.printf "reading len=%d from frame remaining=%d (key=%S)\n%!" len + r.remaining_bytes + (Bytes.unsafe_to_string self.header.mask_key); let len = min len r.remaining_bytes in let n = IO.Input.input self.ic buf i len in + Printf.printf "got n=%d bytes\n%!" n; + Printf.printf "in buf: %S\n%!" (Bytes.sub_string buf i n); - (* update state *) - r.remaining_bytes <- r.remaining_bytes - n; - if r.remaining_bytes = 0 then self.state <- Begin; - + (* apply masking *) if self.header.mask then - apply_masking ~mask_key:self.header.mask_key buf i n + apply_masking ~mask_key:self.header.mask_key ~mask_offset:r.num_read buf + i n else ( Log.error (fun k -> k "websocket: client's frames must be masked"); raise Close_connection ); + + (* update state *) + r.remaining_bytes <- r.remaining_bytes - n; + r.num_read <- r.num_read + n; + if r.remaining_bytes = 0 then self.state <- Begin; + + Printf.printf "in buf (unmasked): %S\n%!" (Bytes.sub_string buf i n); n | Begin -> read_frame_header self; Log.debug (fun k -> - k "websocket: read frame of type=%s payload_len=%d" + k "websocket: read frame of type=%s payload_len=%d key=%S" (Frame_type.show self.header.ty) - self.header.payload_len); + self.header.payload_len + (Bytes.unsafe_to_string self.header.mask_key)); (match self.header.ty with | 0 -> (* continuation *) if self.last_ty = 1 || self.last_ty = 2 then self.state <- - Reading_frame { remaining_bytes = self.header.payload_len } + Reading_frame + { remaining_bytes = self.header.payload_len; num_read = 0 } else ( Log.error (fun k -> k "continuation frame coming after frame of type %s" @@ -340,12 +356,14 @@ module Reader = struct | 1 -> (* text *) self.state <- - Reading_frame { remaining_bytes = self.header.payload_len }; + Reading_frame + { remaining_bytes = self.header.payload_len; num_read = 0 }; read_rec self buf i len | 2 -> (* binary *) self.state <- - Reading_frame { remaining_bytes = self.header.payload_len }; + Reading_frame + { remaining_bytes = self.header.payload_len; num_read = 0 }; read_rec self buf i len | 8 -> (* close frame *) diff --git a/src/ws/tiny_httpd_ws.mli b/src/ws/tiny_httpd_ws.mli index 3582d010..77e71232 100644 --- a/src/ws/tiny_httpd_ws.mli +++ b/src/ws/tiny_httpd_ws.mli @@ -24,7 +24,8 @@ val add_route_handler : (**/**) module Private_ : sig - val apply_masking : mask_key:bytes -> bytes -> int -> int -> unit + val apply_masking : + mask_key:bytes -> mask_offset:int -> bytes -> int -> int -> unit end (**/**) diff --git a/src/ws/tiny_httpd_ws_stubs.c b/src/ws/tiny_httpd_ws_stubs.c index 3fe0a171..0813b527 100644 --- a/src/ws/tiny_httpd_ws_stubs.c +++ b/src/ws/tiny_httpd_ws_stubs.c @@ -3,18 +3,19 @@ #include #include -CAMLprim value tiny_httpd_ws_apply_masking(value _mask_key, value _buf, +CAMLprim value tiny_httpd_ws_apply_masking(value _mask_key, value _mask_offset, value _buf, value _offset, value _len) { - CAMLparam4(_mask_key, _buf, _offset, _len); + CAMLparam5(_mask_key, _mask_offset, _buf, _offset, _len); char const *mask_key = String_val(_mask_key); char *buf = Bytes_val(_buf); + intnat mask_offset = Int_val(_mask_offset); intnat offset = Int_val(_offset); intnat len = Int_val(_len); for (intnat i = 0; i < len; ++i) { unsigned char c = buf[offset + i]; - unsigned char c_m = mask_key[i & 0x3]; + unsigned char c_m = mask_key[(i + mask_offset) & 0x3]; buf[offset + i] = (unsigned char)(c ^ c_m); } CAMLreturn(Val_unit);