diff --git a/sanic/asgi.py b/sanic/asgi.py index f08cc454..cf29a0cc 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,6 +1,5 @@ import asyncio import warnings - from inspect import isawaitable from typing import ( Any, @@ -16,7 +15,6 @@ from typing import ( from urllib.parse import quote import sanic.app # noqa - from sanic.compat import Header from sanic.exceptions import InvalidUsage, ServerError from sanic.log import logger @@ -25,7 +23,6 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.server import StreamBuffer from sanic.websocket import WebSocketConnection - ASGIScope = MutableMapping[str, Any] ASGIMessage = MutableMapping[str, Any] ASGISend = Callable[[ASGIMessage], Awaitable[None]] @@ -68,9 +65,7 @@ class MockProtocol: class MockTransport: _protocol: Optional[MockProtocol] - def __init__( - self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend - ) -> None: + def __init__(self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> None: self.scope = scope self._receive = receive self._send = send @@ -146,9 +141,7 @@ class Lifespan: ) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) + response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop) if isawaitable(response): await response @@ -166,9 +159,7 @@ class Lifespan: ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) + response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop) if isawaitable(response): await response @@ -213,19 +204,13 @@ class ASGIApp: for key, value in scope.get("headers", []) ] ) - instance.do_stream = ( - True if headers.get("expect") == "100-continue" else False - ) + instance.do_stream = True if headers.get("expect") == "100-continue" else False instance.lifespan = Lifespan(instance) if scope["type"] == "lifespan": await instance.lifespan(scope, receive, send) else: - path = ( - scope["path"][1:] - if scope["path"].startswith("/") - else scope["path"] - ) + path = scope["path"][1:] if scope["path"].startswith("/") else scope["path"] url = "/".join([scope.get("root_path", ""), quote(path)]) url_bytes = url.encode("latin-1") url_bytes += b"?" + scope["query_string"] @@ -248,18 +233,11 @@ class ASGIApp: request_class = sanic_app.request_class or Request instance.request = request_class( - url_bytes, - headers, - version, - method, - instance.transport, - sanic_app, + url_bytes, headers, version, method, instance.transport, sanic_app, ) if sanic_app.is_request_stream: - is_stream_handler = sanic_app.router.is_stream_handler( - instance.request - ) + is_stream_handler = sanic_app.router.is_stream_handler(instance.request) if is_stream_handler: instance.request.stream = StreamBuffer( sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE @@ -313,6 +291,7 @@ class ASGIApp: """ Write the response. """ + response.asgi = True headers: List[Tuple[bytes, bytes]] = [] cookies: Dict[str, str] = {} try: @@ -338,9 +317,7 @@ class ASGIApp: type(response), ) exception = ServerError("Invalid response type") - response = self.sanic_app.error_handler.response( - self.request, exception - ) + response = self.sanic_app.error_handler.response(self.request, exception) headers = [ (str(name).encode("latin-1"), str(value).encode("latin-1")) for name, value in response.headers.items() @@ -350,14 +327,10 @@ class ASGIApp: if "content-length" not in response.headers and not isinstance( response, StreamingHTTPResponse ): - headers += [ - (b"content-length", str(len(response.body)).encode("latin-1")) - ] + headers += [(b"content-length", str(len(response.body)).encode("latin-1"))] if "content-type" not in response.headers: - headers += [ - (b"content-type", str(response.content_type).encode("latin-1")) - ] + headers += [(b"content-type", str(response.content_type).encode("latin-1"))] if response.cookies: cookies.update( @@ -369,8 +342,7 @@ class ASGIApp: ) headers += [ - (b"set-cookie", cookie.encode("utf-8")) - for k, cookie in cookies.items() + (b"set-cookie", cookie.encode("utf-8")) for k, cookie in cookies.items() ] await self.transport.send( diff --git a/sanic/response.py b/sanic/response.py index 4a84cf47..2cb83987 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -10,7 +10,6 @@ from sanic.cookies import CookieJar from sanic.headers import format_http1 from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers - try: from ujson import dumps as json_dumps except ImportError: @@ -81,30 +80,25 @@ class StreamingHTTPResponse(BaseHTTPResponse): await self.protocol.push_data(data) await self.protocol.drain() - async def stream( - self, version="1.1", keep_alive=False, keep_alive_timeout=None - ): + async def stream(self, version="1.1", keep_alive=False, keep_alive_timeout=None): """Streams headers, runs the `streaming_fn` callback that writes content to the response body, then finalizes the response body. """ if version != "1.1": self.chunked = False headers = self.get_headers( - version, - keep_alive=keep_alive, - keep_alive_timeout=keep_alive_timeout, + version, keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout, ) - await self.protocol.push_data(headers) - await self.protocol.drain() + if not getattr(self, "asgi", False): + await self.protocol.push_data(headers) + await self.protocol.drain() await self.streaming_fn(self) if self.chunked: await self.protocol.push_data(b"0\r\n\r\n") # no need to await drain here after this write, because it is the # very last thing we write and nothing needs to wait for it. - def get_headers( - self, version="1.1", keep_alive=False, keep_alive_timeout=None - ): + def get_headers(self, version="1.1", keep_alive=False, keep_alive_timeout=None): # This is all returned in a kind-of funky way # We tried to make this as fast as possible in pure python timeout_header = b"" @@ -138,12 +132,7 @@ class HTTPResponse(BaseHTTPResponse): __slots__ = ("body", "status", "content_type", "headers", "_cookies") def __init__( - self, - body=None, - status=200, - headers=None, - content_type=None, - body_bytes=b"", + self, body=None, status=200, headers=None, content_type=None, body_bytes=b"", ): self.content_type = content_type @@ -184,9 +173,7 @@ class HTTPResponse(BaseHTTPResponse): else: status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") - return ( - b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b" - ) % ( + return (b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b") % ( version.encode(), self.status, status, @@ -237,9 +224,7 @@ def json( ) -def text( - body, status=200, headers=None, content_type="text/plain; charset=utf-8" -): +def text(body, status=200, headers=None, content_type="text/plain; charset=utf-8"): """ Returns response object with body in text format. @@ -248,14 +233,10 @@ def text( :param headers: Custom Headers. :param content_type: the content type (string) of the response """ - return HTTPResponse( - body, status=status, headers=headers, content_type=content_type - ) + return HTTPResponse(body, status=status, headers=headers, content_type=content_type) -def raw( - body, status=200, headers=None, content_type="application/octet-stream" -): +def raw(body, status=200, headers=None, content_type="application/octet-stream"): """ Returns response object without encoding the body. @@ -265,10 +246,7 @@ def raw( :param content_type: the content type (string) of the response. """ return HTTPResponse( - body_bytes=body, - status=status, - headers=headers, - content_type=content_type, + body_bytes=body, status=status, headers=headers, content_type=content_type, ) @@ -281,20 +259,12 @@ def html(body, status=200, headers=None): :param headers: Custom Headers. """ return HTTPResponse( - body, - status=status, - headers=headers, - content_type="text/html; charset=utf-8", + body, status=status, headers=headers, content_type="text/html; charset=utf-8", ) async def file( - location, - status=200, - mime_type=None, - headers=None, - filename=None, - _range=None, + location, status=200, mime_type=None, headers=None, filename=None, _range=None, ): """Return a response object with file data. @@ -326,10 +296,7 @@ async def file( mime_type = mime_type or guess_type(filename)[0] or "text/plain" return HTTPResponse( - status=status, - headers=headers, - content_type=mime_type, - body_bytes=out_stream, + status=status, headers=headers, content_type=mime_type, body_bytes=out_stream, ) @@ -437,9 +404,7 @@ def stream( ) -def redirect( - to, headers=None, status=302, content_type="text/html; charset=utf-8" -): +def redirect(to, headers=None, status=302, content_type="text/html; charset=utf-8"): """Abort execution and cause a 302 redirect (by default). :param to: path or fully qualified URL to redirect to @@ -456,6 +421,5 @@ def redirect( # According to RFC 7231, a relative URI is now permitted. headers["Location"] = safe_to - return HTTPResponse( - status=status, headers=headers, content_type=content_type - ) + return HTTPResponse(status=status, headers=headers, content_type=content_type) + diff --git a/tests/test_response.py b/tests/test_response.py index c6e16dd2..07bfc18a 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,7 +1,6 @@ import asyncio import inspect import os - from collections import namedtuple from mimetypes import guess_type from random import choice @@ -9,7 +8,6 @@ from unittest.mock import MagicMock from urllib.parse import unquote import pytest - from aiofiles import os as async_os from sanic.response import ( @@ -25,7 +23,6 @@ from sanic.response import ( from sanic.server import HttpProtocol from sanic.testing import HOST, PORT - JSON_DATA = {"ok": True} @@ -103,14 +100,10 @@ def test_response_content_length(app): ) _, response = app.test_client.get("/response_with_space") - content_length_for_response_with_space = response.headers.get( - "Content-Length" - ) + content_length_for_response_with_space = response.headers.get("Content-Length") _, response = app.test_client.get("/response_without_space") - content_length_for_response_without_space = response.headers.get( - "Content-Length" - ) + content_length_for_response_without_space = response.headers.get("Content-Length") assert ( content_length_for_response_with_space @@ -232,6 +225,12 @@ def test_chunked_streaming_returns_correct_content(streaming_app): assert response.text == "foo,bar" +@pytest.mark.asyncio +async def test_chunked_streaming_returns_correct_content_asgi(streaming_app): + request, response = await streaming_app.asgi_client.get("/") + assert response.text == "4\r\nfoo,\r\n3\r\nbar\r\n0\r\n\r\n" + + def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): request, response = non_chunked_streaming_app.test_client.get("/") assert "Transfer-Encoding" not in response.headers @@ -239,9 +238,17 @@ def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): assert response.headers["Content-Length"] == "7" -def test_non_chunked_streaming_returns_correct_content( +@pytest.mark.asyncio +async def test_non_chunked_streaming_adds_correct_headers_asgi( non_chunked_streaming_app, ): + request, response = await non_chunked_streaming_app.asgi_client.get("/") + assert "Transfer-Encoding" not in response.headers + assert response.headers["Content-Type"] == "text/csv" + assert response.headers["Content-Length"] == "7" + + +def test_non_chunked_streaming_returns_correct_content(non_chunked_streaming_app,): request, response = non_chunked_streaming_app.test_client.get("/") assert response.text == "foo,bar" @@ -254,9 +261,7 @@ def test_stream_response_status_returns_correct_headers(status): @pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30]) -def test_stream_response_keep_alive_returns_correct_headers( - keep_alive_timeout, -): +def test_stream_response_keep_alive_returns_correct_headers(keep_alive_timeout,): response = StreamingHTTPResponse(sample_streaming_fn) headers = response.get_headers( keep_alive=True, keep_alive_timeout=keep_alive_timeout @@ -340,13 +345,9 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked( @streaming_app.listener("after_server_start") async def run_stream(app, loop): await response.stream(version="1.0") - assert response.protocol.transport.write.call_args_list[1][0][0] == ( - b"foo," - ) + assert response.protocol.transport.write.call_args_list[1][0][0] == (b"foo,") - assert response.protocol.transport.write.call_args_list[2][0][0] == ( - b"bar" - ) + assert response.protocol.transport.write.call_args_list[2][0][0] == (b"bar") assert len(response.protocol.transport.write.call_args_list) == 3 @@ -391,9 +392,7 @@ def get_file_content(static_file_directory, file_name): return file.read() -@pytest.mark.parametrize( - "file_name", ["test.file", "decode me.txt", "python.png"] -) +@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"]) @pytest.mark.parametrize("status", [200, 401]) def test_file_response(app, file_name, static_file_directory, status): @app.route("/files/", methods=["GET"]) @@ -420,9 +419,7 @@ def test_file_response(app, file_name, static_file_directory, status): ("python.png", "logo.png"), ], ) -def test_file_response_custom_filename( - app, source, dest, static_file_directory -): +def test_file_response_custom_filename(app, source, dest, static_file_directory): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) @@ -449,8 +446,7 @@ def test_file_head_response(app, file_name, static_file_directory): headers["Content-Length"] = str(stats.st_size) if request.method == "HEAD": return HTTPResponse( - headers=headers, - content_type=guess_type(file_path)[0] or "text/plain", + headers=headers, content_type=guess_type(file_path)[0] or "text/plain", ) else: return file( @@ -468,9 +464,7 @@ def test_file_head_response(app, file_name, static_file_directory): ) -@pytest.mark.parametrize( - "file_name", ["test.file", "decode me.txt", "python.png"] -) +@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"]) def test_file_stream_response(app, file_name, static_file_directory): @app.route("/files/", methods=["GET"]) def file_route(request, filename): @@ -496,9 +490,7 @@ def test_file_stream_response(app, file_name, static_file_directory): ("python.png", "logo.png"), ], ) -def test_file_stream_response_custom_filename( - app, source, dest, static_file_directory -): +def test_file_stream_response_custom_filename(app, source, dest, static_file_directory): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) @@ -527,8 +519,7 @@ def test_file_stream_head_response(app, file_name, static_file_directory): stats = await async_os.stat(file_path) headers["Content-Length"] = str(stats.st_size) return HTTPResponse( - headers=headers, - content_type=guess_type(file_path)[0] or "text/plain", + headers=headers, content_type=guess_type(file_path)[0] or "text/plain", ) else: return file_stream( @@ -551,12 +542,8 @@ def test_file_stream_head_response(app, file_name, static_file_directory): ) -@pytest.mark.parametrize( - "file_name", ["test.file", "decode me.txt", "python.png"] -) -@pytest.mark.parametrize( - "size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)] -) +@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"]) +@pytest.mark.parametrize("size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)]) def test_file_stream_response_range( app, file_name, static_file_directory, size, start, end ):