A buffer which helps httptools not to corrupt requests
This commit is contained in:
parent
00d4533022
commit
a987a20168
|
@ -56,6 +56,40 @@ class CIDict(dict):
|
|||
return super().__contains__(key.casefold())
|
||||
|
||||
|
||||
class RequestBuffer:
|
||||
def __init__(self):
|
||||
self.saved = b''
|
||||
self.passed_header = False
|
||||
self.received = b''
|
||||
|
||||
def put(self, data):
|
||||
if not self.passed_header:
|
||||
self.saved += data
|
||||
if b'\r\n\r\n' in self.saved:
|
||||
data = self.saved
|
||||
self.saved = None
|
||||
self.passed_header = True
|
||||
else:
|
||||
try:
|
||||
i = self.saved.rindex(b'\r\n')
|
||||
except ValueError:
|
||||
data = None
|
||||
else:
|
||||
data = self.saved[:i]
|
||||
self.saved = self.saved[i:]
|
||||
if data:
|
||||
self.received += data
|
||||
|
||||
def pop(self):
|
||||
received = self.received
|
||||
self.received = b''
|
||||
return received
|
||||
|
||||
def finalize(self):
|
||||
if self.saved:
|
||||
self.received += self.saved
|
||||
|
||||
|
||||
class HttpProtocol(asyncio.Protocol):
|
||||
__slots__ = (
|
||||
# event loop, connection
|
||||
|
@ -99,6 +133,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self._request_handler_task = None
|
||||
self._request_stream_task = None
|
||||
self._keep_alive = keep_alive
|
||||
self._buffer = None
|
||||
self.state = state if state else {}
|
||||
if 'requests_count' not in self.state:
|
||||
self.state['requests_count'] = 0
|
||||
|
@ -146,6 +181,18 @@ class HttpProtocol(asyncio.Protocol):
|
|||
# -------------------------------------------- #
|
||||
|
||||
def data_received(self, data):
|
||||
# Create parser if this is the first time we're receiving data
|
||||
if self.parser is None:
|
||||
assert self.request is None
|
||||
self.headers = []
|
||||
self.parser = HttpRequestParser(self)
|
||||
self._buffer = RequestBuffer()
|
||||
|
||||
self._buffer.put(data)
|
||||
data = self._buffer.pop()
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Check for the request itself getting too large and exceeding
|
||||
# memory limits
|
||||
self._total_request_size += len(data)
|
||||
|
@ -153,12 +200,6 @@ class HttpProtocol(asyncio.Protocol):
|
|||
exception = PayloadTooLarge('Payload Too Large')
|
||||
self.write_error(exception)
|
||||
|
||||
# Create parser if this is the first time we're receiving data
|
||||
if self.parser is None:
|
||||
assert self.request is None
|
||||
self.headers = []
|
||||
self.parser = HttpRequestParser(self)
|
||||
|
||||
# requests count
|
||||
self.state['requests_count'] = self.state['requests_count'] + 1
|
||||
|
||||
|
@ -357,6 +398,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self._request_stream_task = None
|
||||
self._total_request_size = 0
|
||||
self._is_stream_handler = False
|
||||
self._buffer = None
|
||||
|
||||
def close_if_idle(self):
|
||||
"""Close the connection if a request is not being sent or received
|
||||
|
|
|
@ -9,7 +9,7 @@ def test_bad_request_response():
|
|||
async def _request(sanic, loop):
|
||||
connect = asyncio.open_connection('127.0.0.1', 42101)
|
||||
reader, writer = await connect
|
||||
writer.write(b'not http')
|
||||
writer.write(b'not http\r\n\r\n')
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if not line:
|
||||
|
|
|
@ -9,6 +9,7 @@ from sanic import Sanic
|
|||
from sanic.exceptions import ServerError
|
||||
from sanic.response import json, text
|
||||
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE
|
||||
from sanic.server import RequestBuffer
|
||||
from sanic.testing import HOST, PORT
|
||||
|
||||
|
||||
|
@ -336,3 +337,12 @@ def test_url_attributes_with_ssl(path, query, expected_url):
|
|||
assert parsed.path == request.path
|
||||
assert parsed.query == request.query_string
|
||||
assert parsed.netloc == request.host
|
||||
|
||||
|
||||
def test_request_buffer():
|
||||
request = b'''GET /ping/ HTTP/1.1\r\nHost: github.com\r\nConnection: keep-alive\r\nCache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.115 Safari/537.36\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nAccept-Language: ko-KR,ko;q=0.8,en-US;q=0.6,en;q=0.4\r\n\r\n'''
|
||||
buffer = RequestBuffer()
|
||||
for i, byte in enumerate(request):
|
||||
buffer.put(request[i:i+1])
|
||||
buffer.finalize()
|
||||
assert request == buffer.pop()
|
Loading…
Reference in New Issue
Block a user