Resolve headers as body in ASGI mode

This commit is contained in:
Adam Hopkins 2020-10-25 10:40:08 +02:00
parent 2a44a27236
commit c09129ec63
3 changed files with 58 additions and 135 deletions

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import warnings import warnings
from inspect import isawaitable from inspect import isawaitable
from typing import ( from typing import (
Any, Any,
@ -16,7 +15,6 @@ from typing import (
from urllib.parse import quote from urllib.parse import quote
import sanic.app # noqa import sanic.app # noqa
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger from sanic.log import logger
@ -25,7 +23,6 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import StreamBuffer from sanic.server import StreamBuffer
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any] ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any] ASGIMessage = MutableMapping[str, Any]
ASGISend = Callable[[ASGIMessage], Awaitable[None]] ASGISend = Callable[[ASGIMessage], Awaitable[None]]
@ -68,9 +65,7 @@ class MockProtocol:
class MockTransport: class MockTransport:
_protocol: Optional[MockProtocol] _protocol: Optional[MockProtocol]
def __init__( def __init__(self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> None:
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.scope = scope self.scope = scope
self._receive = receive self._receive = receive
self._send = send self._send = send
@ -146,9 +141,7 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) ) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in listeners: for handler in listeners:
response = handler( response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response): if isawaitable(response):
await response await response
@ -166,9 +159,7 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in listeners: for handler in listeners:
response = handler( response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response): if isawaitable(response):
await response await response
@ -213,19 +204,13 @@ class ASGIApp:
for key, value in scope.get("headers", []) for key, value in scope.get("headers", [])
] ]
) )
instance.do_stream = ( instance.do_stream = True if headers.get("expect") == "100-continue" else False
True if headers.get("expect") == "100-continue" else False
)
instance.lifespan = Lifespan(instance) instance.lifespan = Lifespan(instance)
if scope["type"] == "lifespan": if scope["type"] == "lifespan":
await instance.lifespan(scope, receive, send) await instance.lifespan(scope, receive, send)
else: else:
path = ( path = scope["path"][1:] if scope["path"].startswith("/") else scope["path"]
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
url = "/".join([scope.get("root_path", ""), quote(path)]) url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1") url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"] url_bytes += b"?" + scope["query_string"]
@ -248,18 +233,11 @@ class ASGIApp:
request_class = sanic_app.request_class or Request request_class = sanic_app.request_class or Request
instance.request = request_class( instance.request = request_class(
url_bytes, url_bytes, headers, version, method, instance.transport, sanic_app,
headers,
version,
method,
instance.transport,
sanic_app,
) )
if sanic_app.is_request_stream: if sanic_app.is_request_stream:
is_stream_handler = sanic_app.router.is_stream_handler( is_stream_handler = sanic_app.router.is_stream_handler(instance.request)
instance.request
)
if is_stream_handler: if is_stream_handler:
instance.request.stream = StreamBuffer( instance.request.stream = StreamBuffer(
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
@ -313,6 +291,7 @@ class ASGIApp:
""" """
Write the response. Write the response.
""" """
response.asgi = True
headers: List[Tuple[bytes, bytes]] = [] headers: List[Tuple[bytes, bytes]] = []
cookies: Dict[str, str] = {} cookies: Dict[str, str] = {}
try: try:
@ -338,9 +317,7 @@ class ASGIApp:
type(response), type(response),
) )
exception = ServerError("Invalid response type") exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response( response = self.sanic_app.error_handler.response(self.request, exception)
self.request, exception
)
headers = [ headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1")) (str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items() for name, value in response.headers.items()
@ -350,14 +327,10 @@ class ASGIApp:
if "content-length" not in response.headers and not isinstance( if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse response, StreamingHTTPResponse
): ):
headers += [ headers += [(b"content-length", str(len(response.body)).encode("latin-1"))]
(b"content-length", str(len(response.body)).encode("latin-1"))
]
if "content-type" not in response.headers: if "content-type" not in response.headers:
headers += [ headers += [(b"content-type", str(response.content_type).encode("latin-1"))]
(b"content-type", str(response.content_type).encode("latin-1"))
]
if response.cookies: if response.cookies:
cookies.update( cookies.update(
@ -369,8 +342,7 @@ class ASGIApp:
) )
headers += [ headers += [
(b"set-cookie", cookie.encode("utf-8")) (b"set-cookie", cookie.encode("utf-8")) for k, cookie in cookies.items()
for k, cookie in cookies.items()
] ]
await self.transport.send( await self.transport.send(

View File

@ -10,7 +10,6 @@ from sanic.cookies import CookieJar
from sanic.headers import format_http1 from sanic.headers import format_http1
from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers
try: try:
from ujson import dumps as json_dumps from ujson import dumps as json_dumps
except ImportError: except ImportError:
@ -81,30 +80,25 @@ class StreamingHTTPResponse(BaseHTTPResponse):
await self.protocol.push_data(data) await self.protocol.push_data(data)
await self.protocol.drain() await self.protocol.drain()
async def stream( async def stream(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
"""Streams headers, runs the `streaming_fn` callback that writes """Streams headers, runs the `streaming_fn` callback that writes
content to the response body, then finalizes the response body. content to the response body, then finalizes the response body.
""" """
if version != "1.1": if version != "1.1":
self.chunked = False self.chunked = False
headers = self.get_headers( headers = self.get_headers(
version, version, keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout,
keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout,
) )
await self.protocol.push_data(headers) if not getattr(self, "asgi", False):
await self.protocol.drain() await self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self) await self.streaming_fn(self)
if self.chunked: if self.chunked:
await self.protocol.push_data(b"0\r\n\r\n") await self.protocol.push_data(b"0\r\n\r\n")
# no need to await drain here after this write, because it is the # no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it. # very last thing we write and nothing needs to wait for it.
def get_headers( def get_headers(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
# This is all returned in a kind-of funky way # This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python # We tried to make this as fast as possible in pure python
timeout_header = b"" timeout_header = b""
@ -138,12 +132,7 @@ class HTTPResponse(BaseHTTPResponse):
__slots__ = ("body", "status", "content_type", "headers", "_cookies") __slots__ = ("body", "status", "content_type", "headers", "_cookies")
def __init__( def __init__(
self, self, body=None, status=200, headers=None, content_type=None, body_bytes=b"",
body=None,
status=200,
headers=None,
content_type=None,
body_bytes=b"",
): ):
self.content_type = content_type self.content_type = content_type
@ -184,9 +173,7 @@ class HTTPResponse(BaseHTTPResponse):
else: else:
status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
return ( return (b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b") % (
b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b"
) % (
version.encode(), version.encode(),
self.status, self.status,
status, status,
@ -237,9 +224,7 @@ def json(
) )
def text( def text(body, status=200, headers=None, content_type="text/plain; charset=utf-8"):
body, status=200, headers=None, content_type="text/plain; charset=utf-8"
):
""" """
Returns response object with body in text format. Returns response object with body in text format.
@ -248,14 +233,10 @@ def text(
:param headers: Custom Headers. :param headers: Custom Headers.
:param content_type: the content type (string) of the response :param content_type: the content type (string) of the response
""" """
return HTTPResponse( return HTTPResponse(body, status=status, headers=headers, content_type=content_type)
body, status=status, headers=headers, content_type=content_type
)
def raw( def raw(body, status=200, headers=None, content_type="application/octet-stream"):
body, status=200, headers=None, content_type="application/octet-stream"
):
""" """
Returns response object without encoding the body. Returns response object without encoding the body.
@ -265,10 +246,7 @@ def raw(
:param content_type: the content type (string) of the response. :param content_type: the content type (string) of the response.
""" """
return HTTPResponse( return HTTPResponse(
body_bytes=body, body_bytes=body, status=status, headers=headers, content_type=content_type,
status=status,
headers=headers,
content_type=content_type,
) )
@ -281,20 +259,12 @@ def html(body, status=200, headers=None):
:param headers: Custom Headers. :param headers: Custom Headers.
""" """
return HTTPResponse( return HTTPResponse(
body, body, status=status, headers=headers, content_type="text/html; charset=utf-8",
status=status,
headers=headers,
content_type="text/html; charset=utf-8",
) )
async def file( async def file(
location, location, status=200, mime_type=None, headers=None, filename=None, _range=None,
status=200,
mime_type=None,
headers=None,
filename=None,
_range=None,
): ):
"""Return a response object with file data. """Return a response object with file data.
@ -326,10 +296,7 @@ async def file(
mime_type = mime_type or guess_type(filename)[0] or "text/plain" mime_type = mime_type or guess_type(filename)[0] or "text/plain"
return HTTPResponse( return HTTPResponse(
status=status, status=status, headers=headers, content_type=mime_type, body_bytes=out_stream,
headers=headers,
content_type=mime_type,
body_bytes=out_stream,
) )
@ -437,9 +404,7 @@ def stream(
) )
def redirect( def redirect(to, headers=None, status=302, content_type="text/html; charset=utf-8"):
to, headers=None, status=302, content_type="text/html; charset=utf-8"
):
"""Abort execution and cause a 302 redirect (by default). """Abort execution and cause a 302 redirect (by default).
:param to: path or fully qualified URL to redirect to :param to: path or fully qualified URL to redirect to
@ -456,6 +421,5 @@ def redirect(
# According to RFC 7231, a relative URI is now permitted. # According to RFC 7231, a relative URI is now permitted.
headers["Location"] = safe_to headers["Location"] = safe_to
return HTTPResponse( return HTTPResponse(status=status, headers=headers, content_type=content_type)
status=status, headers=headers, content_type=content_type
)

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import inspect import inspect
import os import os
from collections import namedtuple from collections import namedtuple
from mimetypes import guess_type from mimetypes import guess_type
from random import choice from random import choice
@ -9,7 +8,6 @@ from unittest.mock import MagicMock
from urllib.parse import unquote from urllib.parse import unquote
import pytest import pytest
from aiofiles import os as async_os from aiofiles import os as async_os
from sanic.response import ( from sanic.response import (
@ -25,7 +23,6 @@ from sanic.response import (
from sanic.server import HttpProtocol from sanic.server import HttpProtocol
from sanic.testing import HOST, PORT from sanic.testing import HOST, PORT
JSON_DATA = {"ok": True} JSON_DATA = {"ok": True}
@ -103,14 +100,10 @@ def test_response_content_length(app):
) )
_, response = app.test_client.get("/response_with_space") _, response = app.test_client.get("/response_with_space")
content_length_for_response_with_space = response.headers.get( content_length_for_response_with_space = response.headers.get("Content-Length")
"Content-Length"
)
_, response = app.test_client.get("/response_without_space") _, response = app.test_client.get("/response_without_space")
content_length_for_response_without_space = response.headers.get( content_length_for_response_without_space = response.headers.get("Content-Length")
"Content-Length"
)
assert ( assert (
content_length_for_response_with_space content_length_for_response_with_space
@ -232,6 +225,12 @@ def test_chunked_streaming_returns_correct_content(streaming_app):
assert response.text == "foo,bar" assert response.text == "foo,bar"
@pytest.mark.asyncio
async def test_chunked_streaming_returns_correct_content_asgi(streaming_app):
request, response = await streaming_app.asgi_client.get("/")
assert response.text == "4\r\nfoo,\r\n3\r\nbar\r\n0\r\n\r\n"
def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
request, response = non_chunked_streaming_app.test_client.get("/") request, response = non_chunked_streaming_app.test_client.get("/")
assert "Transfer-Encoding" not in response.headers assert "Transfer-Encoding" not in response.headers
@ -239,9 +238,17 @@ def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
assert response.headers["Content-Length"] == "7" assert response.headers["Content-Length"] == "7"
def test_non_chunked_streaming_returns_correct_content( @pytest.mark.asyncio
async def test_non_chunked_streaming_adds_correct_headers_asgi(
non_chunked_streaming_app, non_chunked_streaming_app,
): ):
request, response = await non_chunked_streaming_app.asgi_client.get("/")
assert "Transfer-Encoding" not in response.headers
assert response.headers["Content-Type"] == "text/csv"
assert response.headers["Content-Length"] == "7"
def test_non_chunked_streaming_returns_correct_content(non_chunked_streaming_app,):
request, response = non_chunked_streaming_app.test_client.get("/") request, response = non_chunked_streaming_app.test_client.get("/")
assert response.text == "foo,bar" assert response.text == "foo,bar"
@ -254,9 +261,7 @@ def test_stream_response_status_returns_correct_headers(status):
@pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30]) @pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30])
def test_stream_response_keep_alive_returns_correct_headers( def test_stream_response_keep_alive_returns_correct_headers(keep_alive_timeout,):
keep_alive_timeout,
):
response = StreamingHTTPResponse(sample_streaming_fn) response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers( headers = response.get_headers(
keep_alive=True, keep_alive_timeout=keep_alive_timeout keep_alive=True, keep_alive_timeout=keep_alive_timeout
@ -340,13 +345,9 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
@streaming_app.listener("after_server_start") @streaming_app.listener("after_server_start")
async def run_stream(app, loop): async def run_stream(app, loop):
await response.stream(version="1.0") await response.stream(version="1.0")
assert response.protocol.transport.write.call_args_list[1][0][0] == ( assert response.protocol.transport.write.call_args_list[1][0][0] == (b"foo,")
b"foo,"
)
assert response.protocol.transport.write.call_args_list[2][0][0] == ( assert response.protocol.transport.write.call_args_list[2][0][0] == (b"bar")
b"bar"
)
assert len(response.protocol.transport.write.call_args_list) == 3 assert len(response.protocol.transport.write.call_args_list) == 3
@ -391,9 +392,7 @@ def get_file_content(static_file_directory, file_name):
return file.read() return file.read()
@pytest.mark.parametrize( @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"])
"file_name", ["test.file", "decode me.txt", "python.png"]
)
@pytest.mark.parametrize("status", [200, 401]) @pytest.mark.parametrize("status", [200, 401])
def test_file_response(app, file_name, static_file_directory, status): def test_file_response(app, file_name, static_file_directory, status):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
@ -420,9 +419,7 @@ def test_file_response(app, file_name, static_file_directory, status):
("python.png", "logo.png"), ("python.png", "logo.png"),
], ],
) )
def test_file_response_custom_filename( def test_file_response_custom_filename(app, source, dest, static_file_directory):
app, source, dest, static_file_directory
):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
def file_route(request, filename): def file_route(request, filename):
file_path = os.path.join(static_file_directory, filename) file_path = os.path.join(static_file_directory, filename)
@ -449,8 +446,7 @@ def test_file_head_response(app, file_name, static_file_directory):
headers["Content-Length"] = str(stats.st_size) headers["Content-Length"] = str(stats.st_size)
if request.method == "HEAD": if request.method == "HEAD":
return HTTPResponse( return HTTPResponse(
headers=headers, headers=headers, content_type=guess_type(file_path)[0] or "text/plain",
content_type=guess_type(file_path)[0] or "text/plain",
) )
else: else:
return file( return file(
@ -468,9 +464,7 @@ def test_file_head_response(app, file_name, static_file_directory):
) )
@pytest.mark.parametrize( @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"])
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_stream_response(app, file_name, static_file_directory): def test_file_stream_response(app, file_name, static_file_directory):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
def file_route(request, filename): def file_route(request, filename):
@ -496,9 +490,7 @@ def test_file_stream_response(app, file_name, static_file_directory):
("python.png", "logo.png"), ("python.png", "logo.png"),
], ],
) )
def test_file_stream_response_custom_filename( def test_file_stream_response_custom_filename(app, source, dest, static_file_directory):
app, source, dest, static_file_directory
):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
def file_route(request, filename): def file_route(request, filename):
file_path = os.path.join(static_file_directory, filename) file_path = os.path.join(static_file_directory, filename)
@ -527,8 +519,7 @@ def test_file_stream_head_response(app, file_name, static_file_directory):
stats = await async_os.stat(file_path) stats = await async_os.stat(file_path)
headers["Content-Length"] = str(stats.st_size) headers["Content-Length"] = str(stats.st_size)
return HTTPResponse( return HTTPResponse(
headers=headers, headers=headers, content_type=guess_type(file_path)[0] or "text/plain",
content_type=guess_type(file_path)[0] or "text/plain",
) )
else: else:
return file_stream( return file_stream(
@ -551,12 +542,8 @@ def test_file_stream_head_response(app, file_name, static_file_directory):
) )
@pytest.mark.parametrize( @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"])
"file_name", ["test.file", "decode me.txt", "python.png"] @pytest.mark.parametrize("size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)])
)
@pytest.mark.parametrize(
"size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)]
)
def test_file_stream_response_range( def test_file_stream_response_range(
app, file_name, static_file_directory, size, start, end app, file_name, static_file_directory, size, start, end
): ):