Release 21.9.3 (#2318)

* Allow non-conforming ErrorHandlers (#2259)

* Allow non-conforming ErrorHandlers

* Rename to legacy lookup

* Updated depnotice

* Bump version

* Fix formatting

* Remove unused import

* Fix error messages

* Add error format commit and merge conflicts

* Make HTTP connections start in IDLE stage, avoiding delays and error messages (#2268)

* Make all new connections start in IDLE stage, and switch to REQUEST stage only once any bytes are received from client. This makes new connections without any request obey keepalive timeout rather than request timeout like they currently do.

* Revert typo

* Remove request timeout endpoint test which is no longer working (still tested by mocking). Fix mock timeout test setup.

Co-authored-by: L. Karkkainen <tronic@users.noreply.github.com>

* Bump version

* Add error format from config replacement objects

* Cleanup mistaken print statement

* Cleanup reversions

* Bump version

Co-authored-by: L. Kärkkäinen <98187+Tronic@users.noreply.github.com>
Co-authored-by: L. Karkkainen <tronic@users.noreply.github.com>
This commit is contained in:
Adam Hopkins 2021-11-21 14:26:30 +02:00 committed by GitHub
parent f995612073
commit 8673021ad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 259 additions and 155 deletions

View File

@ -1 +1 @@
__version__ = "21.9.0" __version__ = "21.9.3"

View File

@ -173,18 +173,18 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self.asgi = False self.asgi = False
self.auto_reload = False self.auto_reload = False
self.blueprints: Dict[str, Blueprint] = {} self.blueprints: Dict[str, Blueprint] = {}
self.config = config or Config( self.config: Config = config or Config(
load_env=load_env, env_prefix=env_prefix load_env=load_env,
env_prefix=env_prefix,
app=self,
) )
self.configure_logging = configure_logging self.configure_logging: bool = configure_logging
self.ctx = ctx or SimpleNamespace() self.ctx: Any = ctx or SimpleNamespace()
self.debug = None self.debug = None
self.error_handler = error_handler or ErrorHandler( self.error_handler: ErrorHandler = error_handler or ErrorHandler()
fallback=self.config.FALLBACK_ERROR_FORMAT,
)
self.is_running = False self.is_running = False
self.is_stopping = False self.is_stopping = False
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list)
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.reload_dirs: Set[Path] = set() self.reload_dirs: Set[Path] = set()
@ -1474,6 +1474,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
async def _startup(self): async def _startup(self):
self.signalize() self.signalize()
self.finalize() self.finalize()
ErrorHandler.finalize(
self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT
)
TouchUp.run(self) TouchUp.run(self)
async def _server_event( async def _server_event(

View File

@ -1,7 +1,9 @@
from __future__ import annotations
from inspect import isclass from inspect import isclass
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from warnings import warn from warnings import warn
from sanic.errorpages import check_error_format from sanic.errorpages import check_error_format
@ -10,6 +12,10 @@ from sanic.http import Http
from .utils import load_module_from_file_location, str_to_bool from .utils import load_module_from_file_location, str_to_bool
if TYPE_CHECKING: # no cov
from sanic import Sanic
SANIC_PREFIX = "SANIC_" SANIC_PREFIX = "SANIC_"
BASE_LOGO = """ BASE_LOGO = """
@ -71,11 +77,14 @@ class Config(dict):
load_env: Optional[Union[bool, str]] = True, load_env: Optional[Union[bool, str]] = True,
env_prefix: Optional[str] = SANIC_PREFIX, env_prefix: Optional[str] = SANIC_PREFIX,
keep_alive: Optional[bool] = None, keep_alive: Optional[bool] = None,
*,
app: Optional[Sanic] = None,
): ):
defaults = defaults or {} defaults = defaults or {}
super().__init__({**DEFAULT_CONFIG, **defaults}) super().__init__({**DEFAULT_CONFIG, **defaults})
self.LOGO = BASE_LOGO self._app = app
self._LOGO = BASE_LOGO
if keep_alive is not None: if keep_alive is not None:
self.KEEP_ALIVE = keep_alive self.KEEP_ALIVE = keep_alive
@ -97,6 +106,7 @@ class Config(dict):
self._configure_header_size() self._configure_header_size()
self._check_error_format() self._check_error_format()
self._init = True
def __getattr__(self, attr): def __getattr__(self, attr):
try: try:
@ -104,16 +114,51 @@ class Config(dict):
except KeyError as ke: except KeyError as ke:
raise AttributeError(f"Config has no '{ke.args[0]}'") raise AttributeError(f"Config has no '{ke.args[0]}'")
def __setattr__(self, attr, value): def __setattr__(self, attr, value) -> None:
self[attr] = value self.update({attr: value})
if attr in (
"REQUEST_MAX_HEADER_SIZE", def __setitem__(self, attr, value) -> None:
"REQUEST_BUFFER_SIZE", self.update({attr: value})
"REQUEST_MAX_SIZE",
): def update(self, *other, **kwargs) -> None:
self._configure_header_size() other_mapping = {k: v for item in other for k, v in dict(item).items()}
elif attr == "FALLBACK_ERROR_FORMAT": super().update(*other, **kwargs)
self._check_error_format() for attr, value in {**other_mapping, **kwargs}.items():
self._post_set(attr, value)
def _post_set(self, attr, value) -> None:
if self.get("_init"):
if attr in (
"REQUEST_MAX_HEADER_SIZE",
"REQUEST_BUFFER_SIZE",
"REQUEST_MAX_SIZE",
):
self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format()
if self.app and value != self.app.error_handler.fallback:
if self.app.error_handler.fallback != "auto":
warn(
"Overriding non-default ErrorHandler fallback "
"value. Changing from "
f"{self.app.error_handler.fallback} to {value}."
)
self.app.error_handler.fallback = value
elif attr == "LOGO":
self._LOGO = value
warn(
"Setting the config.LOGO is deprecated and will no longer "
"be supported starting in v22.6.",
DeprecationWarning,
)
@property
def app(self):
return self._app
@property
def LOGO(self):
return self._LOGO
def _configure_header_size(self): def _configure_header_size(self):
Http.set_header_max_size( Http.set_header_max_size(

View File

@ -393,7 +393,8 @@ def exception_response(
# from the route # from the route
if request.route: if request.route:
try: try:
render_format = request.route.ctx.error_format if request.route.ctx.error_format:
render_format = request.route.ctx.error_format
except AttributeError: except AttributeError:
... ...

View File

@ -1,3 +1,4 @@
from inspect import signature
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response
@ -25,7 +26,9 @@ class ErrorHandler:
""" """
# Beginning in v22.3, the base renderer will be TextRenderer # Beginning in v22.3, the base renderer will be TextRenderer
def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer): def __init__(
self, fallback: str = "auto", base: Type[BaseRenderer] = HTMLRenderer
):
self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = []
self.cached_handlers: Dict[ self.cached_handlers: Dict[
Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler]
@ -34,6 +37,41 @@ class ErrorHandler:
self.fallback = fallback self.fallback = fallback
self.base = base self.base = base
@classmethod
def finalize(cls, error_handler, fallback: Optional[str] = None):
if (
fallback
and fallback != "auto"
and error_handler.fallback == "auto"
):
error_handler.fallback = fallback
if not isinstance(error_handler, cls):
error_logger.warning(
f"Error handler is non-conforming: {type(error_handler)}"
)
sig = signature(error_handler.lookup)
if len(sig.parameters) == 1:
error_logger.warning(
DeprecationWarning(
"You are using a deprecated error handler. The lookup "
"method should accept two positional parameters: "
"(exception, route_name: Optional[str]). "
"Until you upgrade your ErrorHandler.lookup, Blueprint "
"specific exceptions will not work properly. Beginning "
"in v22.3, the legacy style lookup method will not "
"work at all."
),
)
error_handler._lookup = error_handler._legacy_lookup
def _full_lookup(self, exception, route_name: Optional[str] = None):
return self.lookup(exception, route_name)
def _legacy_lookup(self, exception, route_name: Optional[str] = None):
return self.lookup(exception)
def add(self, exception, handler, route_names: Optional[List[str]] = None): def add(self, exception, handler, route_names: Optional[List[str]] = None):
""" """
Add a new exception handler to an already existing handler object. Add a new exception handler to an already existing handler object.
@ -56,7 +94,7 @@ class ErrorHandler:
else: else:
self.cached_handlers[(exception, None)] = handler self.cached_handlers[(exception, None)] = handler
def lookup(self, exception, route_name: Optional[str]): def lookup(self, exception, route_name: Optional[str] = None):
""" """
Lookup the existing instance of :class:`ErrorHandler` and fetch the Lookup the existing instance of :class:`ErrorHandler` and fetch the
registered handler for a specific type of exception. registered handler for a specific type of exception.
@ -94,6 +132,8 @@ class ErrorHandler:
handler = None handler = None
return handler return handler
_lookup = _full_lookup
def response(self, request, exception): def response(self, request, exception):
"""Fetches and executes an exception handler and returns a response """Fetches and executes an exception handler and returns a response
object object
@ -109,7 +149,7 @@ class ErrorHandler:
or registered handler for that type of exception. or registered handler for that type of exception.
""" """
route_name = request.name if request else None route_name = request.name if request else None
handler = self.lookup(exception, route_name) handler = self._lookup(exception, route_name)
response = None response = None
try: try:
if handler: if handler:

View File

@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta):
self.keep_alive = True self.keep_alive = True
self.stage: Stage = Stage.IDLE self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch self.dispatch = self.protocol.app.dispatch
self.init_for_request()
def init_for_request(self): def init_for_request(self):
"""Init/reset all per-request variables.""" """Init/reset all per-request variables."""
@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta):
""" """
HTTP 1.1 connection handler HTTP 1.1 connection handler
""" """
while True: # As long as connection stays keep-alive # Handle requests while the connection stays reusable
while self.keep_alive and self.stage is Stage.IDLE:
self.init_for_request()
# Wait for incoming bytes (in IDLE stage)
if not self.recv_buffer:
await self._receive_more()
self.stage = Stage.REQUEST
try: try:
# Receive and handle a request # Receive and handle a request
self.stage = Stage.REQUEST
self.response_func = self.http1_response_header self.response_func = self.http1_response_header
await self.http1_request_header() await self.http1_request_header()
self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request) await self.protocol.request_handler(self.request)
@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta):
if self.response: if self.response:
self.response.stream = None self.response.stream = None
# Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive:
break
self.init_for_request()
# Wait for the next request
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self): # no cov async def http1_request_header(self): # no cov
""" """
Receive and parse request header into self.request. Receive and parse request header into self.request.
@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta):
# Remove header and its trailing CRLF # Remove header and its trailing CRLF
del buf[: pos + 4] del buf[: pos + 4]
self.stage = Stage.HANDLER
self.request, request.stream = request, self self.request, request.stream = request, self
self.protocol.state["requests_count"] += 1 self.protocol.state["requests_count"] += 1

