From e5aed4c067e28998ed72a707b616ecc9d048e9e0 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 25 Oct 2020 15:01:53 +0200 Subject: [PATCH] Ignore writing headers when in ASGI mode (#1957) * Ignore writing headers when in ASGI mode for streaming responses * Move asgi set on streaming until after response type check * Adds multidict==5.0.0 to pass tests * Bump version to 20.9.1 --- examples/run_asgi.py | 6 ++---- sanic/__version__.py | 2 +- sanic/asgi.py | 2 ++ sanic/response.py | 22 +++++++++++++++------- tests/test_response.py | 16 ++++++++++++++++ 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/examples/run_asgi.py b/examples/run_asgi.py index 44be25f5..e54d5d5d 100644 --- a/examples/run_asgi.py +++ b/examples/run_asgi.py @@ -7,8 +7,8 @@ """ from pathlib import Path -from sanic import Sanic, response +from sanic import Sanic, response app = Sanic(__name__) @@ -42,9 +42,7 @@ async def handler_file(request): @app.route("/file_stream") async def handler_file_stream(request): - return await response.file_stream( - Path("../") / "setup.py", chunk_size=1024 - ) + return await response.file_stream(Path("../") / "setup.py", chunk_size=1024) @app.route("/stream", stream=True) diff --git a/sanic/__version__.py b/sanic/__version__.py index d59f2279..0d8f82c6 100644 --- a/sanic/__version__.py +++ b/sanic/__version__.py @@ -1 +1 @@ -__version__ = "20.9.0" +__version__ = "20.9.1" diff --git a/sanic/asgi.py b/sanic/asgi.py index 5ec13cf4..2a3c4540 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -350,6 +350,8 @@ class ASGIApp: if name not in (b"Set-Cookie",) ] + response.asgi = True + if "content-length" not in response.headers and not isinstance( response, StreamingHTTPResponse ): diff --git a/sanic/response.py b/sanic/response.py index 24033336..1f7b12de 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -22,6 +22,9 @@ except ImportError: class BaseHTTPResponse: + def __init__(self): + self.asgi = False + def _encode_body(self, data): return data.encode() if hasattr(data, "encode") else data @@ -80,6 +83,8 @@ class StreamingHTTPResponse(BaseHTTPResponse): content_type="text/plain; charset=utf-8", chunked=True, ): + super().__init__() + self.content_type = content_type self.streaming_fn = streaming_fn self.status = status @@ -109,13 +114,14 @@ class StreamingHTTPResponse(BaseHTTPResponse): """ if version != "1.1": self.chunked = False - headers = self.get_headers( - version, - keep_alive=keep_alive, - keep_alive_timeout=keep_alive_timeout, - ) - await self.protocol.push_data(headers) - await self.protocol.drain() + if not getattr(self, "asgi", False): + headers = self.get_headers( + version, + keep_alive=keep_alive, + keep_alive_timeout=keep_alive_timeout, + ) + await self.protocol.push_data(headers) + await self.protocol.drain() await self.streaming_fn(self) if self.chunked: await self.protocol.push_data(b"0\r\n\r\n") @@ -143,6 +149,8 @@ class HTTPResponse(BaseHTTPResponse): content_type=None, body_bytes=b"", ): + super().__init__() + self.content_type = content_type self.body = body_bytes if body is None else self._encode_body(body) self.status = status diff --git a/tests/test_response.py b/tests/test_response.py index 625317d5..6e2f2a9a 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -235,6 +235,12 @@ def test_chunked_streaming_returns_correct_content(streaming_app): assert response.text == "foo,bar" +@pytest.mark.asyncio +async def test_chunked_streaming_returns_correct_content_asgi(streaming_app): + request, response = await streaming_app.asgi_client.get("/") + assert response.text == "4\r\nfoo,\r\n3\r\nbar\r\n0\r\n\r\n" + + def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): request, response = non_chunked_streaming_app.test_client.get("/") assert "Transfer-Encoding" not in response.headers @@ -242,6 +248,16 @@ def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): assert response.headers["Content-Length"] == "7" +@pytest.mark.asyncio +async def test_non_chunked_streaming_adds_correct_headers_asgi( + non_chunked_streaming_app, +): + request, response = await non_chunked_streaming_app.asgi_client.get("/") + assert "Transfer-Encoding" not in response.headers + assert response.headers["Content-Type"] == "text/csv" + assert response.headers["Content-Length"] == "7" + + def test_non_chunked_streaming_returns_correct_content( non_chunked_streaming_app, ):