A buffer which helps httptools not to corrupt requests

This commit is contained in:
Jeong YunWon 2017-06-30 21:46:19 +09:00
parent 00d4533022
commit a987a20168
3 changed files with 59 additions and 7 deletions

View File

@ -56,6 +56,40 @@ class CIDict(dict):
return super().__contains__(key.casefold())
class RequestBuffer:
def __init__(self):
self.saved = b''
self.passed_header = False
self.received = b''
def put(self, data):
if not self.passed_header:
self.saved += data
if b'\r\n\r\n' in self.saved:
data = self.saved
self.saved = None
self.passed_header = True
else:
try:
i = self.saved.rindex(b'\r\n')
except ValueError:
data = None
else:
data = self.saved[:i]
self.saved = self.saved[i:]
if data:
self.received += data
def pop(self):
received = self.received
self.received = b''
return received
def finalize(self):
if self.saved:
self.received += self.saved
class HttpProtocol(asyncio.Protocol):
__slots__ = (
# event loop, connection
@ -99,6 +133,7 @@ class HttpProtocol(asyncio.Protocol):
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = keep_alive
self._buffer = None
self.state = state if state else {}
if 'requests_count' not in self.state:
self.state['requests_count'] = 0
@ -146,6 +181,18 @@ class HttpProtocol(asyncio.Protocol):
# -------------------------------------------- #
def data_received(self, data):
# Create parser if this is the first time we're receiving data
if self.parser is None:
assert self.request is None
self.headers = []
self.parser = HttpRequestParser(self)
self._buffer = RequestBuffer()
self._buffer.put(data)
data = self._buffer.pop()
if not data:
return
# Check for the request itself getting too large and exceeding
# memory limits
self._total_request_size += len(data)
@ -153,12 +200,6 @@ class HttpProtocol(asyncio.Protocol):
exception = PayloadTooLarge('Payload Too Large')
self.write_error(exception)
# Create parser if this is the first time we're receiving data
if self.parser is None:
assert self.request is None
self.headers = []
self.parser = HttpRequestParser(self)
# requests count
self.state['requests_count'] = self.state['requests_count'] + 1
@ -357,6 +398,7 @@ class HttpProtocol(asyncio.Protocol):
self._request_stream_task = None
self._total_request_size = 0
self._is_stream_handler = False
self._buffer = None
def close_if_idle(self):
"""Close the connection if a request is not being sent or received

View File

@ -9,7 +9,7 @@ def test_bad_request_response():
async def _request(sanic, loop):
connect = asyncio.open_connection('127.0.0.1', 42101)
reader, writer = await connect
writer.write(b'not http')
writer.write(b'not http\r\n\r\n')
while True:
line = await reader.readline()
if not line:

View File

@ -9,6 +9,7 @@ from sanic import Sanic
from sanic.exceptions import ServerError
from sanic.response import json, text
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE
from sanic.server import RequestBuffer
from sanic.testing import HOST, PORT
@ -336,3 +337,12 @@ def test_url_attributes_with_ssl(path, query, expected_url):
assert parsed.path == request.path
assert parsed.query == request.query_string
assert parsed.netloc == request.host
def test_request_buffer():
request = b'''GET /ping/ HTTP/1.1\r\nHost: github.com\r\nConnection: keep-alive\r\nCache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.115 Safari/537.36\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nAccept-Language: ko-KR,ko;q=0.8,en-US;q=0.6,en;q=0.4\r\n\r\n'''
buffer = RequestBuffer()
for i, byte in enumerate(request):
buffer.put(request[i:i+1])
buffer.finalize()
assert request == buffer.pop()