Prevent sending multiple or mixed responses on a single request (#2327)

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Zhiwei 2021-12-09 03:00:18 -07:00 committed by GitHub
parent b2a1bc69f5
commit 96c027bad5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 405 additions and 53 deletions

View File

@ -42,7 +42,7 @@ from typing import (
Union, Union,
) )
from urllib.parse import urlencode, urlunparse from urllib.parse import urlencode, urlunparse
from warnings import filterwarnings from warnings import filterwarnings, warn
from sanic_routing.exceptions import ( # type: ignore from sanic_routing.exceptions import ( # type: ignore
FinalizationError, FinalizationError,
@ -67,6 +67,7 @@ from sanic.exceptions import (
URLBuildError, URLBuildError,
) )
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.http import Stage
from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, error_logger, logger from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, error_logger, logger
from sanic.mixins.listeners import ListenerEvent from sanic.mixins.listeners import ListenerEvent
from sanic.models.futures import ( from sanic.models.futures import (
@ -736,6 +737,50 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
context={"request": request, "exception": exception}, context={"request": request, "exception": exception},
) )
if (
request.stream is not None
and request.stream.stage is not Stage.HANDLER
):
error_logger.exception(exception, exc_info=True)
logger.error(
"The error response will not be sent to the client for "
f'the following exception:"{exception}". A previous response '
"has at least partially been sent."
)
# ----------------- deprecated -----------------
handler = self.error_handler._lookup(
exception, request.name if request else None
)
if handler:
warn(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"Therefore, the response from your custom exception "
f"handler {handler.__name__} will not be sent to the "
"client. Beginning in v22.6, Sanic will stop executing "
"custom exception handlers in this scenario. Exception "
"handlers should only be used to generate the exception "
"responses. If you would like to perform any other "
"action on a raised exception, please consider using a "
"signal handler like "
'`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/"
"signals.html",
DeprecationWarning,
)
try:
response = self.error_handler.response(request, exception)
if isawaitable(response):
response = await response
except BaseException as e:
logger.error("An error occurred in the exception handler.")
error_logger.exception(e)
# ----------------------------------------------
return
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware
# -------------------------------------------- # # -------------------------------------------- #
@ -765,6 +810,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
) )
if response is not None: if response is not None:
try: try:
request.reset_response()
response = await request.respond(response) response = await request.respond(response)
except BaseException: except BaseException:
# Skip response middleware # Skip response middleware
@ -874,7 +920,16 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
if isawaitable(response): if isawaitable(response):
response = await response response = await response
if response is not None: if request.responded:
if response is not None:
error_logger.error(
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
if request.stream is not None:
response = request.stream.response
elif response is not None:
response = await request.respond(response) response = await request.respond(response)
elif not hasattr(handler, "is_websocket"): elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore response = request.stream.response # type: ignore

View File

@ -7,8 +7,10 @@ import sanic.app # noqa
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
from sanic.http import Stage
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
from sanic.request import Request from sanic.request import Request
from sanic.response import BaseHTTPResponse
from sanic.server import ConnInfo from sanic.server import ConnInfo
from sanic.server.websockets.connection import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
@ -83,6 +85,8 @@ class ASGIApp:
transport: MockTransport transport: MockTransport
lifespan: Lifespan lifespan: Lifespan
ws: Optional[WebSocketConnection] ws: Optional[WebSocketConnection]
stage: Stage
response: Optional[BaseHTTPResponse]
def __init__(self) -> None: def __init__(self) -> None:
self.ws = None self.ws = None
@ -95,6 +99,8 @@ class ASGIApp:
instance.sanic_app = sanic_app instance.sanic_app = sanic_app
instance.transport = MockTransport(scope, receive, send) instance.transport = MockTransport(scope, receive, send)
instance.transport.loop = sanic_app.loop instance.transport.loop = sanic_app.loop
instance.stage = Stage.IDLE
instance.response = None
setattr(instance.transport, "add_task", sanic_app.loop.create_task) setattr(instance.transport, "add_task", sanic_app.loop.create_task)
headers = Header( headers = Header(
@ -149,6 +155,8 @@ class ASGIApp:
""" """
Read and stream the body in chunks from an incoming ASGI message. Read and stream the body in chunks from an incoming ASGI message.
""" """
if self.stage is Stage.IDLE:
self.stage = Stage.REQUEST
message = await self.transport.receive() message = await self.transport.receive()
body = message.get("body", b"") body = message.get("body", b"")
if not message.get("more_body", False): if not message.get("more_body", False):
@ -163,11 +171,17 @@ class ASGIApp:
if data: if data:
yield data yield data
def respond(self, response): def respond(self, response: BaseHTTPResponse):
if self.stage is not Stage.HANDLER:
self.stage = Stage.FAILED
raise RuntimeError("Response already started")
if self.response is not None:
self.response.stream = None
response.stream, self.response = self, response response.stream, self.response = self, response
return response return response
async def send(self, data, end_stream): async def send(self, data, end_stream):
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
if self.response: if self.response:
response, self.response = self.response, None response, self.response = self.response, None
await self.transport.send( await self.transport.send(
@ -195,6 +209,7 @@ class ASGIApp:
Handle the incoming request. Handle the incoming request.
""" """
try: try:
self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request) await self.sanic_app.handle_request(self.request)
except Exception as e: except Exception as e:
await self.sanic_app.handle_exception(self.request, e) await self.sanic_app.handle_exception(self.request, e)

View File

@ -584,6 +584,11 @@ class Http(metaclass=TouchUpMeta):
self.stage = Stage.FAILED self.stage = Stage.FAILED
raise RuntimeError("Response already started") raise RuntimeError("Response already started")
# Disconnect any earlier but unused response object
if self.response is not None:
self.response.stream = None
# Connect and return the response
self.response, response.stream = response, self self.response, response.stream = response, self
return response return response

View File

@ -18,7 +18,6 @@ from sanic_routing.route import Route # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.server import ConnInfo from sanic.server import ConnInfo
from sanic.app import Sanic from sanic.app import Sanic
from sanic.http import Http
import email.utils import email.utils
import uuid import uuid
@ -32,7 +31,7 @@ from httptools import parse_url # type: ignore
from sanic.compat import CancelledErrors, Header from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage, ServerError
from sanic.headers import ( from sanic.headers import (
AcceptContainer, AcceptContainer,
Options, Options,
@ -42,6 +41,7 @@ from sanic.headers import (
parse_host, parse_host,
parse_xforwarded, parse_xforwarded,
) )
from sanic.http import Http, Stage
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse from sanic.response import BaseHTTPResponse, HTTPResponse
@ -104,6 +104,7 @@ class Request:
"parsed_json", "parsed_json",
"parsed_forwarded", "parsed_forwarded",
"raw_url", "raw_url",
"responded",
"request_middleware_started", "request_middleware_started",
"route", "route",
"stream", "stream",
@ -155,6 +156,7 @@ class Request:
self.stream: Optional[Http] = None self.stream: Optional[Http] = None
self.route: Optional[Route] = None self.route: Optional[Route] = None
self._protocol = None self._protocol = None
self.responded: bool = False
def __repr__(self): def __repr__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
@ -164,6 +166,21 @@ class Request:
def generate_id(*_): def generate_id(*_):
return uuid.uuid4() return uuid.uuid4()
def reset_response(self):
try:
if (
self.stream is not None
and self.stream.stage is not Stage.HANDLER
):
raise ServerError(
"Cannot reset response because previous response was sent."
)
self.stream.response.stream = None
self.stream.response = None
self.responded = False
except AttributeError:
pass
async def respond( async def respond(
self, self,
response: Optional[BaseHTTPResponse] = None, response: Optional[BaseHTTPResponse] = None,
@ -172,13 +189,19 @@ class Request:
headers: Optional[Union[Header, Dict[str, str]]] = None, headers: Optional[Union[Header, Dict[str, str]]] = None,
content_type: Optional[str] = None, content_type: Optional[str] = None,
): ):
try:
if self.stream is not None and self.stream.response:
raise ServerError("Second respond call is not allowed.")
except AttributeError:
pass
# This logic of determining which response to use is subject to change # This logic of determining which response to use is subject to change
if response is None: if response is None:
response = (self.stream and self.stream.response) or HTTPResponse( response = HTTPResponse(
status=status, status=status,
headers=headers, headers=headers,
content_type=content_type, content_type=content_type,
) )
# Connect the response # Connect the response
if isinstance(response, BaseHTTPResponse) and self.stream: if isinstance(response, BaseHTTPResponse) and self.stream:
response = self.stream.respond(response) response = self.stream.respond(response)
@ -193,6 +216,7 @@ class Request:
error_logger.exception( error_logger.exception(
"Exception occurred in one of response middleware handlers" "Exception occurred in one of response middleware handlers"
) )
self.responded = True
return response return response
async def receive_body(self): async def receive_body(self):

View File

@ -3,6 +3,7 @@ from mimetypes import guess_type
from os import path from os import path
from pathlib import PurePath from pathlib import PurePath
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
AnyStr, AnyStr,
Callable, Callable,
@ -19,11 +20,15 @@ from warnings import warn
from sanic.compat import Header, open_async from sanic.compat import Header, open_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
from sanic.exceptions import SanicException, ServerError
from sanic.helpers import has_message_body, remove_entity_headers from sanic.helpers import has_message_body, remove_entity_headers
from sanic.http import Http from sanic.http import Http
from sanic.models.protocol_types import HTMLProtocol, Range from sanic.models.protocol_types import HTMLProtocol, Range
if TYPE_CHECKING:
from sanic.asgi import ASGIApp
try: try:
from ujson import dumps as json_dumps from ujson import dumps as json_dumps
except ImportError: except ImportError:
@ -45,7 +50,7 @@ class BaseHTTPResponse:
self.asgi: bool = False self.asgi: bool = False
self.body: Optional[bytes] = None self.body: Optional[bytes] = None
self.content_type: Optional[str] = None self.content_type: Optional[str] = None
self.stream: Http = None self.stream: Optional[Union[Http, ASGIApp]] = None
self.status: int = None self.status: int = None
self.headers = Header({}) self.headers = Header({})
self._cookies: Optional[CookieJar] = None self._cookies: Optional[CookieJar] = None
@ -112,8 +117,17 @@ class BaseHTTPResponse:
""" """
if data is None and end_stream is None: if data is None and end_stream is None:
end_stream = True end_stream = True
if end_stream and not data and self.stream.send is None: if self.stream is None:
return raise SanicException(
"No stream is connected to the response object instance."
)
if self.stream.send is None:
if end_stream and not data:
return
raise ServerError(
"Response stream was ended, no more response data is "
"allowed to be sent."
)
data = ( data = (
data.encode() # type: ignore data.encode() # type: ignore
if hasattr(data, "encode") if hasattr(data, "encode")

View File

@ -6,7 +6,8 @@ import string
import sys import sys
import uuid import uuid
from typing import Tuple from logging import LogRecord
from typing import Callable, List, Tuple
import pytest import pytest
@ -170,3 +171,16 @@ def run_startup(caplog):
return caplog.record_tuples return caplog.record_tuples
return run return run
@pytest.fixture(scope="function")
def message_in_records():
def msg_in_log(records: List[LogRecord], msg: str):
error_captured = False
for record in records:
if msg in record.message:
error_captured = True
break
return error_captured
return msg_in_log

View File

@ -1,15 +1,18 @@
import asyncio import asyncio
import logging import logging
from typing import Callable, List
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from pytest import LogCaptureFixture, MonkeyPatch
from sanic import Sanic, handlers from sanic import Sanic, handlers
from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.request import Request
from sanic.response import stream, text from sanic.response import stream, text
@ -90,35 +93,35 @@ def exception_handler_app():
return exception_handler_app return exception_handler_app
def test_invalid_usage_exception_handler(exception_handler_app): def test_invalid_usage_exception_handler(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get("/1") request, response = exception_handler_app.test_client.get("/1")
assert response.status == 400 assert response.status == 400
def test_server_error_exception_handler(exception_handler_app): def test_server_error_exception_handler(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get("/2") request, response = exception_handler_app.test_client.get("/2")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
def test_not_found_exception_handler(exception_handler_app): def test_not_found_exception_handler(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get("/3") request, response = exception_handler_app.test_client.get("/3")
assert response.status == 200 assert response.status == 200
def test_text_exception__handler(exception_handler_app): def test_text_exception__handler(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get("/random") request, response = exception_handler_app.test_client.get("/random")
assert response.status == 200 assert response.status == 200
assert response.text == "Done." assert response.text == "Done."
def test_async_exception_handler(exception_handler_app): def test_async_exception_handler(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get("/7") request, response = exception_handler_app.test_client.get("/7")
assert response.status == 200 assert response.status == 200
assert response.text == "foo,bar" assert response.text == "foo,bar"
def test_html_traceback_output_in_debug_mode(exception_handler_app): def test_html_traceback_output_in_debug_mode(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get("/4", debug=True) request, response = exception_handler_app.test_client.get("/4", debug=True)
assert response.status == 500 assert response.status == 500
soup = BeautifulSoup(response.body, "html.parser") soup = BeautifulSoup(response.body, "html.parser")
@ -133,12 +136,12 @@ def test_html_traceback_output_in_debug_mode(exception_handler_app):
) == summary_text ) == summary_text
def test_inherited_exception_handler(exception_handler_app): def test_inherited_exception_handler(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get("/5") request, response = exception_handler_app.test_client.get("/5")
assert response.status == 200 assert response.status == 200
def test_chained_exception_handler(exception_handler_app): def test_chained_exception_handler(exception_handler_app: Sanic):
request, response = exception_handler_app.test_client.get( request, response = exception_handler_app.test_client.get(
"/6/0", debug=True "/6/0", debug=True
) )
@ -157,7 +160,7 @@ def test_chained_exception_handler(exception_handler_app):
) == summary_text ) == summary_text
def test_exception_handler_lookup(exception_handler_app): def test_exception_handler_lookup(exception_handler_app: Sanic):
class CustomError(Exception): class CustomError(Exception):
pass pass
@ -205,13 +208,17 @@ def test_exception_handler_lookup(exception_handler_app):
) )
def test_exception_handler_processed_request_middleware(exception_handler_app): def test_exception_handler_processed_request_middleware(
exception_handler_app: Sanic,
):
request, response = exception_handler_app.test_client.get("/8") request, response = exception_handler_app.test_client.get("/8")
assert response.status == 200 assert response.status == 200
assert response.text == "Done." assert response.text == "Done."
def test_single_arg_exception_handler_notice(exception_handler_app, caplog): def test_single_arg_exception_handler_notice(
exception_handler_app: Sanic, caplog: LogCaptureFixture
):
class CustomErrorHandler(ErrorHandler): class CustomErrorHandler(ErrorHandler):
def lookup(self, exception): def lookup(self, exception):
return super().lookup(exception, None) return super().lookup(exception, None)
@ -233,7 +240,9 @@ def test_single_arg_exception_handler_notice(exception_handler_app, caplog):
assert response.status == 400 assert response.status == 400
def test_error_handler_noisy_log(exception_handler_app, monkeypatch): def test_error_handler_noisy_log(
exception_handler_app: Sanic, monkeypatch: MonkeyPatch
):
err_logger = Mock() err_logger = Mock()
monkeypatch.setattr(handlers, "error_logger", err_logger) monkeypatch.setattr(handlers, "error_logger", err_logger)
@ -246,3 +255,45 @@ def test_error_handler_noisy_log(exception_handler_app, monkeypatch):
err_logger.exception.assert_called_with( err_logger.exception.assert_called_with(
"Exception occurred while handling uri: %s", repr(request.url) "Exception occurred while handling uri: %s", repr(request.url)
) )
def test_exception_handler_response_was_sent(
app: Sanic,
caplog: LogCaptureFixture,
message_in_records: Callable[[List[logging.LogRecord], str], bool],
):
exception_handler_ran = False
@app.exception(ServerError)
async def exception_handler(request, exception):
nonlocal exception_handler_ran
exception_handler_ran = True
return text("Error")
@app.route("/1")
async def handler1(request: Request):
response = await request.respond()
await response.send("some text")
raise ServerError("Exception")
@app.route("/2")
async def handler2(request: Request):
response = await request.respond()
raise ServerError("Exception")
with caplog.at_level(logging.WARNING):
_, response = app.test_client.get("/1")
assert "some text" in response.text
# Change to assert warning not in the records in the future version.
message_in_records(
caplog.records,
(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"Therefore, the response from your custom exception "
),
)
_, response = app.test_client.get("/2")
assert "Error" in response.text

View File

@ -297,3 +297,27 @@ def test_middleware_added_response(app):
_, response = app.test_client.get("/") _, response = app.test_client.get("/")
assert response.json["foo"] == "bar" assert response.json["foo"] == "bar"
def test_middleware_return_response(app):
response_middleware_run_count = 0
request_middleware_run_count = 0
@app.on_response
def response(_, response):
nonlocal response_middleware_run_count
response_middleware_run_count += 1
@app.on_request
def request(_):
nonlocal request_middleware_run_count
request_middleware_run_count += 1
@app.get("/")
async def handler(request):
resp1 = await request.respond()
return resp1
_, response = app.test_client.get("/")
assert response_middleware_run_count == 1
assert request_middleware_run_count == 1

View File

@ -15,8 +15,8 @@ from sanic_testing.testing import (
) )
from sanic import Blueprint, Sanic from sanic import Blueprint, Sanic
from sanic.exceptions import ServerError from sanic.exceptions import SanicException, ServerError
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters
from sanic.response import html, json, text from sanic.response import html, json, text

View File

@ -3,15 +3,18 @@ import inspect
import os import os
from collections import namedtuple from collections import namedtuple
from logging import ERROR, LogRecord
from mimetypes import guess_type from mimetypes import guess_type
from random import choice from random import choice
from typing import Callable, List
from urllib.parse import unquote from urllib.parse import unquote
import pytest import pytest
from aiofiles import os as async_os from aiofiles import os as async_os
from pytest import LogCaptureFixture
from sanic import Sanic from sanic import Request, Sanic
from sanic.response import ( from sanic.response import (
HTTPResponse, HTTPResponse,
empty, empty,
@ -33,7 +36,7 @@ def test_response_body_not_a_string(app):
random_num = choice(range(1000)) random_num = choice(range(1000))
@app.route("/hello") @app.route("/hello")
async def hello_route(request): async def hello_route(request: Request):
return text(random_num) return text(random_num)
request, response = app.test_client.get("/hello") request, response = app.test_client.get("/hello")
@ -51,7 +54,7 @@ def test_method_not_allowed():
app = Sanic("app") app = Sanic("app")
@app.get("/") @app.get("/")
async def test_get(request): async def test_get(request: Request):
return response.json({"hello": "world"}) return response.json({"hello": "world"})
request, response = app.test_client.head("/") request, response = app.test_client.head("/")
@ -67,7 +70,7 @@ def test_method_not_allowed():
app.router.reset() app.router.reset()
@app.post("/") @app.post("/")
async def test_post(request): async def test_post(request: Request):
return response.json({"hello": "world"}) return response.json({"hello": "world"})
request, response = app.test_client.head("/") request, response = app.test_client.head("/")
@ -89,7 +92,7 @@ def test_method_not_allowed():
def test_response_header(app): def test_response_header(app):
@app.get("/") @app.get("/")
async def test(request): async def test(request: Request):
return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"}) return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"})
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
@ -102,14 +105,14 @@ def test_response_header(app):
def test_response_content_length(app): def test_response_content_length(app):
@app.get("/response_with_space") @app.get("/response_with_space")
async def response_with_space(request): async def response_with_space(request: Request):
return json( return json(
{"message": "Data", "details": "Some Details"}, {"message": "Data", "details": "Some Details"},
headers={"CONTENT-TYPE": "application/json"}, headers={"CONTENT-TYPE": "application/json"},
) )
@app.get("/response_without_space") @app.get("/response_without_space")
async def response_without_space(request): async def response_without_space(request: Request):
return json( return json(
{"message": "Data", "details": "Some Details"}, {"message": "Data", "details": "Some Details"},
headers={"CONTENT-TYPE": "application/json"}, headers={"CONTENT-TYPE": "application/json"},
@ -135,7 +138,7 @@ def test_response_content_length(app):
def test_response_content_length_with_different_data_types(app): def test_response_content_length_with_different_data_types(app):
@app.get("/") @app.get("/")
async def get_data_with_different_types(request): async def get_data_with_different_types(request: Request):
# Indentation issues in the Response is intentional. Please do not fix # Indentation issues in the Response is intentional. Please do not fix
return json( return json(
{"bool": True, "none": None, "string": "string", "number": -1}, {"bool": True, "none": None, "string": "string", "number": -1},
@ -149,23 +152,23 @@ def test_response_content_length_with_different_data_types(app):
@pytest.fixture @pytest.fixture
def json_app(app): def json_app(app):
@app.route("/") @app.route("/")
async def test(request): async def test(request: Request):
return json(JSON_DATA) return json(JSON_DATA)
@app.get("/no-content") @app.get("/no-content")
async def no_content_handler(request): async def no_content_handler(request: Request):
return json(JSON_DATA, status=204) return json(JSON_DATA, status=204)
@app.get("/no-content/unmodified") @app.get("/no-content/unmodified")
async def no_content_unmodified_handler(request): async def no_content_unmodified_handler(request: Request):
return json(None, status=304) return json(None, status=304)
@app.get("/unmodified") @app.get("/unmodified")
async def unmodified_handler(request): async def unmodified_handler(request: Request):
return json(JSON_DATA, status=304) return json(JSON_DATA, status=304)
@app.delete("/") @app.delete("/")
async def delete_handler(request): async def delete_handler(request: Request):
return json(None, status=204) return json(None, status=204)
return app return app
@ -207,7 +210,7 @@ def test_no_content(json_app):
@pytest.fixture @pytest.fixture
def streaming_app(app): def streaming_app(app):
@app.route("/") @app.route("/")
async def test(request): async def test(request: Request):
return stream( return stream(
sample_streaming_fn, sample_streaming_fn,
content_type="text/csv", content_type="text/csv",
@ -219,7 +222,7 @@ def streaming_app(app):
@pytest.fixture @pytest.fixture
def non_chunked_streaming_app(app): def non_chunked_streaming_app(app):
@app.route("/") @app.route("/")
async def test(request): async def test(request: Request):
return stream( return stream(
sample_streaming_fn, sample_streaming_fn,
headers={"Content-Length": "7"}, headers={"Content-Length": "7"},
@ -276,7 +279,7 @@ def test_non_chunked_streaming_returns_correct_content(
def test_stream_response_with_cookies(app): def test_stream_response_with_cookies(app):
@app.route("/") @app.route("/")
async def test(request): async def test(request: Request):
response = stream(sample_streaming_fn, content_type="text/csv") response = stream(sample_streaming_fn, content_type="text/csv")
response.cookies["test"] = "modified" response.cookies["test"] = "modified"
response.cookies["test"] = "pass" response.cookies["test"] = "pass"
@ -288,7 +291,7 @@ def test_stream_response_with_cookies(app):
def test_stream_response_without_cookies(app): def test_stream_response_without_cookies(app):
@app.route("/") @app.route("/")
async def test(request): async def test(request: Request):
return stream(sample_streaming_fn, content_type="text/csv") return stream(sample_streaming_fn, content_type="text/csv")
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
@ -314,7 +317,7 @@ def get_file_content(static_file_directory, file_name):
"file_name", ["test.file", "decode me.txt", "python.png"] "file_name", ["test.file", "decode me.txt", "python.png"]
) )
@pytest.mark.parametrize("status", [200, 401]) @pytest.mark.parametrize("status", [200, 401])
def test_file_response(app, file_name, static_file_directory, status): def test_file_response(app: Sanic, file_name, static_file_directory, status):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
def file_route(request, filename): def file_route(request, filename):
file_path = os.path.join(static_file_directory, filename) file_path = os.path.join(static_file_directory, filename)
@ -340,7 +343,7 @@ def test_file_response(app, file_name, static_file_directory, status):
], ],
) )
def test_file_response_custom_filename( def test_file_response_custom_filename(
app, source, dest, static_file_directory app: Sanic, source, dest, static_file_directory
): ):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
def file_route(request, filename): def file_route(request, filename):
@ -358,7 +361,7 @@ def test_file_response_custom_filename(
@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"]) @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"])
def test_file_head_response(app, file_name, static_file_directory): def test_file_head_response(app: Sanic, file_name, static_file_directory):
@app.route("/files/<filename>", methods=["GET", "HEAD"]) @app.route("/files/<filename>", methods=["GET", "HEAD"])
async def file_route(request, filename): async def file_route(request, filename):
file_path = os.path.join(static_file_directory, filename) file_path = os.path.join(static_file_directory, filename)
@ -391,7 +394,7 @@ def test_file_head_response(app, file_name, static_file_directory):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"] "file_name", ["test.file", "decode me.txt", "python.png"]
) )
def test_file_stream_response(app, file_name, static_file_directory): def test_file_stream_response(app: Sanic, file_name, static_file_directory):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
def file_route(request, filename): def file_route(request, filename):
file_path = os.path.join(static_file_directory, filename) file_path = os.path.join(static_file_directory, filename)
@ -417,7 +420,7 @@ def test_file_stream_response(app, file_name, static_file_directory):
], ],
) )
def test_file_stream_response_custom_filename( def test_file_stream_response_custom_filename(
app, source, dest, static_file_directory app: Sanic, source, dest, static_file_directory
): ):
@app.route("/files/<filename>", methods=["GET"]) @app.route("/files/<filename>", methods=["GET"])
def file_route(request, filename): def file_route(request, filename):
@ -435,7 +438,9 @@ def test_file_stream_response_custom_filename(
@pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"]) @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"])
def test_file_stream_head_response(app, file_name, static_file_directory): def test_file_stream_head_response(
app: Sanic, file_name, static_file_directory
):
@app.route("/files/<filename>", methods=["GET", "HEAD"]) @app.route("/files/<filename>", methods=["GET", "HEAD"])
async def file_route(request, filename): async def file_route(request, filename):
file_path = os.path.join(static_file_directory, filename) file_path = os.path.join(static_file_directory, filename)
@ -479,7 +484,7 @@ def test_file_stream_head_response(app, file_name, static_file_directory):
"size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)] "size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)]
) )
def test_file_stream_response_range( def test_file_stream_response_range(
app, file_name, static_file_directory, size, start, end app: Sanic, file_name, static_file_directory, size, start, end
): ):
Range = namedtuple("Range", ["size", "start", "end", "total"]) Range = namedtuple("Range", ["size", "start", "end", "total"])
@ -508,7 +513,7 @@ def test_file_stream_response_range(
def test_raw_response(app): def test_raw_response(app):
@app.get("/test") @app.get("/test")
def handler(request): def handler(request: Request):
return raw(b"raw_response") return raw(b"raw_response")
request, response = app.test_client.get("/test") request, response = app.test_client.get("/test")
@ -518,7 +523,7 @@ def test_raw_response(app):
def test_empty_response(app): def test_empty_response(app):
@app.get("/test") @app.get("/test")
def handler(request): def handler(request: Request):
return empty() return empty()
request, response = app.test_client.get("/test") request, response = app.test_client.get("/test")
@ -526,17 +531,162 @@ def test_empty_response(app):
assert response.body == b"" assert response.body == b""
def test_direct_response_stream(app): def test_direct_response_stream(app: Sanic):
@app.route("/") @app.route("/")
async def test(request): async def test(request: Request):
response = await request.respond(content_type="text/csv") response = await request.respond(content_type="text/csv")
await response.send("foo,") await response.send("foo,")
await response.send("bar") await response.send("bar")
await response.eof() await response.eof()
return response
_, response = app.test_client.get("/") _, response = app.test_client.get("/")
assert response.text == "foo,bar" assert response.text == "foo,bar"
assert response.headers["Transfer-Encoding"] == "chunked" assert response.headers["Transfer-Encoding"] == "chunked"
assert response.headers["Content-Type"] == "text/csv" assert response.headers["Content-Type"] == "text/csv"
assert "Content-Length" not in response.headers assert "Content-Length" not in response.headers
def test_two_respond_calls(app: Sanic):
@app.route("/")
async def handler(request: Request):
response = await request.respond()
await response.send("foo,")
await response.send("bar")
await response.eof()
def test_multiple_responses(
app: Sanic,
caplog: LogCaptureFixture,
message_in_records: Callable[[List[LogRecord], str], bool],
):
@app.route("/1")
async def handler(request: Request):
response = await request.respond()
await response.send("foo")
response = await request.respond()
@app.route("/2")
async def handler(request: Request):
response = await request.respond()
response = await request.respond()
await response.send("foo")
@app.get("/3")
async def handler(request: Request):
response = await request.respond()
await response.send("foo,")
response = await request.respond()
await response.send("bar")
@app.get("/4")
async def handler(request: Request):
response = await request.respond(headers={"one": "one"})
return json({"foo": "bar"}, headers={"one": "two"})
@app.get("/5")
async def handler(request: Request):
response = await request.respond(headers={"one": "one"})
await response.send("foo")
return json({"foo": "bar"}, headers={"one": "two"})
@app.get("/6")
async def handler(request: Request):
response = await request.respond(headers={"one": "one"})
await response.send("foo, ")
json_response = json({"foo": "bar"}, headers={"one": "two"})
await response.send("bar")
return json_response
error_msg0 = "Second respond call is not allowed."
error_msg1 = (
"The error response will not be sent to the client for the following "
'exception:"Second respond call is not allowed.". A previous '
"response has at least partially been sent."
)
error_msg2 = (
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
error_msg3 = (
"Response stream was ended, no more "
"response data is allowed to be sent."
)
with caplog.at_level(ERROR):
_, response = app.test_client.get("/1")
assert response.status == 200
assert message_in_records(caplog.records, error_msg0)
assert message_in_records(caplog.records, error_msg1)
with caplog.at_level(ERROR):
_, response = app.test_client.get("/2")
assert response.status == 500
assert "500 — Internal Server Error" in response.text
with caplog.at_level(ERROR):
_, response = app.test_client.get("/3")
assert response.status == 200
assert "foo," in response.text
assert message_in_records(caplog.records, error_msg0)
assert message_in_records(caplog.records, error_msg1)
with caplog.at_level(ERROR):
_, response = app.test_client.get("/4")
print(response.json)
assert response.status == 200
assert "foo" not in response.text
assert "one" in response.headers
assert response.headers["one"] == "one"
print(response.headers)
assert message_in_records(caplog.records, error_msg2)
with caplog.at_level(ERROR):
_, response = app.test_client.get("/5")
assert response.status == 200
assert "foo" in response.text
assert "one" in response.headers
assert response.headers["one"] == "one"
assert message_in_records(caplog.records, error_msg2)
with caplog.at_level(ERROR):
_, response = app.test_client.get("/6")
assert "foo, bar" in response.text
assert "one" in response.headers
assert response.headers["one"] == "one"
assert message_in_records(caplog.records, error_msg2)
def send_response_after_eof_should_fail(
app: Sanic,
caplog: LogCaptureFixture,
message_in_records: Callable[[List[LogRecord], str], bool],
):
@app.get("/")
async def handler(request: Request):
response = await request.respond()
await response.send("foo, ")
await response.eof()
await response.send("bar")
error_msg1 = (
"The error response will not be sent to the client for the following "
'exception:"Second respond call is not allowed.". A previous '
"response has at least partially been sent."
)
error_msg2 = (
"Response stream was ended, no more "
"response data is allowed to be sent."
)
with caplog.at_level(ERROR):
_, response = app.test_client.get("/")
assert "foo, " in response.text
assert message_in_records(caplog.records, error_msg1)
assert message_in_records(caplog.records, error_msg2)