Prevent sending multiple or mixed responses on a single request (#2327)
Co-authored-by: Adam Hopkins <adam@amhopkins.com> Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
b2a1bc69f5
commit
96c027bad5
59
sanic/app.py
59
sanic/app.py
|
@ -42,7 +42,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
from urllib.parse import urlencode, urlunparse
|
||||
from warnings import filterwarnings
|
||||
from warnings import filterwarnings, warn
|
||||
|
||||
from sanic_routing.exceptions import ( # type: ignore
|
||||
FinalizationError,
|
||||
|
@ -67,6 +67,7 @@ from sanic.exceptions import (
|
|||
URLBuildError,
|
||||
)
|
||||
from sanic.handlers import ErrorHandler
|
||||
from sanic.http import Stage
|
||||
from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, error_logger, logger
|
||||
from sanic.mixins.listeners import ListenerEvent
|
||||
from sanic.models.futures import (
|
||||
|
@ -736,6 +737,50 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
|||
context={"request": request, "exception": exception},
|
||||
)
|
||||
|
||||
if (
|
||||
request.stream is not None
|
||||
and request.stream.stage is not Stage.HANDLER
|
||||
):
|
||||
error_logger.exception(exception, exc_info=True)
|
||||
logger.error(
|
||||
"The error response will not be sent to the client for "
|
||||
f'the following exception:"{exception}". A previous response '
|
||||
"has at least partially been sent."
|
||||
)
|
||||
|
||||
# ----------------- deprecated -----------------
|
||||
handler = self.error_handler._lookup(
|
||||
exception, request.name if request else None
|
||||
)
|
||||
if handler:
|
||||
warn(
|
||||
"An error occurred while handling the request after at "
|
||||
"least some part of the response was sent to the client. "
|
||||
"Therefore, the response from your custom exception "
|
||||
f"handler {handler.__name__} will not be sent to the "
|
||||
"client. Beginning in v22.6, Sanic will stop executing "
|
||||
"custom exception handlers in this scenario. Exception "
|
||||
"handlers should only be used to generate the exception "
|
||||
"responses. If you would like to perform any other "
|
||||
"action on a raised exception, please consider using a "
|
||||
"signal handler like "
|
||||
'`@app.signal("http.lifecycle.exception")`\n'
|
||||
"For further information, please see the docs: "
|
||||
"https://sanicframework.org/en/guide/advanced/"
|
||||
"signals.html",
|
||||
DeprecationWarning,
|
||||
)
|
||||
try:
|
||||
response = self.error_handler.response(request, exception)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
except BaseException as e:
|
||||
logger.error("An error occurred in the exception handler.")
|
||||
error_logger.exception(e)
|
||||
# ----------------------------------------------
|
||||
|
||||
return
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
|
@ -765,6 +810,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
|||
)
|
||||
if response is not None:
|
||||
try:
|
||||
request.reset_response()
|
||||
response = await request.respond(response)
|
||||
except BaseException:
|
||||
# Skip response middleware
|
||||
|
@ -874,7 +920,16 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
|||
if isawaitable(response):
|
||||
response = await response
|
||||
|
||||
if response is not None:
|
||||
if request.responded:
|
||||
if response is not None:
|
||||
error_logger.error(
|
||||
"The response object returned by the route handler "
|
||||
"will not be sent to client. The request has already "
|
||||
"been responded to."
|
||||
)
|
||||
if request.stream is not None:
|
||||
response = request.stream.response
|
||||
elif response is not None:
|
||||
response = await request.respond(response)
|
||||
elif not hasattr(handler, "is_websocket"):
|
||||
response = request.stream.response # type: ignore
|
||||
|
|
|
@ -7,8 +7,10 @@ import sanic.app # noqa
|
|||
|
||||
from sanic.compat import Header
|
||||
from sanic.exceptions import ServerError
|
||||
from sanic.http import Stage
|
||||
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
|
||||
from sanic.request import Request
|
||||
from sanic.response import BaseHTTPResponse
|
||||
from sanic.server import ConnInfo
|
||||
from sanic.server.websockets.connection import WebSocketConnection
|
||||
|
||||
|
@ -83,6 +85,8 @@ class ASGIApp:
|
|||
transport: MockTransport
|
||||
lifespan: Lifespan
|
||||
ws: Optional[WebSocketConnection]
|
||||
stage: Stage
|
||||
response: Optional[BaseHTTPResponse]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.ws = None
|
||||
|
@ -95,6 +99,8 @@ class ASGIApp:
|
|||
instance.sanic_app = sanic_app
|
||||
instance.transport = MockTransport(scope, receive, send)
|
||||
instance.transport.loop = sanic_app.loop
|
||||
instance.stage = Stage.IDLE
|
||||
instance.response = None
|
||||
setattr(instance.transport, "add_task", sanic_app.loop.create_task)
|
||||
|
||||
headers = Header(
|
||||
|
@ -149,6 +155,8 @@ class ASGIApp:
|
|||
"""
|
||||
Read and stream the body in chunks from an incoming ASGI message.
|
||||
"""
|
||||
if self.stage is Stage.IDLE:
|
||||
self.stage = Stage.REQUEST
|
||||
message = await self.transport.receive()
|
||||
body = message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
|
@ -163,11 +171,17 @@ class ASGIApp:
|
|||
if data:
|
||||
yield data
|
||||
|
||||
def respond(self, response):
|
||||
def respond(self, response: BaseHTTPResponse):
|
||||
if self.stage is not Stage.HANDLER:
|
||||
self.stage = Stage.FAILED
|
||||
raise RuntimeError("Response already started")
|
||||
if self.response is not None:
|
||||
self.response.stream = None
|
||||
response.stream, self.response = self, response
|
||||
return response
|
||||
|
||||
async def send(self, data, end_stream):
|
||||
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
|
||||
if self.response:
|
||||
response, self.response = self.response, None
|
||||
await self.transport.send(
|
||||
|
@ -195,6 +209,7 @@ class ASGIApp:
|
|||
Handle the incoming request.
|
||||
"""
|
||||
try:
|
||||
self.stage = Stage.HANDLER
|
||||
await self.sanic_app.handle_request(self.request)
|
||||
except Exception as e:
|
||||
await self.sanic_app.handle_exception(self.request, e)
|
||||
|
|
|
@ -584,6 +584,11 @@ class Http(metaclass=TouchUpMeta):
|
|||
self.stage = Stage.FAILED
|
||||
raise RuntimeError("Response already started")
|
||||
|
||||
# Disconnect any earlier but unused response object
|
||||
if self.response is not None:
|
||||
self.response.stream = None
|
||||
|
||||
# Connect and return the response
|
||||
self.response, response.stream = response, self
|
||||
return response
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ from sanic_routing.route import Route # type: ignore
|
|||
if TYPE_CHECKING:
|
||||
from sanic.server import ConnInfo
|
||||
from sanic.app import Sanic
|
||||
from sanic.http import Http
|
||||
|
||||
import email.utils
|
||||
import uuid
|
||||
|
@ -32,7 +31,7 @@ from httptools import parse_url # type: ignore
|
|||
|
||||
from sanic.compat import CancelledErrors, Header
|
||||
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.exceptions import InvalidUsage, ServerError
|
||||
from sanic.headers import (
|
||||
AcceptContainer,
|
||||
Options,
|
||||
|
@ -42,6 +41,7 @@ from sanic.headers import (
|
|||
parse_host,
|
||||
parse_xforwarded,
|
||||
)
|
||||
from sanic.http import Http, Stage
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.models.protocol_types import TransportProtocol
|
||||
from sanic.response import BaseHTTPResponse, HTTPResponse
|
||||
|
@ -104,6 +104,7 @@ class Request:
|
|||
"parsed_json",
|
||||
"parsed_forwarded",
|
||||
"raw_url",
|
||||
"responded",
|
||||
"request_middleware_started",
|
||||
"route",
|
||||
"stream",
|
||||
|
@ -155,6 +156,7 @@ class Request:
|
|||
self.stream: Optional[Http] = None
|
||||
self.route: Optional[Route] = None
|
||||
self._protocol = None
|
||||
self.responded: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
|
@ -164,6 +166,21 @@ class Request:
|
|||
def generate_id(*_):
|
||||
return uuid.uuid4()
|
||||
|
||||
def reset_response(self):
|
||||
try:
|
||||
if (
|
||||
self.stream is not None
|
||||
and self.stream.stage is not Stage.HANDLER
|
||||
):
|
||||
raise ServerError(
|
||||
"Cannot reset response because previous response was sent."
|
||||
)
|
||||
self.stream.response.stream = None
|
||||
self.stream.response = None
|
||||
self.responded = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
async def respond(
|
||||
self,
|
||||
response: Optional[BaseHTTPResponse] = None,
|
||||
|
@ -172,13 +189,19 @@ class Request:
|
|||
headers: Optional[Union[Header, Dict[str, str]]] = None,
|
||||
content_type: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
if self.stream is not None and self.stream.response:
|
||||
raise ServerError("Second respond call is not allowed.")
|
||||
except AttributeError:
|
||||
pass
|
||||
# This logic of determining which response to use is subject to change
|
||||
if response is None:
|
||||
response = (self.stream and self.stream.response) or HTTPResponse(
|
||||
response = HTTPResponse(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
# Connect the response
|
||||
if isinstance(response, BaseHTTPResponse) and self.stream:
|
||||
response = self.stream.respond(response)
|
||||
|
@ -193,6 +216,7 @@ class Request:
|
|||
error_logger.exception(
|
||||
"Exception occurred in one of response middleware handlers"
|
||||
)
|
||||
self.responded = True
|
||||
return response
|
||||
|
||||
async def receive_body(self):
|
||||
|
|
|
@ -3,6 +3,7 @@ from mimetypes import guess_type
|
|||
from os import path
|
||||
from pathlib import PurePath
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AnyStr,
|
||||
Callable,
|
||||
|
@ -19,11 +20,15 @@ from warnings import warn
|
|||
from sanic.compat import Header, open_async
|
||||
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
|
||||
from sanic.cookies import CookieJar
|
||||
from sanic.exceptions import SanicException, ServerError
|
||||
from sanic.helpers import has_message_body, remove_entity_headers
|
||||
from sanic.http import Http
|
||||
from sanic.models.protocol_types import HTMLProtocol, Range
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.asgi import ASGIApp
|
||||
|
||||
try:
|
||||
from ujson import dumps as json_dumps
|
||||
except ImportError:
|
||||
|
@ -45,7 +50,7 @@ class BaseHTTPResponse:
|
|||
self.asgi: bool = False
|
||||
self.body: Optional[bytes] = None
|
||||
self.content_type: Optional[str] = None
|
||||
self.stream: Http = None
|
||||
self.stream: Optional[Union[Http, ASGIApp]] = None
|
||||
self.status: int = None
|
||||
self.headers = Header({})
|
||||
self._cookies: Optional[CookieJar] = None
|
||||
|
@ -112,8 +117,17 @@ class BaseHTTPResponse:
|
|||
"""
|
||||
if data is None and end_stream is None:
|
||||
end_stream = True
|
||||
if end_stream and not data and self.stream.send is None:
|
||||
return
|
||||
if self.stream is None:
|
||||
raise SanicException(
|
||||
"No stream is connected to the response object instance."
|
||||
)
|
||||
if self.stream.send is None:
|
||||
if end_stream and not data:
|
||||
return
|
||||
raise ServerError(
|
||||
"Response stream was ended, no more response data is "
|
||||
"allowed to be sent."
|
||||
)
|
||||
data = (
|
||||
data.encode() # type: ignore
|
||||
if hasattr(data, "encode")
|
||||
|
|
|
@ -6,7 +6,8 @@ import string
|
|||
import sys
|
||||
import uuid
|
||||
|
||||
from typing import Tuple
|
||||
from logging import LogRecord
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -170,3 +171,16 @@ def run_startup(caplog):
|
|||
return caplog.record_tuples
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def message_in_records():
|
||||
def msg_in_log(records: List[LogRecord], msg: str):
|
||||
error_captured = False
|
||||
for record in records:
|
||||
if msg in record.message:
|
||||
error_captured = True
|
||||
break
|
||||
return error_captured
|
||||
|
||||
return msg_in_log
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import Callable, List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from pytest import LogCaptureFixture, MonkeyPatch
|
||||
|
||||
from sanic import Sanic, handlers
|
||||
from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError
|
||||
from sanic.handlers import ErrorHandler
|
||||
from sanic.request import Request
|
||||
from sanic.response import stream, text
|
||||
|
||||
|
||||
|
@ -90,35 +93,35 @@ def exception_handler_app():
|
|||
return exception_handler_app
|
||||
|
||||
|
||||
def test_invalid_usage_exception_handler(exception_handler_app):
|
||||
def test_invalid_usage_exception_handler(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get("/1")
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
def test_server_error_exception_handler(exception_handler_app):
|
||||
def test_server_error_exception_handler(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get("/2")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
|
||||
|
||||
def test_not_found_exception_handler(exception_handler_app):
|
||||
def test_not_found_exception_handler(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get("/3")
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_text_exception__handler(exception_handler_app):
|
||||
def test_text_exception__handler(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get("/random")
|
||||
assert response.status == 200
|
||||
assert response.text == "Done."
|
||||
|
||||
|
||||
def test_async_exception_handler(exception_handler_app):
|
||||
def test_async_exception_handler(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get("/7")
|
||||
assert response.status == 200
|
||||
assert response.text == "foo,bar"
|
||||
|
||||
|
||||
def test_html_traceback_output_in_debug_mode(exception_handler_app):
|
||||
def test_html_traceback_output_in_debug_mode(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get("/4", debug=True)
|
||||
assert response.status == 500
|
||||
soup = BeautifulSoup(response.body, "html.parser")
|
||||
|
@ -133,12 +136,12 @@ def test_html_traceback_output_in_debug_mode(exception_handler_app):
|
|||
) == summary_text
|
||||
|
||||
|
||||
def test_inherited_exception_handler(exception_handler_app):
|
||||
def test_inherited_exception_handler(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get("/5")
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_chained_exception_handler(exception_handler_app):
|
||||
def test_chained_exception_handler(exception_handler_app: Sanic):
|
||||
request, response = exception_handler_app.test_client.get(
|
||||
"/6/0", debug=True
|
||||
)
|
||||
|
@ -157,7 +160,7 @@ def test_chained_exception_handler(exception_handler_app):
|
|||
) == summary_text
|
||||
|
||||
|
||||
def test_exception_handler_lookup(exception_handler_app):
|
||||
def test_exception_handler_lookup(exception_handler_app: Sanic):
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
|
@ -205,13 +208,17 @@ def test_exception_handler_lookup(exception_handler_app):
|
|||
)
|
||||
|
||||
|
||||
def test_exception_handler_processed_request_middleware(exception_handler_app):
|
||||
def test_exception_handler_processed_request_middleware(
|
||||
exception_handler_app: Sanic,
|
||||
):
|
||||
request, response = exception_handler_app.test_client.get("/8")
|
||||
assert response.status == 200
|
||||
assert response.text == "Done."
|
||||
|
||||
|
||||
def test_single_arg_exception_handler_notice(exception_handler_app, caplog):
|
||||
def test_single_arg_exception_handler_notice(
|
||||
exception_handler_app: Sanic, caplog: LogCaptureFixture
|
||||
):
|
||||
class CustomErrorHandler(ErrorHandler):
|
||||
def lookup(self, exception):
|
||||
return super().lookup(exception, None)
|
||||
|
@ -233,7 +240,9 @@ def test_single_arg_exception_handler_notice(exception_handler_app, caplog):
|
|||
assert response.status == 400
|
||||
|
||||
|
||||
def test_error_handler_noisy_log(exception_handler_app, monkeypatch):
|
||||
def test_error_handler_noisy_log(
|
||||
exception_handler_app: Sanic, monkeypatch: MonkeyPatch
|
||||
):
|
||||
err_logger = Mock()
|
||||
monkeypatch.setattr(handlers, "error_logger", err_logger)
|
||||
|
||||
|
@ -246,3 +255,45 @@ def test_error_handler_noisy_log(exception_handler_app, monkeypatch):
|
|||
err_logger.exception.assert_called_with(
|
||||
"Exception occurred while handling uri: %s", repr(request.url)
|
||||
)
|
||||
|
||||
|
||||
def test_exception_handler_response_was_sent(
|
||||
app: Sanic,
|
||||
caplog: LogCaptureFixture,
|
||||
message_in_records: Callable[[List[logging.LogRecord], str], bool],
|
||||
):
|
||||
exception_handler_ran = False
|
||||
|
||||
@app.exception(ServerError)
|
||||
async def exception_handler(request, exception):
|
||||
nonlocal exception_handler_ran
|
||||
exception_handler_ran = True
|
||||
return text("Error")
|
||||
|
||||
@app.route("/1")
|
||||
async def handler1(request: Request):
|
||||
response = await request.respond()
|
||||
await response.send("some text")
|
||||
raise ServerError("Exception")
|
||||
|
||||
@app.route("/2")
|
||||
async def handler2(request: Request):
|
||||
response = await request.respond()
|
||||
raise ServerError("Exception")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_, response = app.test_client.get("/1")
|
||||
assert "some text" in response.text
|
||||
|
||||
# Change to assert warning not in the records in the future version.
|
||||
message_in_records(
|
||||
caplog.records,
|
||||
(
|
||||
"An error occurred while handling the request after at "
|
||||
"least some part of the response was sent to the client. "
|
||||
"Therefore, the response from your custom exception "
|
||||
),
|
||||
)
|
||||
|
||||
_, response = app.test_client.get("/2")
|
||||
assert "Error" in response.text
|
||||
|
|
|
@ -297,3 +297,27 @@ def test_middleware_added_response(app):
|
|||
|
||||
_, response = app.test_client.get("/")
|
||||
assert response.json["foo"] == "bar"
|
||||
|
||||
|
||||
def test_middleware_return_response(app):
|
||||
response_middleware_run_count = 0
|
||||
request_middleware_run_count = 0
|
||||
|
||||
@app.on_response
|
||||
def response(_, response):
|
||||
nonlocal response_middleware_run_count
|
||||
response_middleware_run_count += 1
|
||||
|
||||
@app.on_request
|
||||
def request(_):
|
||||
nonlocal request_middleware_run_count
|
||||
request_middleware_run_count += 1
|
||||
|
||||
@app.get("/")
|
||||
async def handler(request):
|
||||
resp1 = await request.respond()
|
||||
return resp1
|
||||
|
||||
_, response = app.test_client.get("/")
|
||||
assert response_middleware_run_count == 1
|
||||
assert request_middleware_run_count == 1
|
||||
|
|
|
@ -15,8 +15,8 @@ from sanic_testing.testing import (
|
|||
)
|
||||
|
||||
from sanic import Blueprint, Sanic
|
||||
from sanic.exceptions import ServerError
|
||||
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters
|
||||
from sanic.exceptions import SanicException, ServerError
|
||||
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters
|
||||
from sanic.response import html, json, text
|
||||
|
||||
|
||||
|
|
|
@ -3,15 +3,18 @@ import inspect
|
|||
import os
|
||||
|
||||
from collections import namedtuple
|
||||
from logging import ERROR, LogRecord
|
||||
from mimetypes import guess_type
|
||||
from random import choice
|
||||
from typing import Callable, List
|
||||
from urllib.parse import unquote
|
||||
|
||||
import pytest
|
||||
|
||||
from aiofiles import os as async_os
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic import Request, Sanic
|
||||
from sanic.response import (
|
||||
HTTPResponse,
|
||||
empty,
|
||||
|
@ -33,7 +36,7 @@ def test_response_body_not_a_string(app):
|
|||
random_num = choice(range(1000))
|
||||
|
||||
@app.route("/hello")
|
||||
async def hello_route(request):
|
||||
async def hello_route(request: Request):
|
||||
return text(random_num)
|
||||
|
||||
request, response = app.test_client.get("/hello")
|
||||
|
@ -51,7 +54,7 @@ def test_method_not_allowed():
|
|||
app = Sanic("app")
|
||||
|
||||
@app.get("/")
|
||||
async def test_get(request):
|
||||
async def test_get(request: Request):
|
||||
return response.json({"hello": "world"})
|
||||
|
||||
request, response = app.test_client.head("/")
|
||||
|
@ -67,7 +70,7 @@ def test_method_not_allowed():
|
|||
app.router.reset()
|
||||
|
||||
@app.post("/")
|
||||
async def test_post(request):
|
||||
async def test_post(request: Request):
|
||||
return response.json({"hello": "world"})
|
||||
|
||||
request, response = app.test_client.head("/")
|
||||
|
@ -89,7 +92,7 @@ def test_method_not_allowed():
|
|||
|
||||
def test_response_header(app):
|
||||
@app.get("/")
|
||||
async def test(request):
|
||||
async def test(request: Request):
|
||||
return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"})
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
|
@ -102,14 +105,14 @@ def test_response_header(app):
|
|||
|
||||
def test_response_content_length(app):
|
||||
@app.get("/response_with_space")
|
||||
async def response_with_space(request):
|
||||
async def response_with_space(request: Request):
|
||||
return json(
|
||||
{"message": "Data", "details": "Some Details"},
|
||||
headers={"CONTENT-TYPE": "application/json"},
|
||||
)
|
||||
|
||||
@app.get("/response_without_space")
|
||||
async def response_without_space(request):
|
||||
async def response_without_space(request: Request):
|
||||
return json(
|
||||
{"message": "Data", "details": "Some Details"},
|
||||
headers={"CONTENT-TYPE": "application/json"},
|
||||
|
@ -135,7 +138,7 @@ def test_response_content_length(app):
|
|||
|
||||
def test_response_content_length_with_different_data_types(app):
|
||||
@app.get("/")
|
||||
async def get_data_with_different_types(request):
|
||||
async def get_data_with_different_types(request: Request):
|
||||
# Indentation issues in the Response is intentional. Please do not fix
|
||||
return json(
|
||||
{"bool": True, "none": None, "string": "string", "number": -1},
|
||||
|
@ -149,23 +152,23 @@ def test_response_content_length_with_different_data_types(app):
|
|||
@pytest.fixture
|
||||
def json_app(app):
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
async def test(request: Request):
|
||||
return json(JSON_DATA)
|
||||
|
||||
@app.get("/no-content")
|
||||
async def no_content_handler(request):
|
||||
async def no_content_handler(request: Request):
|
||||
return json(JSON_DATA, status=204)
|
||||
|
||||
@app.get("/no-content/unmodified")
|
||||
async def no_content_unmodified_handler(request):
|
||||
async def no_content_unmodified_handler(request: Request):
|
||||
return json(None, status=304)
|
||||
|
||||
@app.get("/unmodified")
|
||||
async def unmodified_handler(request):
|
||||
async def unmodified_handler(request: Request):
|
||||
return json(JSON_DATA, status=304)
|
||||
|
||||
@app.delete("/")
|
||||
async def delete_handler(request):
|
||||
async def delete_handler(request: Request):
|
||||
return json(None, status=204)
|
||||
|
||||
return app
|
||||
|
@ -207,7 +210,7 @@ def test_no_content(json_app):
|
|||
@pytest.fixture
|
||||
def streaming_app(app):
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
async def test(request: Request):
|
||||
return stream(
|
||||
sample_streaming_fn,
|
||||
content_type="text/csv",
|
||||
|
@ -219,7 +222,7 @@ def streaming_app(app):
|
|||
@pytest.fixture
|
||||
def non_chunked_streaming_app(app):
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
async def test(request: Request):
|
||||
return stream(
|
||||
sample_streaming_fn,
|
||||
headers={"Content-Length": "7"},
|
||||
|
@ -276,7 +279,7 @@ def test_non_chunked_streaming_returns_correct_content(
|
|||
|
||||
def test_stream_response_with_cookies(app):
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
async def test(request: Request):
|
||||
response = stream(sample_streaming_fn, content_type="text/csv")
|
||||
response.cookies["test"] = "modified"
|
||||
response.cookies["test"] = "pass"
|
||||
|
@ -288,7 +291,7 @@ def test_stream_response_with_cookies(app):
|
|||
|
||||
def test_stream_response_without_cookies(app):
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
async def test(request: Request):
|
||||
return stream(sample_streaming_fn, content_type="text/csv")
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
|
@ -314,7 +317,7 @@ def get_file_content(static_file_directory, file_name):
|
|||
"file_name", ["test.file", "decode me.txt", "python.png"]
|
||||
)
|
||||
@pytest.mark.parametrize("status", [200, 401])
|
||||
def test_file_response(app, file_name, static_file_directory, status):
|
||||
def test_file_response(app: Sanic, file_name, static_file_directory, status):
|
||||
@app.route("/files/<filename>", methods=["GET"])
|
||||
def file_route(request, filename):
|
||||
file_path = os.path.join(static_file_directory, filename)
|
||||
|
@ -340,7 +343,7 @@ def test_file_response(app, file_name, static_file_directory, status):
|
|||
],
|
||||
)
|
||||
def test_file_response_custom_filename(
|
||||
app, source, dest, static_file_directory
|
||||
app: Sanic, source, dest, static_file_directory
|
||||
):
|
||||
@app.route("/files/<filename>", methods=["GET"])
|
||||
def file_route(request, filename):
|
||||
|
@ -358,7 +361,7 @@ def test_file_response_custom_filename(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"])
|
||||
def test_file_head_response(app, file_name, static_file_directory):
|
||||
def test_file_head_response(app: Sanic, file_name, static_file_directory):
|
||||
@app.route("/files/<filename>", methods=["GET", "HEAD"])
|
||||
async def file_route(request, filename):
|
||||
file_path = os.path.join(static_file_directory, filename)
|
||||
|
@ -391,7 +394,7 @@ def test_file_head_response(app, file_name, static_file_directory):
|
|||
@pytest.mark.parametrize(
|
||||
"file_name", ["test.file", "decode me.txt", "python.png"]
|
||||
)
|
||||
def test_file_stream_response(app, file_name, static_file_directory):
|
||||
def test_file_stream_response(app: Sanic, file_name, static_file_directory):
|
||||
@app.route("/files/<filename>", methods=["GET"])
|
||||
def file_route(request, filename):
|
||||
file_path = os.path.join(static_file_directory, filename)
|
||||
|
@ -417,7 +420,7 @@ def test_file_stream_response(app, file_name, static_file_directory):
|
|||
],
|
||||
)
|
||||
def test_file_stream_response_custom_filename(
|
||||
app, source, dest, static_file_directory
|
||||
app: Sanic, source, dest, static_file_directory
|
||||
):
|
||||
@app.route("/files/<filename>", methods=["GET"])
|
||||
def file_route(request, filename):
|
||||
|
@ -435,7 +438,9 @@ def test_file_stream_response_custom_filename(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"])
|
||||
def test_file_stream_head_response(app, file_name, static_file_directory):
|
||||
def test_file_stream_head_response(
|
||||
app: Sanic, file_name, static_file_directory
|
||||
):
|
||||
@app.route("/files/<filename>", methods=["GET", "HEAD"])
|
||||
async def file_route(request, filename):
|
||||
file_path = os.path.join(static_file_directory, filename)
|
||||
|
@ -479,7 +484,7 @@ def test_file_stream_head_response(app, file_name, static_file_directory):
|
|||
"size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)]
|
||||
)
|
||||
def test_file_stream_response_range(
|
||||
app, file_name, static_file_directory, size, start, end
|
||||
app: Sanic, file_name, static_file_directory, size, start, end
|
||||
):
|
||||
|
||||
Range = namedtuple("Range", ["size", "start", "end", "total"])
|
||||
|
@ -508,7 +513,7 @@ def test_file_stream_response_range(
|
|||
|
||||
def test_raw_response(app):
|
||||
@app.get("/test")
|
||||
def handler(request):
|
||||
def handler(request: Request):
|
||||
return raw(b"raw_response")
|
||||
|
||||
request, response = app.test_client.get("/test")
|
||||
|
@ -518,7 +523,7 @@ def test_raw_response(app):
|
|||
|
||||
def test_empty_response(app):
|
||||
@app.get("/test")
|
||||
def handler(request):
|
||||
def handler(request: Request):
|
||||
return empty()
|
||||
|
||||
request, response = app.test_client.get("/test")
|
||||
|
@ -526,17 +531,162 @@ def test_empty_response(app):
|
|||
assert response.body == b""
|
||||
|
||||
|
||||
def test_direct_response_stream(app):
|
||||
def test_direct_response_stream(app: Sanic):
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
async def test(request: Request):
|
||||
response = await request.respond(content_type="text/csv")
|
||||
await response.send("foo,")
|
||||
await response.send("bar")
|
||||
await response.eof()
|
||||
return response
|
||||
|
||||
_, response = app.test_client.get("/")
|
||||
assert response.text == "foo,bar"
|
||||
assert response.headers["Transfer-Encoding"] == "chunked"
|
||||
assert response.headers["Content-Type"] == "text/csv"
|
||||
assert "Content-Length" not in response.headers
|
||||
|
||||
|
||||
def test_two_respond_calls(app: Sanic):
|
||||
@app.route("/")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond()
|
||||
await response.send("foo,")
|
||||
await response.send("bar")
|
||||
await response.eof()
|
||||
|
||||
|
||||
def test_multiple_responses(
|
||||
app: Sanic,
|
||||
caplog: LogCaptureFixture,
|
||||
message_in_records: Callable[[List[LogRecord], str], bool],
|
||||
):
|
||||
@app.route("/1")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond()
|
||||
await response.send("foo")
|
||||
response = await request.respond()
|
||||
|
||||
@app.route("/2")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond()
|
||||
response = await request.respond()
|
||||
await response.send("foo")
|
||||
|
||||
@app.get("/3")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond()
|
||||
await response.send("foo,")
|
||||
response = await request.respond()
|
||||
await response.send("bar")
|
||||
|
||||
@app.get("/4")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond(headers={"one": "one"})
|
||||
return json({"foo": "bar"}, headers={"one": "two"})
|
||||
|
||||
@app.get("/5")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond(headers={"one": "one"})
|
||||
await response.send("foo")
|
||||
return json({"foo": "bar"}, headers={"one": "two"})
|
||||
|
||||
@app.get("/6")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond(headers={"one": "one"})
|
||||
await response.send("foo, ")
|
||||
json_response = json({"foo": "bar"}, headers={"one": "two"})
|
||||
await response.send("bar")
|
||||
return json_response
|
||||
|
||||
error_msg0 = "Second respond call is not allowed."
|
||||
|
||||
error_msg1 = (
|
||||
"The error response will not be sent to the client for the following "
|
||||
'exception:"Second respond call is not allowed.". A previous '
|
||||
"response has at least partially been sent."
|
||||
)
|
||||
|
||||
error_msg2 = (
|
||||
"The response object returned by the route handler "
|
||||
"will not be sent to client. The request has already "
|
||||
"been responded to."
|
||||
)
|
||||
|
||||
error_msg3 = (
|
||||
"Response stream was ended, no more "
|
||||
"response data is allowed to be sent."
|
||||
)
|
||||
|
||||
with caplog.at_level(ERROR):
|
||||
_, response = app.test_client.get("/1")
|
||||
assert response.status == 200
|
||||
assert message_in_records(caplog.records, error_msg0)
|
||||
assert message_in_records(caplog.records, error_msg1)
|
||||
|
||||
with caplog.at_level(ERROR):
|
||||
_, response = app.test_client.get("/2")
|
||||
assert response.status == 500
|
||||
assert "500 — Internal Server Error" in response.text
|
||||
|
||||
with caplog.at_level(ERROR):
|
||||
_, response = app.test_client.get("/3")
|
||||
assert response.status == 200
|
||||
assert "foo," in response.text
|
||||
assert message_in_records(caplog.records, error_msg0)
|
||||
assert message_in_records(caplog.records, error_msg1)
|
||||
|
||||
with caplog.at_level(ERROR):
|
||||
_, response = app.test_client.get("/4")
|
||||
print(response.json)
|
||||
assert response.status == 200
|
||||
assert "foo" not in response.text
|
||||
assert "one" in response.headers
|
||||
assert response.headers["one"] == "one"
|
||||
|
||||
print(response.headers)
|
||||
assert message_in_records(caplog.records, error_msg2)
|
||||
|
||||
with caplog.at_level(ERROR):
|
||||
_, response = app.test_client.get("/5")
|
||||
assert response.status == 200
|
||||
assert "foo" in response.text
|
||||
assert "one" in response.headers
|
||||
assert response.headers["one"] == "one"
|
||||
assert message_in_records(caplog.records, error_msg2)
|
||||
|
||||
with caplog.at_level(ERROR):
|
||||
_, response = app.test_client.get("/6")
|
||||
assert "foo, bar" in response.text
|
||||
assert "one" in response.headers
|
||||
assert response.headers["one"] == "one"
|
||||
assert message_in_records(caplog.records, error_msg2)
|
||||
|
||||
|
||||
def send_response_after_eof_should_fail(
|
||||
app: Sanic,
|
||||
caplog: LogCaptureFixture,
|
||||
message_in_records: Callable[[List[LogRecord], str], bool],
|
||||
):
|
||||
@app.get("/")
|
||||
async def handler(request: Request):
|
||||
response = await request.respond()
|
||||
await response.send("foo, ")
|
||||
await response.eof()
|
||||
await response.send("bar")
|
||||
|
||||
error_msg1 = (
|
||||
"The error response will not be sent to the client for the following "
|
||||
'exception:"Second respond call is not allowed.". A previous '
|
||||
"response has at least partially been sent."
|
||||
)
|
||||
|
||||
error_msg2 = (
|
||||
"Response stream was ended, no more "
|
||||
"response data is allowed to be sent."
|
||||
)
|
||||
|
||||
with caplog.at_level(ERROR):
|
||||
_, response = app.test_client.get("/")
|
||||
assert "foo, " in response.text
|
||||
assert message_in_records(caplog.records, error_msg1)
|
||||
assert message_in_records(caplog.records, error_msg2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user