Fix issues with after request handling in HTTP pipelining (#2201)

* Clean up after a request is complete, before the next pipelined request.

* Limit the size of request body consumed after handler has finished.

* Linter error.

* Add unit test re: bad headers

Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
L. Kärkkäinen 2021-07-28 04:38:28 -04:00 committed by Adam Hopkins
parent a6e78b70ab
commit ba1c73d947
No known key found for this signature in database
GPG Key ID: 6A33C08203F67A28
2 changed files with 47 additions and 9 deletions

View File

@ -95,19 +95,23 @@ class Http:
self._receive_more = protocol.receive_more self._receive_more = protocol.receive_more
self.recv_buffer = protocol.recv_buffer self.recv_buffer = protocol.recv_buffer
self.protocol = protocol self.protocol = protocol
self.expecting_continue: bool = False self.keep_alive = True
self.stage: Stage = Stage.IDLE self.stage: Stage = Stage.IDLE
self.init_for_request()
def init_for_request(self):
"""Init/reset all per-request variables."""
self.exception = None
self.expecting_continue: bool = False
self.head_only = None
self.request_body = None self.request_body = None
self.request_bytes = None self.request_bytes = None
self.request_bytes_left = None self.request_bytes_left = None
self.request_max_size = protocol.request_max_size self.request_max_size = self.protocol.request_max_size
self.keep_alive = True
self.head_only = None
self.request: Request = None self.request: Request = None
self.response: BaseHTTPResponse = None self.response: BaseHTTPResponse = None
self.exception = None
self.url = None
self.upgrade_websocket = False self.upgrade_websocket = False
self.url = None
def __bool__(self): def __bool__(self):
"""Test if request handling is in progress""" """Test if request handling is in progress"""
@ -148,7 +152,10 @@ class Http:
if self.request_body: if self.request_body:
if self.response and 200 <= self.response.status < 300: if self.response and 200 <= self.response.status < 300:
error_logger.error(f"{self.request} body not consumed.") error_logger.error(f"{self.request} body not consumed.")
# Limit the size because the handler may have set it infinite
self.request_max_size = min(
self.request_max_size, self.protocol.request_max_size
)
try: try:
async for _ in self: async for _ in self:
pass pass
@ -160,11 +167,19 @@ class Http:
await sleep(0.001) await sleep(0.001)
self.keep_alive = False self.keep_alive = False
# Clean up to free memory and for the next request
if self.request:
self.request.stream = None
if self.response:
self.response.stream = None
self.init_for_request()
# Exit and disconnect if no more requests can be taken # Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive: if self.stage is not Stage.IDLE or not self.keep_alive:
break break
# Wait for next request # Wait for the next request
if not self.recv_buffer: if not self.recv_buffer:
await self._receive_more() await self._receive_more()

View File

@ -1,7 +1,7 @@
from httpx import AsyncByteStream from httpx import AsyncByteStream
from sanic_testing.reusable import ReusableClient from sanic_testing.reusable import ReusableClient
from sanic.response import json from sanic.response import json, text
def test_no_body_requests(app): def test_no_body_requests(app):
@ -80,3 +80,26 @@ def test_streaming_body_requests(app):
assert response1.json["data"] == response2.json["data"] == data assert response1.json["data"] == response2.json["data"] == data
assert response1.json["request_id"] != response2.json["request_id"] assert response1.json["request_id"] != response2.json["request_id"]
assert response1.json["connection_id"] == response2.json["connection_id"] assert response1.json["connection_id"] == response2.json["connection_id"]
def test_bad_headers(app):
@app.get("/")
async def handler(request):
return text("")
@app.on_response
async def reqid(request, response):
response.headers["x-request-id"] = request.id
client = ReusableClient(app, port=1234)
bad_headers = {"bad": "bad" * 5_000}
with client:
_, response1 = client.get("/")
_, response2 = client.get("/", headers=bad_headers)
assert response1.status == 200
assert response2.status == 413
assert (
response1.headers["x-request-id"] != response2.headers["x-request-id"]
)