diff --git a/sanic/server.py b/sanic/server.py index fd31680e..7d33e44f 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -56,6 +56,40 @@ class CIDict(dict): return super().__contains__(key.casefold()) +class RequestBuffer: + def __init__(self): + self.saved = b'' + self.passed_header = False + self.received = b'' + + def put(self, data): + if not self.passed_header: + self.saved += data + if b'\r\n\r\n' in self.saved: + data = self.saved + self.saved = None + self.passed_header = True + else: + try: + i = self.saved.rindex(b'\r\n') + except ValueError: + data = None + else: + data = self.saved[:i] + self.saved = self.saved[i:] + if data: + self.received += data + + def pop(self): + received = self.received + self.received = b'' + return received + + def finalize(self): + if self.saved: + self.received += self.saved + + class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection @@ -99,6 +133,7 @@ class HttpProtocol(asyncio.Protocol): self._request_handler_task = None self._request_stream_task = None self._keep_alive = keep_alive + self._buffer = None self.state = state if state else {} if 'requests_count' not in self.state: self.state['requests_count'] = 0 @@ -146,6 +181,18 @@ class HttpProtocol(asyncio.Protocol): # -------------------------------------------- # def data_received(self, data): + # Create parser if this is the first time we're receiving data + if self.parser is None: + assert self.request is None + self.headers = [] + self.parser = HttpRequestParser(self) + self._buffer = RequestBuffer() + + self._buffer.put(data) + data = self._buffer.pop() + if not data: + return + # Check for the request itself getting too large and exceeding # memory limits self._total_request_size += len(data) @@ -153,12 +200,6 @@ class HttpProtocol(asyncio.Protocol): exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) - # Create parser if this is the first time we're receiving data - if self.parser is None: - assert self.request is None - self.headers = [] - self.parser = HttpRequestParser(self) - # requests count self.state['requests_count'] = self.state['requests_count'] + 1 @@ -357,6 +398,7 @@ class HttpProtocol(asyncio.Protocol): self._request_stream_task = None self._total_request_size = 0 self._is_stream_handler = False + self._buffer = None def close_if_idle(self): """Close the connection if a request is not being sent or received diff --git a/tests/test_bad_request.py b/tests/test_bad_request.py index bf595085..f951b096 100644 --- a/tests/test_bad_request.py +++ b/tests/test_bad_request.py @@ -9,7 +9,7 @@ def test_bad_request_response(): async def _request(sanic, loop): connect = asyncio.open_connection('127.0.0.1', 42101) reader, writer = await connect - writer.write(b'not http') + writer.write(b'not http\r\n\r\n') while True: line = await reader.readline() if not line: diff --git a/tests/test_requests.py b/tests/test_requests.py index 81fe1a5c..01c60ec3 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -9,6 +9,7 @@ from sanic import Sanic from sanic.exceptions import ServerError from sanic.response import json, text from sanic.request import DEFAULT_HTTP_CONTENT_TYPE +from sanic.server import RequestBuffer from sanic.testing import HOST, PORT @@ -336,3 +337,12 @@ def test_url_attributes_with_ssl(path, query, expected_url): assert parsed.path == request.path assert parsed.query == request.query_string assert parsed.netloc == request.host + + +def test_request_buffer(): + request = b'''GET /ping/ HTTP/1.1\r\nHost: github.com\r\nConnection: keep-alive\r\nCache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.115 Safari/537.36\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nAccept-Language: ko-KR,ko;q=0.8,en-US;q=0.6,en;q=0.4\r\n\r\n''' + buffer = RequestBuffer() + for i, byte in enumerate(request): + buffer.put(request[i:i+1]) + buffer.finalize() + assert request == buffer.pop() \ No newline at end of file