Rewritten request body parser.

This commit is contained in:
L. Kärkkäinen 2020-02-26 13:03:33 +02:00
parent ef4f233fba
commit 42d86bcd5a

View File

@ -81,6 +81,7 @@ class HttpProtocol(asyncio.Protocol):
"access_log", "access_log",
# connection management # connection management
"_total_request_size", "_total_request_size",
"_request_bytes_left",
"_status", "_status",
"_time", "_time",
"_last_response_time", "_last_response_time",
@ -235,28 +236,39 @@ class HttpProtocol(asyncio.Protocol):
) )
except: except:
raise InvalidUsage("Bad Request") raise InvalidUsage("Bad Request")
del buf[:pos + 4]
if headers.get(EXPECT_HEADER): if headers.get(EXPECT_HEADER):
self._status = Status.EXPECT self._status = Status.EXPECT
self.expect_handler() self.expect_handler()
self.state["requests_count"] += 1 self.state["requests_count"] += 1
# Run handler # Prepare a request object from the header received
self.request = self.request_class( self.request = self.request_class(
url_bytes=self.url.encode(), url_bytes=self.url.encode(),
headers=headers, headers=headers,
version="1.1", version=protocol[-3:],
method=method, method=method,
transport=self.transport, transport=self.transport,
app=self.app, app=self.app,
) )
self.state["request_body"] = request_body = ( # Prepare for request body
int(headers.get("content-length", 0)) or body = (
headers.get("transfer-encoding") == "chunked" headers.get("transfer-encoding") == "chunked"
or int(headers.get("content-length", 0))
) )
if request_body: if body:
self.request.stream = StreamBuffer( self.request.stream = StreamBuffer(
self.request_buffer_queue_size, protocol=self, self.request_buffer_queue_size, protocol=self,
) )
if body is True:
self._request_chunked = True
self._request_bytes_left = 0
pos -= 2 # One CRLF stays in buffer
else:
self._request_chunked = False
self._request_bytes_left = body
# Remove header and its trailing CRLF
del buf[:pos + 4]
# Run handler
self._status, self._time = Status.HANDLER, current_time() self._status, self._time = Status.HANDLER, current_time()
await self.request_handler( await self.request_handler(
self.request, self.write_response, self.stream_response self.request, self.write_response, self.stream_response
@ -270,37 +282,41 @@ class HttpProtocol(asyncio.Protocol):
self.close() self.close()
async def request_body(self): async def request_body(self):
rb = self.state["request_body"] buf = self._buffer
if rb is True: if self._request_chunked and self._request_bytes_left == 0:
# This code is crap and needs rewriting # Process a chunk header: \r\n<size>[;<chunk extensions>]\r\n
while b"\r\n" not in self._buffer: while True:
pos = buf.find(b"\r\n", 3)
if pos != -1:
break
if len(buf) > 64:
self.keep_alive = False
raise InvalidUsage("Bad chunked encoding")
await self.receive_more() await self.receive_more()
pos = self._buffer.find(b"\r\n") try:
size = int(self._buffer[:pos]) size = int(buf[2:pos].split(b";", 1)[0].decode(), 16)
if self._total_request_size + size > self.request_max_size: except:
self.keep_alive = False self.keep_alive = False
raise PayloadTooLarge("Payload Too Large") raise InvalidUsage("Bad chunked encoding")
self._total_request_size += pos + 4 + size self._request_bytes_left = size
if size == 0: self._total_request_size += pos + 2
self.state["request_body"] = 0 del buf[:pos + 2]
if self._request_bytes_left <= 0:
self._request_chunked = False
return None return None
while len(self._buffer) < pos + 4 + size: # At this point we are good to read/return _request_bytes_left
if self._request_bytes_left:
if not buf:
await self.receive_more() await self.receive_more()
body = self._buffer[pos + 2: pos + 2 + size] data = bytes(buf[:self._request_bytes_left])
del body[:pos + 4 + size] size = len(data)
return body del buf[:size]
elif rb > 0: self._request_bytes_left -= size
if not self._buffer:
await self.receive_more()
body = self._buffer[:rb]
size = len(body)
del self._buffer[:rb]
self.state["request_body"] -= size
self._total_request_size += size self._total_request_size += size
if self._total_request_size > self.request_max_size: if self._total_request_size > self.request_max_size:
self.keep_alive = False self.keep_alive = False
raise PayloadTooLarge("Payload Too Large") raise PayloadTooLarge("Payload Too Large")
return body return data
return None return None
def expect_handler(self): def expect_handler(self):