From c5070bd449f15efd5be0dacdef02bbfab8165b02 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 25 Oct 2020 14:32:18 +0200 Subject: [PATCH] Backport stream header fix (#1959) Resolve headers as body in ASGI mode * Bump version to 19.12.3 * Update multidict==5.0.0 --- sanic/__version__.py | 2 +- sanic/asgi.py | 2 ++ sanic/request.py | 10 +++++----- sanic/response.py | 12 ++++++++++-- sanic/router.py | 2 +- sanic/worker.py | 2 +- setup.py | 13 ++++--------- tests/test_keep_alive_timeout.py | 4 ++-- tests/test_response.py | 16 ++++++++++++++++ 9 files changed, 42 insertions(+), 21 deletions(-) diff --git a/sanic/__version__.py b/sanic/__version__.py index 4410ae56..d9f9ee85 100644 --- a/sanic/__version__.py +++ b/sanic/__version__.py @@ -1 +1 @@ -__version__ = "19.12.2" +__version__ = "19.12.3" diff --git a/sanic/asgi.py b/sanic/asgi.py index f08cc454..b5014b51 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -347,6 +347,8 @@ class ASGIApp: if name not in (b"Set-Cookie",) ] + response.asgi = True + if "content-length" not in response.headers and not isinstance( response, StreamingHTTPResponse ): diff --git a/sanic/request.py b/sanic/request.py index 246eb351..3c765fa3 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -129,27 +129,27 @@ class Request: def get(self, key, default=None): """.. deprecated:: 19.9 - Custom context is now stored in `request.custom_context.yourkey`""" + Custom context is now stored in `request.custom_context.yourkey`""" return self.ctx.__dict__.get(key, default) def __contains__(self, key): """.. deprecated:: 19.9 - Custom context is now stored in `request.custom_context.yourkey`""" + Custom context is now stored in `request.custom_context.yourkey`""" return key in self.ctx.__dict__ def __getitem__(self, key): """.. deprecated:: 19.9 - Custom context is now stored in `request.custom_context.yourkey`""" + Custom context is now stored in `request.custom_context.yourkey`""" return self.ctx.__dict__[key] def __delitem__(self, key): """.. deprecated:: 19.9 - Custom context is now stored in `request.custom_context.yourkey`""" + Custom context is now stored in `request.custom_context.yourkey`""" del self.ctx.__dict__[key] def __setitem__(self, key, value): """.. deprecated:: 19.9 - Custom context is now stored in `request.custom_context.yourkey`""" + Custom context is now stored in `request.custom_context.yourkey`""" setattr(self.ctx, key, value) def body_init(self): diff --git a/sanic/response.py b/sanic/response.py index 4a84cf47..4db94fc5 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -22,6 +22,9 @@ except ImportError: class BaseHTTPResponse: + def __init__(self): + self.asgi = False + def _encode_body(self, data): try: # Try to encode it regularly @@ -59,6 +62,8 @@ class StreamingHTTPResponse(BaseHTTPResponse): content_type="text/plain", chunked=True, ): + super().__init__() + self.content_type = content_type self.streaming_fn = streaming_fn self.status = status @@ -94,8 +99,9 @@ class StreamingHTTPResponse(BaseHTTPResponse): keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout, ) - await self.protocol.push_data(headers) - await self.protocol.drain() + if not getattr(self, "asgi", False): + await self.protocol.push_data(headers) + await self.protocol.drain() await self.streaming_fn(self) if self.chunked: await self.protocol.push_data(b"0\r\n\r\n") @@ -145,6 +151,8 @@ class HTTPResponse(BaseHTTPResponse): content_type=None, body_bytes=b"", ): + super().__init__() + self.content_type = content_type if body is not None: diff --git a/sanic/router.py b/sanic/router.py index 2d8817a3..698589a5 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -484,7 +484,7 @@ class Router: return route_handler, [], kwargs, route.uri, route.name def is_stream_handler(self, request): - """ Handler for request is stream or not. + """Handler for request is stream or not. :param request: Request object :return: bool """ diff --git a/sanic/worker.py b/sanic/worker.py index 777f12cf..d42662dc 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -174,7 +174,7 @@ class GunicornWorker(base.Worker): @staticmethod def _create_ssl_context(cfg): - """ Creates SSLContext instance for usage in asyncio.create_server. + """Creates SSLContext instance for usage in asyncio.create_server. See ssl.SSLSocket.__init__ for more details. """ ctx = ssl.SSLContext(cfg.ssl_version) diff --git a/setup.py b/setup.py index 019769cf..e762a088 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ import codecs import os import re import sys - from distutils.util import strtobool from setuptools import setup @@ -39,9 +38,7 @@ def open_local(paths, mode="r", encoding="utf8"): with open_local(["sanic", "__version__.py"], encoding="latin1") as fp: try: - version = re.findall( - r"^__version__ = \"([^']+)\"\r?$", fp.read(), re.M - )[0] + version = re.findall(r"^__version__ = \"([^']+)\"\r?$", fp.read(), re.M)[0] except IndexError: raise RuntimeError("Unable to determine version.") @@ -71,9 +68,7 @@ setup_kwargs = { ], } -env_dependency = ( - '; sys_platform != "win32" ' 'and implementation_name == "cpython"' -) +env_dependency = '; sys_platform != "win32" ' 'and implementation_name == "cpython"' ujson = "ujson>=1.35" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency @@ -83,13 +78,13 @@ requirements = [ ujson, "aiofiles>=0.3.0", "websockets>=7.0,<9.0", - "multidict>=4.0,<5.0", + "multidict==5.0.0", "httpx==0.9.3", ] tests_require = [ "pytest==5.2.1", - "multidict>=4.0,<5.0", + "multidict==5.0.0", "gunicorn", "pytest-cov", "httpcore==0.3.0", diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index a59d6c5b..bec433be 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -230,8 +230,8 @@ async def handler3(request): def test_keep_alive_timeout_reuse(): """If the server keep-alive timeout and client keep-alive timeout are - both longer than the delay, the client _and_ server will successfully - reuse the existing connection.""" + both longer than the delay, the client _and_ server will successfully + reuse the existing connection.""" try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) diff --git a/tests/test_response.py b/tests/test_response.py index c6e16dd2..488a76e7 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -232,6 +232,12 @@ def test_chunked_streaming_returns_correct_content(streaming_app): assert response.text == "foo,bar" +@pytest.mark.asyncio +async def test_chunked_streaming_returns_correct_content_asgi(streaming_app): + request, response = await streaming_app.asgi_client.get("/") + assert response.text == "4\r\nfoo,\r\n3\r\nbar\r\n0\r\n\r\n" + + def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): request, response = non_chunked_streaming_app.test_client.get("/") assert "Transfer-Encoding" not in response.headers @@ -239,6 +245,16 @@ def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): assert response.headers["Content-Length"] == "7" +@pytest.mark.asyncio +async def test_non_chunked_streaming_adds_correct_headers_asgi( + non_chunked_streaming_app, +): + request, response = await non_chunked_streaming_app.asgi_client.get("/") + assert "Transfer-Encoding" not in response.headers + assert response.headers["Content-Type"] == "text/csv" + assert response.headers["Content-Length"] == "7" + + def test_non_chunked_streaming_returns_correct_content( non_chunked_streaming_app, ):