A buffer which helps httptools not to corrupt requests
This commit is contained in:
parent
00d4533022
commit
a987a20168
|
@ -56,6 +56,40 @@ class CIDict(dict):
|
||||||
return super().__contains__(key.casefold())
|
return super().__contains__(key.casefold())
|
||||||
|
|
||||||
|
|
||||||
|
class RequestBuffer:
|
||||||
|
def __init__(self):
|
||||||
|
self.saved = b''
|
||||||
|
self.passed_header = False
|
||||||
|
self.received = b''
|
||||||
|
|
||||||
|
def put(self, data):
|
||||||
|
if not self.passed_header:
|
||||||
|
self.saved += data
|
||||||
|
if b'\r\n\r\n' in self.saved:
|
||||||
|
data = self.saved
|
||||||
|
self.saved = None
|
||||||
|
self.passed_header = True
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
i = self.saved.rindex(b'\r\n')
|
||||||
|
except ValueError:
|
||||||
|
data = None
|
||||||
|
else:
|
||||||
|
data = self.saved[:i]
|
||||||
|
self.saved = self.saved[i:]
|
||||||
|
if data:
|
||||||
|
self.received += data
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
received = self.received
|
||||||
|
self.received = b''
|
||||||
|
return received
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self.saved:
|
||||||
|
self.received += self.saved
|
||||||
|
|
||||||
|
|
||||||
class HttpProtocol(asyncio.Protocol):
|
class HttpProtocol(asyncio.Protocol):
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
# event loop, connection
|
# event loop, connection
|
||||||
|
@ -99,6 +133,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
self._request_handler_task = None
|
self._request_handler_task = None
|
||||||
self._request_stream_task = None
|
self._request_stream_task = None
|
||||||
self._keep_alive = keep_alive
|
self._keep_alive = keep_alive
|
||||||
|
self._buffer = None
|
||||||
self.state = state if state else {}
|
self.state = state if state else {}
|
||||||
if 'requests_count' not in self.state:
|
if 'requests_count' not in self.state:
|
||||||
self.state['requests_count'] = 0
|
self.state['requests_count'] = 0
|
||||||
|
@ -146,6 +181,18 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
# -------------------------------------------- #
|
# -------------------------------------------- #
|
||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
|
# Create parser if this is the first time we're receiving data
|
||||||
|
if self.parser is None:
|
||||||
|
assert self.request is None
|
||||||
|
self.headers = []
|
||||||
|
self.parser = HttpRequestParser(self)
|
||||||
|
self._buffer = RequestBuffer()
|
||||||
|
|
||||||
|
self._buffer.put(data)
|
||||||
|
data = self._buffer.pop()
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
# Check for the request itself getting too large and exceeding
|
# Check for the request itself getting too large and exceeding
|
||||||
# memory limits
|
# memory limits
|
||||||
self._total_request_size += len(data)
|
self._total_request_size += len(data)
|
||||||
|
@ -153,12 +200,6 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
exception = PayloadTooLarge('Payload Too Large')
|
exception = PayloadTooLarge('Payload Too Large')
|
||||||
self.write_error(exception)
|
self.write_error(exception)
|
||||||
|
|
||||||
# Create parser if this is the first time we're receiving data
|
|
||||||
if self.parser is None:
|
|
||||||
assert self.request is None
|
|
||||||
self.headers = []
|
|
||||||
self.parser = HttpRequestParser(self)
|
|
||||||
|
|
||||||
# requests count
|
# requests count
|
||||||
self.state['requests_count'] = self.state['requests_count'] + 1
|
self.state['requests_count'] = self.state['requests_count'] + 1
|
||||||
|
|
||||||
|
@ -357,6 +398,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
self._request_stream_task = None
|
self._request_stream_task = None
|
||||||
self._total_request_size = 0
|
self._total_request_size = 0
|
||||||
self._is_stream_handler = False
|
self._is_stream_handler = False
|
||||||
|
self._buffer = None
|
||||||
|
|
||||||
def close_if_idle(self):
|
def close_if_idle(self):
|
||||||
"""Close the connection if a request is not being sent or received
|
"""Close the connection if a request is not being sent or received
|
||||||
|
|
|
@ -9,7 +9,7 @@ def test_bad_request_response():
|
||||||
async def _request(sanic, loop):
|
async def _request(sanic, loop):
|
||||||
connect = asyncio.open_connection('127.0.0.1', 42101)
|
connect = asyncio.open_connection('127.0.0.1', 42101)
|
||||||
reader, writer = await connect
|
reader, writer = await connect
|
||||||
writer.write(b'not http')
|
writer.write(b'not http\r\n\r\n')
|
||||||
while True:
|
while True:
|
||||||
line = await reader.readline()
|
line = await reader.readline()
|
||||||
if not line:
|
if not line:
|
||||||
|
|
|
@ -9,6 +9,7 @@ from sanic import Sanic
|
||||||
from sanic.exceptions import ServerError
|
from sanic.exceptions import ServerError
|
||||||
from sanic.response import json, text
|
from sanic.response import json, text
|
||||||
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE
|
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE
|
||||||
|
from sanic.server import RequestBuffer
|
||||||
from sanic.testing import HOST, PORT
|
from sanic.testing import HOST, PORT
|
||||||
|
|
||||||
|
|
||||||
|
@ -336,3 +337,12 @@ def test_url_attributes_with_ssl(path, query, expected_url):
|
||||||
assert parsed.path == request.path
|
assert parsed.path == request.path
|
||||||
assert parsed.query == request.query_string
|
assert parsed.query == request.query_string
|
||||||
assert parsed.netloc == request.host
|
assert parsed.netloc == request.host
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_buffer():
|
||||||
|
request = b'''GET /ping/ HTTP/1.1\r\nHost: github.com\r\nConnection: keep-alive\r\nCache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.115 Safari/537.36\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nAccept-Language: ko-KR,ko;q=0.8,en-US;q=0.6,en;q=0.4\r\n\r\n'''
|
||||||
|
buffer = RequestBuffer()
|
||||||
|
for i, byte in enumerate(request):
|
||||||
|
buffer.put(request[i:i+1])
|
||||||
|
buffer.finalize()
|
||||||
|
assert request == buffer.pop()
|
Loading…
Reference in New Issue
Block a user