diff --git a/sanic/http.py b/sanic/http.py index 16603c38..402238cb 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -95,19 +95,23 @@ class Http: self._receive_more = protocol.receive_more self.recv_buffer = protocol.recv_buffer self.protocol = protocol - self.expecting_continue: bool = False + self.keep_alive = True self.stage: Stage = Stage.IDLE + self.init_for_request() + + def init_for_request(self): + """Init/reset all per-request variables.""" + self.exception = None + self.expecting_continue: bool = False + self.head_only = None self.request_body = None self.request_bytes = None self.request_bytes_left = None - self.request_max_size = protocol.request_max_size - self.keep_alive = True - self.head_only = None + self.request_max_size = self.protocol.request_max_size self.request: Request = None self.response: BaseHTTPResponse = None - self.exception = None - self.url = None self.upgrade_websocket = False + self.url = None def __bool__(self): """Test if request handling is in progress""" @@ -148,7 +152,10 @@ class Http: if self.request_body: if self.response and 200 <= self.response.status < 300: error_logger.error(f"{self.request} body not consumed.") - + # Limit the size because the handler may have set it infinite + self.request_max_size = min( + self.request_max_size, self.protocol.request_max_size + ) try: async for _ in self: pass @@ -160,11 +167,19 @@ class Http: await sleep(0.001) self.keep_alive = False + # Clean up to free memory and for the next request + if self.request: + self.request.stream = None + if self.response: + self.response.stream = None + + self.init_for_request() + # Exit and disconnect if no more requests can be taken if self.stage is not Stage.IDLE or not self.keep_alive: break - # Wait for next request + # Wait for the next request if not self.recv_buffer: await self._receive_more() diff --git a/tests/test_pipelining.py b/tests/test_pipelining.py index 689a787b..2bb29c52 100644 --- a/tests/test_pipelining.py +++ b/tests/test_pipelining.py @@ -1,7 +1,7 @@ from httpx import AsyncByteStream from sanic_testing.reusable import ReusableClient -from sanic.response import json +from sanic.response import json, text def test_no_body_requests(app): @@ -80,3 +80,26 @@ def test_streaming_body_requests(app): assert response1.json["data"] == response2.json["data"] == data assert response1.json["request_id"] != response2.json["request_id"] assert response1.json["connection_id"] == response2.json["connection_id"] + + +def test_bad_headers(app): + @app.get("/") + async def handler(request): + return text("") + + @app.on_response + async def reqid(request, response): + response.headers["x-request-id"] = request.id + + client = ReusableClient(app, port=1234) + bad_headers = {"bad": "bad" * 5_000} + + with client: + _, response1 = client.get("/") + _, response2 = client.get("/", headers=bad_headers) + + assert response1.status == 200 + assert response2.status == 413 + assert ( + response1.headers["x-request-id"] != response2.headers["x-request-id"] + )