Rewritten request body parser.
This commit is contained in:
		| @@ -81,6 +81,7 @@ class HttpProtocol(asyncio.Protocol): | |||||||
|         "access_log", |         "access_log", | ||||||
|         # connection management |         # connection management | ||||||
|         "_total_request_size", |         "_total_request_size", | ||||||
|  |         "_request_bytes_left", | ||||||
|         "_status", |         "_status", | ||||||
|         "_time", |         "_time", | ||||||
|         "_last_response_time", |         "_last_response_time", | ||||||
| @@ -235,28 +236,39 @@ class HttpProtocol(asyncio.Protocol): | |||||||
|                     ) |                     ) | ||||||
|                 except: |                 except: | ||||||
|                     raise InvalidUsage("Bad Request") |                     raise InvalidUsage("Bad Request") | ||||||
|                 del buf[:pos + 4] |  | ||||||
|                 if headers.get(EXPECT_HEADER): |                 if headers.get(EXPECT_HEADER): | ||||||
|                     self._status = Status.EXPECT |                     self._status = Status.EXPECT | ||||||
|                     self.expect_handler() |                     self.expect_handler() | ||||||
|                 self.state["requests_count"] += 1 |                 self.state["requests_count"] += 1 | ||||||
|                 # Run handler |                 # Prepare a request object from the header received | ||||||
|                 self.request = self.request_class( |                 self.request = self.request_class( | ||||||
|                     url_bytes=self.url.encode(), |                     url_bytes=self.url.encode(), | ||||||
|                     headers=headers, |                     headers=headers, | ||||||
|                     version="1.1", |                     version=protocol[-3:], | ||||||
|                     method=method, |                     method=method, | ||||||
|                     transport=self.transport, |                     transport=self.transport, | ||||||
|                     app=self.app, |                     app=self.app, | ||||||
|                 ) |                 ) | ||||||
|                 self.state["request_body"] = request_body = ( |                 # Prepare for request body | ||||||
|                     int(headers.get("content-length", 0)) or |                 body = ( | ||||||
|                     headers.get("transfer-encoding") == "chunked" |                     headers.get("transfer-encoding") == "chunked" | ||||||
|  |                     or int(headers.get("content-length", 0)) | ||||||
|                 ) |                 ) | ||||||
|                 if request_body: |                 if body: | ||||||
|                     self.request.stream = StreamBuffer( |                     self.request.stream = StreamBuffer( | ||||||
|                         self.request_buffer_queue_size, protocol=self, |                         self.request_buffer_queue_size, protocol=self, | ||||||
|                     ) |                     ) | ||||||
|  |                     if body is True: | ||||||
|  |                         self._request_chunked = True | ||||||
|  |                         self._request_bytes_left = 0 | ||||||
|  |                         pos -= 2  # One CRLF stays in buffer | ||||||
|  |                     else: | ||||||
|  |                         self._request_chunked = False | ||||||
|  |                         self._request_bytes_left = body | ||||||
|  |                 # Remove header and its trailing CRLF | ||||||
|  |                 del buf[:pos + 4] | ||||||
|  |                 # Run handler | ||||||
|                 self._status, self._time = Status.HANDLER, current_time() |                 self._status, self._time = Status.HANDLER, current_time() | ||||||
|                 await self.request_handler( |                 await self.request_handler( | ||||||
|                     self.request, self.write_response, self.stream_response |                     self.request, self.write_response, self.stream_response | ||||||
| @@ -270,37 +282,41 @@ class HttpProtocol(asyncio.Protocol): | |||||||
|             self.close() |             self.close() | ||||||
|  |  | ||||||
|     async def request_body(self): |     async def request_body(self): | ||||||
|         rb = self.state["request_body"] |         buf = self._buffer | ||||||
|         if rb is True: |         if self._request_chunked and self._request_bytes_left == 0: | ||||||
|             # This code is crap and needs rewriting |             # Process a chunk header: \r\n<size>[;<chunk extensions>]\r\n | ||||||
|             while b"\r\n" not in self._buffer: |             while True: | ||||||
|  |                 pos = buf.find(b"\r\n", 3) | ||||||
|  |                 if pos != -1: | ||||||
|  |                     break | ||||||
|  |                 if len(buf) > 64: | ||||||
|  |                     self.keep_alive = False | ||||||
|  |                     raise InvalidUsage("Bad chunked encoding") | ||||||
|                 await self.receive_more() |                 await self.receive_more() | ||||||
|             pos = self._buffer.find(b"\r\n") |             try: | ||||||
|             size = int(self._buffer[:pos]) |                 size = int(buf[2:pos].split(b";", 1)[0].decode(), 16) | ||||||
|             if self._total_request_size + size > self.request_max_size: |             except: | ||||||
|                 self.keep_alive = False |                 self.keep_alive = False | ||||||
|                 raise PayloadTooLarge("Payload Too Large") |                 raise InvalidUsage("Bad chunked encoding") | ||||||
|             self._total_request_size += pos + 4 + size |             self._request_bytes_left = size | ||||||
|             if size == 0: |             self._total_request_size += pos + 2 | ||||||
|                 self.state["request_body"] = 0 |             del buf[:pos + 2] | ||||||
|  |             if self._request_bytes_left <= 0: | ||||||
|  |                 self._request_chunked = False | ||||||
|                 return None |                 return None | ||||||
|             while len(self._buffer) < pos + 4 + size: |         # At this point we are good to read/return _request_bytes_left | ||||||
|  |         if self._request_bytes_left: | ||||||
|  |             if not buf: | ||||||
|                 await self.receive_more() |                 await self.receive_more() | ||||||
|             body = self._buffer[pos + 2: pos + 2 + size] |             data = bytes(buf[:self._request_bytes_left]) | ||||||
|             del body[:pos + 4 + size] |             size = len(data) | ||||||
|             return body |             del buf[:size] | ||||||
|         elif rb > 0: |             self._request_bytes_left -= size | ||||||
|             if not self._buffer: |  | ||||||
|                 await self.receive_more() |  | ||||||
|             body = self._buffer[:rb] |  | ||||||
|             size = len(body) |  | ||||||
|             del self._buffer[:rb] |  | ||||||
|             self.state["request_body"] -= size |  | ||||||
|             self._total_request_size += size |             self._total_request_size += size | ||||||
|             if self._total_request_size > self.request_max_size: |             if self._total_request_size > self.request_max_size: | ||||||
|                 self.keep_alive = False |                 self.keep_alive = False | ||||||
|                 raise PayloadTooLarge("Payload Too Large") |                 raise PayloadTooLarge("Payload Too Large") | ||||||
|             return body |             return data | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     def expect_handler(self): |     def expect_handler(self): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 L. Kärkkäinen
					L. Kärkkäinen