View File

@ -918,7 +918,7 @@ class RouteMixin:
return route return route
def _determine_error_format(self, handler) -> str: def _determine_error_format(self, handler) -> Optional[str]:
if not isinstance(handler, CompositionView): if not isinstance(handler, CompositionView):
try: try:
src = dedent(getsource(handler)) src = dedent(getsource(handler))
@ -930,7 +930,7 @@ class RouteMixin:
except (OSError, TypeError): except (OSError, TypeError):
... ...
return "auto" return None
def _get_response_types(self, node): def _get_response_types(self, node):
types = set() types = set()

View File

@ -139,11 +139,10 @@ class Router(BaseRouter):
route.ctx.stream = stream route.ctx.stream = stream
route.ctx.hosts = hosts route.ctx.hosts = hosts
route.ctx.static = static route.ctx.static = static
route.ctx.error_format = ( route.ctx.error_format = error_format
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
check_error_format(route.ctx.error_format) if error_format:
check_error_format(route.ctx.error_format)
routes.append(route) routes.append(route)

View File

@ -3,6 +3,7 @@ from os import environ
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from textwrap import dedent from textwrap import dedent
from unittest.mock import Mock
import pytest import pytest
@ -350,3 +351,40 @@ def test_update_from_lowercase_key(app):
d = {"test_setting_value": 1} d = {"test_setting_value": 1}
app.update_config(d) app.update_config(d)
assert "test_setting_value" not in app.config assert "test_setting_value" not in app.config
def test_deprecation_notice_when_setting_logo(app):
message = (
"Setting the config.LOGO is deprecated and will no longer be "
"supported starting in v22.6."
)
with pytest.warns(DeprecationWarning, match=message):
app.config.LOGO = "My Custom Logo"
def test_config_set_methods(app, monkeypatch):
post_set = Mock()
monkeypatch.setattr(Config, "_post_set", post_set)
app.config.FOO = 1
post_set.assert_called_once_with("FOO", 1)
post_set.reset_mock()
app.config["FOO"] = 2
post_set.assert_called_once_with("FOO", 2)
post_set.reset_mock()
app.config.update({"FOO": 3})
post_set.assert_called_once_with("FOO", 3)
post_set.reset_mock()
app.config.update([("FOO", 4)])
post_set.assert_called_once_with("FOO", 4)
post_set.reset_mock()
app.config.update(FOO=5)
post_set.assert_called_once_with("FOO", 5)
post_set.reset_mock()
app.config.update_config({"FOO": 6})
post_set.assert_called_once_with("FOO", 6)

View File

@ -1,8 +1,10 @@
import pytest import pytest
from sanic import Sanic from sanic import Sanic
from sanic.config import Config
from sanic.errorpages import HTMLRenderer, exception_response from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound, SanicException from sanic.exceptions import NotFound, SanicException
from sanic.handlers import ErrorHandler
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text from sanic.response import HTTPResponse, html, json, text
@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected):
) )
assert response.content_type == expected assert response.content_type == expected
def test_allow_fallback_error_format_set_main_process_start(app):
@app.main_process_start
async def start(app, _):
app.config.FALLBACK_ERROR_FORMAT = "text"
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_setting_fallback_to_non_default_raise_warning(app):
app.error_handler = ErrorHandler(fallback="text")
assert app.error_handler.fallback == "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to auto."
),
):
app.config.FALLBACK_ERROR_FORMAT = "auto"
assert app.error_handler.fallback == "auto"
app.config.FALLBACK_ERROR_FORMAT = "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to json."
),
):
app.config.FALLBACK_ERROR_FORMAT = "json"
assert app.error_handler.fallback == "json"
def test_allow_fallback_error_format_in_config_injection():
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app = Sanic("test", config=MyConfig())
@app.route("/error", methods=["GET", "POST"])
def err(request):
raise Exception("something went wrong")
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_allow_fallback_error_format_in_config_replacement(app):
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app.config = MyConfig()
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"

