Resolve headers as body in ASGI mode
This commit is contained in:
parent
2a44a27236
commit
c09129ec63
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -16,7 +15,6 @@ from typing import (
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
import sanic.app # noqa
|
import sanic.app # noqa
|
||||||
|
|
||||||
from sanic.compat import Header
|
from sanic.compat import Header
|
||||||
from sanic.exceptions import InvalidUsage, ServerError
|
from sanic.exceptions import InvalidUsage, ServerError
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
|
@ -25,7 +23,6 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||||
from sanic.server import StreamBuffer
|
from sanic.server import StreamBuffer
|
||||||
from sanic.websocket import WebSocketConnection
|
from sanic.websocket import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
ASGIScope = MutableMapping[str, Any]
|
ASGIScope = MutableMapping[str, Any]
|
||||||
ASGIMessage = MutableMapping[str, Any]
|
ASGIMessage = MutableMapping[str, Any]
|
||||||
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
|
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
|
||||||
|
@ -68,9 +65,7 @@ class MockProtocol:
|
||||||
class MockTransport:
|
class MockTransport:
|
||||||
_protocol: Optional[MockProtocol]
|
_protocol: Optional[MockProtocol]
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> None:
|
||||||
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
|
||||||
) -> None:
|
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
self._receive = receive
|
self._receive = receive
|
||||||
self._send = send
|
self._send = send
|
||||||
|
@ -146,9 +141,7 @@ class Lifespan:
|
||||||
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
|
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
|
||||||
|
|
||||||
for handler in listeners:
|
for handler in listeners:
|
||||||
response = handler(
|
response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
|
||||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
|
||||||
)
|
|
||||||
if isawaitable(response):
|
if isawaitable(response):
|
||||||
await response
|
await response
|
||||||
|
|
||||||
|
@ -166,9 +159,7 @@ class Lifespan:
|
||||||
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
|
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
|
||||||
|
|
||||||
for handler in listeners:
|
for handler in listeners:
|
||||||
response = handler(
|
response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
|
||||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
|
||||||
)
|
|
||||||
if isawaitable(response):
|
if isawaitable(response):
|
||||||
await response
|
await response
|
||||||
|
|
||||||
|
@ -213,19 +204,13 @@ class ASGIApp:
|
||||||
for key, value in scope.get("headers", [])
|
for key, value in scope.get("headers", [])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
instance.do_stream = (
|
instance.do_stream = True if headers.get("expect") == "100-continue" else False
|
||||||
True if headers.get("expect") == "100-continue" else False
|
|
||||||
)
|
|
||||||
instance.lifespan = Lifespan(instance)
|
instance.lifespan = Lifespan(instance)
|
||||||
|
|
||||||
if scope["type"] == "lifespan":
|
if scope["type"] == "lifespan":
|
||||||
await instance.lifespan(scope, receive, send)
|
await instance.lifespan(scope, receive, send)
|
||||||
else:
|
else:
|
||||||
path = (
|
path = scope["path"][1:] if scope["path"].startswith("/") else scope["path"]
|
||||||
scope["path"][1:]
|
|
||||||
if scope["path"].startswith("/")
|
|
||||||
else scope["path"]
|
|
||||||
)
|
|
||||||
url = "/".join([scope.get("root_path", ""), quote(path)])
|
url = "/".join([scope.get("root_path", ""), quote(path)])
|
||||||
url_bytes = url.encode("latin-1")
|
url_bytes = url.encode("latin-1")
|
||||||
url_bytes += b"?" + scope["query_string"]
|
url_bytes += b"?" + scope["query_string"]
|
||||||
|
@ -248,18 +233,11 @@ class ASGIApp:
|
||||||
|
|
||||||
request_class = sanic_app.request_class or Request
|
request_class = sanic_app.request_class or Request
|
||||||
instance.request = request_class(
|
instance.request = request_class(
|
||||||
url_bytes,
|
url_bytes, headers, version, method, instance.transport, sanic_app,
|
||||||
headers,
|
|
||||||
version,
|
|
||||||
method,
|
|
||||||
instance.transport,
|
|
||||||
sanic_app,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if sanic_app.is_request_stream:
|
if sanic_app.is_request_stream:
|
||||||
is_stream_handler = sanic_app.router.is_stream_handler(
|
is_stream_handler = sanic_app.router.is_stream_handler(instance.request)
|
||||||
instance.request
|
|
||||||
)
|
|
||||||
if is_stream_handler:
|
if is_stream_handler:
|
||||||
instance.request.stream = StreamBuffer(
|
instance.request.stream = StreamBuffer(
|
||||||
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
|
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
|
||||||
|
@ -313,6 +291,7 @@ class ASGIApp:
|
||||||
"""
|
"""
|
||||||
Write the response.
|
Write the response.
|
||||||
"""
|
"""
|
||||||
|
response.asgi = True
|
||||||
headers: List[Tuple[bytes, bytes]] = []
|
headers: List[Tuple[bytes, bytes]] = []
|
||||||
cookies: Dict[str, str] = {}
|
cookies: Dict[str, str] = {}
|
||||||
try:
|
try:
|
||||||
|
@ -338,9 +317,7 @@ class ASGIApp:
|
||||||
type(response),
|
type(response),
|
||||||
)
|
)
|
||||||
exception = ServerError("Invalid response type")
|
exception = ServerError("Invalid response type")
|
||||||
response = self.sanic_app.error_handler.response(
|
response = self.sanic_app.error_handler.response(self.request, exception)
|
||||||
self.request, exception
|
|
||||||
)
|
|
||||||
headers = [
|
headers = [
|
||||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||||
for name, value in response.headers.items()
|
for name, value in response.headers.items()
|
||||||
|
@ -350,14 +327,10 @@ class ASGIApp:
|
||||||
if "content-length" not in response.headers and not isinstance(
|
if "content-length" not in response.headers and not isinstance(
|
||||||
response, StreamingHTTPResponse
|
response, StreamingHTTPResponse
|
||||||
):
|
):
|
||||||
headers += [
|
headers += [(b"content-length", str(len(response.body)).encode("latin-1"))]
|
||||||
(b"content-length", str(len(response.body)).encode("latin-1"))
|
|
||||||
]
|
|
||||||
|
|
||||||
if "content-type" not in response.headers:
|
if "content-type" not in response.headers:
|
||||||
headers += [
|
headers += [(b"content-type", str(response.content_type).encode("latin-1"))]
|
||||||
(b"content-type", str(response.content_type).encode("latin-1"))
|
|
||||||
]
|
|
||||||
|
|
||||||
if response.cookies:
|
if response.cookies:
|
||||||
cookies.update(
|
cookies.update(
|
||||||
|
@ -369,8 +342,7 @@ class ASGIApp:
|
||||||
)
|
)
|
||||||
|
|
||||||
headers += [
|
headers += [
|
||||||
(b"set-cookie", cookie.encode("utf-8"))
|
(b"set-cookie", cookie.encode("utf-8")) for k, cookie in cookies.items()
|
||||||
for k, cookie in cookies.items()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
await self.transport.send(
|
await self.transport.send(
|
||||||
|
|
|
@ -10,7 +10,6 @@ from sanic.cookies import CookieJar
|
||||||
from sanic.headers import format_http1
|
from sanic.headers import format_http1
|
||||||
from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers
|
from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ujson import dumps as json_dumps
|
from ujson import dumps as json_dumps
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -81,30 +80,25 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||||
await self.protocol.push_data(data)
|
await self.protocol.push_data(data)
|
||||||
await self.protocol.drain()
|
await self.protocol.drain()
|
||||||
|
|
||||||
async def stream(
|
async def stream(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None
|
|
||||||
):
|
|
||||||
"""Streams headers, runs the `streaming_fn` callback that writes
|
"""Streams headers, runs the `streaming_fn` callback that writes
|
||||||
content to the response body, then finalizes the response body.
|
content to the response body, then finalizes the response body.
|
||||||
"""
|
"""
|
||||||
if version != "1.1":
|
if version != "1.1":
|
||||||
self.chunked = False
|
self.chunked = False
|
||||||
headers = self.get_headers(
|
headers = self.get_headers(
|
||||||
version,
|
version, keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout,
|
||||||
keep_alive=keep_alive,
|
|
||||||
keep_alive_timeout=keep_alive_timeout,
|
|
||||||
)
|
)
|
||||||
await self.protocol.push_data(headers)
|
if not getattr(self, "asgi", False):
|
||||||
await self.protocol.drain()
|
await self.protocol.push_data(headers)
|
||||||
|
await self.protocol.drain()
|
||||||
await self.streaming_fn(self)
|
await self.streaming_fn(self)
|
||||||
if self.chunked:
|
if self.chunked:
|
||||||
await self.protocol.push_data(b"0\r\n\r\n")
|
await self.protocol.push_data(b"0\r\n\r\n")
|
||||||
# no need to await drain here after this write, because it is the
|
# no need to await drain here after this write, because it is the
|
||||||
# very last thing we write and nothing needs to wait for it.
|
# very last thing we write and nothing needs to wait for it.
|
||||||
|
|
||||||
def get_headers(
|
def get_headers(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None
|
|
||||||
):
|
|
||||||
# This is all returned in a kind-of funky way
|
# This is all returned in a kind-of funky way
|
||||||
# We tried to make this as fast as possible in pure python
|
# We tried to make this as fast as possible in pure python
|
||||||
timeout_header = b""
|
timeout_header = b""
|
||||||
|
@ -138,12 +132,7 @@ class HTTPResponse(BaseHTTPResponse):
|
||||||
__slots__ = ("body", "status", "content_type", "headers", "_cookies")
|
__slots__ = ("body", "status", "content_type", "headers", "_cookies")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, body=None, status=200, headers=None, content_type=None, body_bytes=b"",
|
||||||
body=None,
|
|
||||||
status=200,
|
|
||||||
headers=None,
|
|
||||||
content_type=None,
|
|
||||||
body_bytes=b"",
|
|
||||||
):
|
):
|
||||||
self.content_type = content_type
|
self.content_type = content_type
|
||||||
|
|
||||||
|
@ -184,9 +173,7 @@ class HTTPResponse(BaseHTTPResponse):
|
||||||
else:
|
else:
|
||||||
status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
|
status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
|
||||||
|
|
||||||
return (
|
return (b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b") % (
|
||||||
b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b"
|
|
||||||
) % (
|
|
||||||
version.encode(),
|
version.encode(),
|
||||||
self.status,
|
self.status,
|
||||||
status,
|
status,
|
||||||
|
@ -237,9 +224,7 @@ def json(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def text(
|
def text(body, status=200, headers=None, content_type="text/plain; charset=utf-8"):
|
||||||
body, status=200, headers=None, content_type="text/plain; charset=utf-8"
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Returns response object with body in text format.
|
Returns response object with body in text format.
|
||||||
|
|
||||||
|
@ -248,14 +233,10 @@ def text(
|
||||||
:param headers: Custom Headers.
|
:param headers: Custom Headers.
|
||||||
:param content_type: the content type (string) of the response
|
:param content_type: the content type (string) of the response
|
||||||
"""
|
"""
|
||||||
return HTTPResponse(
|
return HTTPResponse(body, status=status, headers=headers, content_type=content_type)
|
||||||
body, status=status, headers=headers, content_type=content_type
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def raw(
|
def raw(body, status=200, headers=None, content_type="application/octet-stream"):
|
||||||
body, status=200, headers=None, content_type="application/octet-stream"
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Returns response object without encoding the body.
|
Returns response object without encoding the body.
|
||||||
|
|
||||||
|
@ -265,10 +246,7 @@ def raw(
|
||||||
:param content_type: the content type (string) of the response.
|
:param content_type: the content type (string) of the response.
|
||||||
"""
|
"""
|
||||||
return HTTPResponse(
|
return HTTPResponse(
|
||||||
body_bytes=body,
|
body_bytes=body, status=status, headers=headers, content_type=content_type,
|
||||||
status=status,
|
|
||||||
headers=headers,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -281,20 +259,12 @@ def html(body, status=200, headers=None):
|
||||||
:param headers: Custom Headers.
|
:param headers: Custom Headers.
|
||||||
"""
|
"""
|
||||||
return HTTPResponse(
|
return HTTPResponse(
|
||||||
body,
|
body, status=status, headers=headers, content_type="text/html; charset=utf-8",
|
||||||
status=status,
|
|
||||||
headers=headers,
|
|
||||||
content_type="text/html; charset=utf-8",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def file(
|
async def file(
|
||||||
location,
|
location, status=200, mime_type=None, headers=None, filename=None, _range=None,
|
||||||
status=200,
|
|
||||||
mime_type=None,
|
|
||||||
headers=None,
|
|
||||||
filename=None,
|
|
||||||
_range=None,
|
|
||||||
):
|
):
|
||||||
"""Return a response object with file data.
|
"""Return a response object with file data.
|
||||||
|
|
||||||
|
@ -326,10 +296,7 @@ async def file(
|
||||||
|
|
||||||
mime_type = mime_type or guess_type(filename)[0] or "text/plain"
|
mime_type = mime_type or guess_type(filename)[0] or "text/plain"
|
||||||
return HTTPResponse(
|
return HTTPResponse(
|
||||||
status=status,
|
status=status, headers=headers, content_type=mime_type, body_bytes=out_stream,
|
||||||
headers=headers,
|
|
||||||
content_type=mime_type,
|
|
||||||
body_bytes=out_stream,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -437,9 +404,7 @@ def stream(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def redirect(
|
def redirect(to, headers=None, status=302, content_type="text/html; charset=utf-8"):
|
||||||
to, headers=None, status=302, content_type="text/html; charset=utf-8"
|
|
||||||
):
|
|
||||||
"""Abort execution and cause a 302 redirect (by default).
|
"""Abort execution and cause a 302 redirect (by default).
|
||||||
|
|
||||||
:param to: path or fully qualified URL to redirect to
|
:param to: path or fully qualified URL to redirect to
|
||||||
|
@ -456,6 +421,5 @@ def redirect(
|
||||||
# According to RFC 7231, a relative URI is now permitted.
|
# According to RFC 7231, a relative URI is now permitted.
|
||||||
headers["Location"] = safe_to
|
headers["Location"] = safe_to
|
||||||
|
|
||||||
return HTTPResponse(
|
return HTTPResponse(status=status, headers=headers, content_type=content_type)
|
||||||
status=status, headers=headers, content_type=content_type
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
from random import choice
|
from random import choice
|
||||||
|
@ -9,7 +8,6 @@ from unittest.mock import MagicMock
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from aiofiles import os as async_os
|
from aiofiles import os as async_os
|
||||||
|
|
||||||
from sanic.response import (
|
from sanic.response import (
|
||||||
|
@ -25,7 +23,6 @@ from sanic.response import (
|
||||||
from sanic.server import HttpProtocol
|
from sanic.server import HttpProtocol
|
||||||
from sanic.testing import HOST, PORT
|
from sanic.testing import HOST, PORT
|
||||||
|
|
||||||
|
|
||||||
JSON_DATA = {"ok": True}
|
JSON_DATA = {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,14 +100,10 @@ def test_response_content_length(app):
|
||||||
)
|
)
|
||||||
|
|
||||||
_, response = app.test_client.get("/response_with_space")
|
_, response = app.test_client.get("/response_with_space")
|
||||||
content_length_for_response_with_space = response.headers.get(
|
content_length_for_response_with_space = response.headers.get("Content-Length")
|
||||||
"Content-Length"
|
|
||||||
)
|
|
||||||
|
|
||||||
_, response = app.test_client.get("/response_without_space")
|
_, response = app.test_client.get("/response_without_space")
|
||||||
content_length_for_response_without_space = response.headers.get(
|
content_length_for_response_without_space = response.headers.get("Content-Length")
|
||||||
"Content-Length"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
content_length_for_response_with_space
|
content_length_for_response_with_space
|
||||||
|
@ -232,6 +225,12 @@ def test_chunked_streaming_returns_correct_content(streaming_app):
|
||||||
assert response.text == "foo,bar"
|
assert response.text == "foo,bar"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunked_streaming_returns_correct_content_asgi(streaming_app):
|
||||||
|
request, response = await streaming_app.asgi_client.get("/")
|
||||||
|
assert response.text == "4\r\nfoo,\r\n3\r\nbar\r\n0\r\n\r\n"
|
||||||
|
|
||||||
|
|
||||||
def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
|
def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
|
||||||
request, response = non_chunked_streaming_app.test_client.get("/")
|
request, response = non_chunked_streaming_app.test_client.get("/")
|
||||||
assert "Transfer-Encoding" not in response.headers
|
assert "Transfer-Encoding" not in response.headers
|
||||||
|
@ -239,9 +238,17 @@ def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
|
||||||
assert response.headers["Content-Length"] == "7"
|
assert response.headers["Content-Length"] == "7"
|
||||||
|
|
||||||
|
|
||||||
def test_non_chunked_streaming_returns_correct_content(
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_chunked_streaming_adds_correct_headers_asgi(
|
||||||
non_chunked_streaming_app,
|
non_chunked_streaming_app,
|
||||||
):
|
):
|
||||||
|
request, response = await non_chunked_streaming_app.asgi_client.get("/")
|
||||||
|
assert "Transfer-Encoding" not in response.headers
|
||||||
|
assert response.headers["Content-Type"] == "text/csv"
|
||||||
|
assert response.headers["Content-Length"] == "7"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_chunked_streaming_returns_correct_content(non_chunked_streaming_app,):
|
||||||
request, response = non_chunked_streaming_app.test_client.get("/")
|
request, response = non_chunked_streaming_app.test_client.get("/")
|
||||||
assert response.text == "foo,bar"
|
assert response.text == "foo,bar"
|
||||||
|
|
||||||
|
@ -254,9 +261,7 @@ def test_stream_response_status_returns_correct_headers(status):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30])
|
@pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30])
|
||||||
def test_stream_response_keep_alive_returns_correct_headers(
|
def test_stream_response_keep_alive_returns_correct_headers(keep_alive_timeout,):
|
||||||
keep_alive_timeout,
|
|
||||||
):
|
|
||||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||||
headers = response.get_headers(
|
headers = response.get_headers(
|
||||||
keep_alive=True, keep_alive_timeout=keep_alive_timeout
|
keep_alive=True, keep_alive_timeout=keep_alive_timeout
|
||||||
|
@ -340,13 +345,9 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
|
||||||
@streaming_app.listener("after_server_start")
|
@streaming_app.listener("after_server_start")
|
||||||
async def run_stream(app, loop):
|
async def run_stream(app, loop):
|
||||||
await response.stream(version="1.0")
|
await response.stream(version="1.0")
|
||||||
assert response.protocol.transport.write.call_args_list[1][0][0] == (
|
assert response.protocol.transport.write.call_args_list[1][0][0] == (b"foo,")
|
||||||
b"foo,"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.protocol.transport.write.call_args_list[2][0][0] == (
|
assert response.protocol.transport.write.call_args_list[2][0][0] == (b"bar")
|
||||||
b"bar"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(response.protocol.transport.write.call_args_list) == 3
|
assert len(response.protocol.transport.write.call_args_list) == 3
|
||||||
|
|
||||||
|
@ -391,9 +392,7 @@ def get_file_content(static_file_directory, file_name):
|
||||||
return file.read()
|
return file.read()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"])
|
||||||
"file_name", ["test.file", "decode me.txt", "python.png"]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("status", [200, 401])
|
@pytest.mark.parametrize("status", [200, 401])
|
||||||
def test_file_response(app, file_name, static_file_directory, status):
|
def test_file_response(app, file_name, static_file_directory, status):
|
||||||
@app.route("/files/<filename>", methods=["GET"])
|
@app.route("/files/<filename>", methods=["GET"])
|
||||||
|
@ -420,9 +419,7 @@ def test_file_response(app, file_name, static_file_directory, status):
|
||||||
("python.png", "logo.png"),
|
("python.png", "logo.png"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_file_response_custom_filename(
|
def test_file_response_custom_filename(app, source, dest, static_file_directory):
|
||||||
app, source, dest, static_file_directory
|
|
||||||
):
|
|
||||||
@app.route("/files/<filename>", methods=["GET"])
|
@app.route("/files/<filename>", methods=["GET"])
|
||||||
def file_route(request, filename):
|
def file_route(request, filename):
|
||||||
file_path = os.path.join(static_file_directory, filename)
|
file_path = os.path.join(static_file_directory, filename)
|
||||||
|
@ -449,8 +446,7 @@ def test_file_head_response(app, file_name, static_file_directory):
|
||||||
headers["Content-Length"] = str(stats.st_size)
|
headers["Content-Length"] = str(stats.st_size)
|
||||||
if request.method == "HEAD":
|
if request.method == "HEAD":
|
||||||
return HTTPResponse(
|
return HTTPResponse(
|
||||||
headers=headers,
|
headers=headers, content_type=guess_type(file_path)[0] or "text/plain",
|
||||||
content_type=guess_type(file_path)[0] or "text/plain",
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return file(
|
return file(
|
||||||
|
@ -468,9 +464,7 @@ def test_file_head_response(app, file_name, static_file_directory):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"])
|
||||||
"file_name", ["test.file", "decode me.txt", "python.png"]
|
|
||||||
)
|
|
||||||
def test_file_stream_response(app, file_name, static_file_directory):
|
def test_file_stream_response(app, file_name, static_file_directory):
|
||||||
@app.route("/files/<filename>", methods=["GET"])
|
@app.route("/files/<filename>", methods=["GET"])
|
||||||
def file_route(request, filename):
|
def file_route(request, filename):
|
||||||
|
@ -496,9 +490,7 @@ def test_file_stream_response(app, file_name, static_file_directory):
|
||||||
("python.png", "logo.png"),
|
("python.png", "logo.png"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_file_stream_response_custom_filename(
|
def test_file_stream_response_custom_filename(app, source, dest, static_file_directory):
|
||||||
app, source, dest, static_file_directory
|
|
||||||
):
|
|
||||||
@app.route("/files/<filename>", methods=["GET"])
|
@app.route("/files/<filename>", methods=["GET"])
|
||||||
def file_route(request, filename):
|
def file_route(request, filename):
|
||||||
file_path = os.path.join(static_file_directory, filename)
|
file_path = os.path.join(static_file_directory, filename)
|
||||||
|
@ -527,8 +519,7 @@ def test_file_stream_head_response(app, file_name, static_file_directory):
|
||||||
stats = await async_os.stat(file_path)
|
stats = await async_os.stat(file_path)
|
||||||
headers["Content-Length"] = str(stats.st_size)
|
headers["Content-Length"] = str(stats.st_size)
|
||||||
return HTTPResponse(
|
return HTTPResponse(
|
||||||
headers=headers,
|
headers=headers, content_type=guess_type(file_path)[0] or "text/plain",
|
||||||
content_type=guess_type(file_path)[0] or "text/plain",
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return file_stream(
|
return file_stream(
|
||||||
|
@ -551,12 +542,8 @@ def test_file_stream_head_response(app, file_name, static_file_directory):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt", "python.png"])
|
||||||
"file_name", ["test.file", "decode me.txt", "python.png"]
|
@pytest.mark.parametrize("size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)])
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)]
|
|
||||||
)
|
|
||||||
def test_file_stream_response_range(
|
def test_file_stream_response_range(
|
||||||
app, file_name, static_file_directory, size, start, end
|
app, file_name, static_file_directory, size, start, end
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user