View File

@ -4,6 +4,7 @@ import warnings
import pytest import pytest
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from websockets.version import version as websockets_version
from sanic import Sanic from sanic import Sanic
from sanic.exceptions import ( from sanic.exceptions import (
@ -16,7 +17,6 @@ from sanic.exceptions import (
abort, abort,
) )
from sanic.response import text from sanic.response import text
from websockets.version import version as websockets_version
class SanicExceptionTestException(Exception): class SanicExceptionTestException(Exception):

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
import pytest import pytest
@ -206,3 +207,23 @@ def test_exception_handler_processed_request_middleware(exception_handler_app):
request, response = exception_handler_app.test_client.get("/8") request, response = exception_handler_app.test_client.get("/8")
assert response.status == 200 assert response.status == 200
assert response.text == "Done." assert response.text == "Done."
def test_single_arg_exception_handler_notice(exception_handler_app, caplog):
class CustomErrorHandler(ErrorHandler):
def lookup(self, exception):
return super().lookup(exception, None)
exception_handler_app.error_handler = CustomErrorHandler()
with caplog.at_level(logging.WARNING):
_, response = exception_handler_app.test_client.get("/1")
assert caplog.records[0].message == (
"You are using a deprecated error handler. The lookup method should "
"accept two positional parameters: (exception, route_name: "
"Optional[str]). Until you upgrade your ErrorHandler.lookup, "
"Blueprint specific exceptions will not work properly. Beginning in "
"v22.3, the legacy style lookup method will not work at all."
)
assert response.status == 400

View File

@ -1,109 +0,0 @@
import asyncio
import httpcore
import httpx
import pytest
from sanic_testing.testing import SanicTestClient
from sanic import Sanic
from sanic.response import text
class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection):
async def arequest(self, *args, **kwargs):
await asyncio.sleep(2)
return await super().arequest(*args, **kwargs)
async def _open_socket(self, *args, **kwargs):
retval = await super()._open_socket(*args, **kwargs)
if self._request_delay:
await asyncio.sleep(self._request_delay)
return retval
class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool):
def __init__(self, request_delay=None, *args, **kwargs):
self._request_delay = request_delay
super().__init__(*args, **kwargs)
async def _add_to_pool(self, connection, timeout):
connection.__class__ = DelayableHTTPConnection
connection._request_delay = self._request_delay
await super()._add_to_pool(connection, timeout)
class DelayableSanicSession(httpx.AsyncClient):
def __init__(self, request_delay=None, *args, **kwargs) -> None:
transport = DelayableSanicConnectionPool(request_delay=request_delay)
super().__init__(transport=transport, *args, **kwargs)
class DelayableSanicTestClient(SanicTestClient):
def __init__(self, app, request_delay=None):
super().__init__(app)
self._request_delay = request_delay
self._loop = None
def get_new_session(self):
return DelayableSanicSession(request_delay=self._request_delay)
@pytest.fixture
def request_no_timeout_app():
app = Sanic("test_request_no_timeout")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler2(request):
return text("OK")
return app
@pytest.fixture
def request_timeout_default_app():
app = Sanic("test_request_timeout_default")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler1(request):
return text("OK")
@app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
return app
def test_default_server_error_request_timeout(request_timeout_default_app):
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/1")
assert response.status == 408
assert "Request Timeout" in response.text
def test_default_server_error_request_dont_timeout(request_no_timeout_app):
client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
_, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
def test_default_server_error_websocket_request_timeout(
request_timeout_default_app,
):
headers = {
"Upgrade": "websocket",
"Connection": "upgrade",
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version": "13",
}
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/ws1", headers=headers)
assert response.status == 408
assert "Request Timeout" in response.text

View File

@ -26,6 +26,7 @@ def protocol(app, mock_transport):
protocol = HttpProtocol(loop=loop, app=app) protocol = HttpProtocol(loop=loop, app=app)
protocol.connection_made(mock_transport) protocol.connection_made(mock_transport)
protocol._setup_connection() protocol._setup_connection()
protocol._http.init_for_request()
protocol._task = Mock(spec=asyncio.Task) protocol._task = Mock(spec=asyncio.Task)
protocol._task.cancel = Mock() protocol._task.cancel = Mock()
return protocol return protocol