Compare commits

..

2 Commits

Author SHA1 Message Date
Adam Hopkins
5e12edbc38 Allow non-conforming ErrorHandlers (#2259)
* Allow non-conforming ErrorHandlers

* Rename to legacy lookup

* Updated depnotice

* Bump version

* Fix formatting

* Remove unused import

* Fix error messages
2021-10-03 01:02:56 +03:00
Adam Hopkins
50a606adee Merge pull request #2256 from sanic-org/current-release
Mergeback
2021-10-02 22:21:34 +03:00
12 changed files with 153 additions and 202 deletions

View File

@@ -1 +1 @@
__version__ = "21.9.3"
__version__ = "21.9.1"

View File

@@ -173,18 +173,18 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self.asgi = False
self.auto_reload = False
self.blueprints: Dict[str, Blueprint] = {}
self.config: Config = config or Config(
load_env=load_env,
env_prefix=env_prefix,
app=self,
self.config = config or Config(
load_env=load_env, env_prefix=env_prefix
)
self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace()
self.configure_logging = configure_logging
self.ctx = ctx or SimpleNamespace()
self.debug = None
self.error_handler: ErrorHandler = error_handler or ErrorHandler()
self.error_handler = error_handler or ErrorHandler(
fallback=self.config.FALLBACK_ERROR_FORMAT,
)
self.is_running = False
self.is_stopping = False
self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list)
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.reload_dirs: Set[Path] = set()
@@ -1474,9 +1474,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
async def _startup(self):
self.signalize()
self.finalize()
ErrorHandler.finalize(
self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT
)
ErrorHandler.finalize(self.error_handler)
TouchUp.run(self)
async def _server_event(

View File

@@ -1,9 +1,7 @@
from __future__ import annotations
from inspect import isclass
from os import environ
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union
from warnings import warn
from sanic.errorpages import check_error_format
@@ -12,10 +10,6 @@ from sanic.http import Http
from .utils import load_module_from_file_location, str_to_bool
if TYPE_CHECKING: # no cov
from sanic import Sanic
SANIC_PREFIX = "SANIC_"
BASE_LOGO = """
@@ -77,14 +71,11 @@ class Config(dict):
load_env: Optional[Union[bool, str]] = True,
env_prefix: Optional[str] = SANIC_PREFIX,
keep_alive: Optional[bool] = None,
*,
app: Optional[Sanic] = None,
):
defaults = defaults or {}
super().__init__({**DEFAULT_CONFIG, **defaults})
self._app = app
self._LOGO = BASE_LOGO
self.LOGO = BASE_LOGO
if keep_alive is not None:
self.KEEP_ALIVE = keep_alive
@@ -106,7 +97,6 @@ class Config(dict):
self._configure_header_size()
self._check_error_format()
self._init = True
def __getattr__(self, attr):
try:
@@ -114,51 +104,16 @@ class Config(dict):
except KeyError as ke:
raise AttributeError(f"Config has no '{ke.args[0]}'")
def __setattr__(self, attr, value) -> None:
self.update({attr: value})
def __setitem__(self, attr, value) -> None:
self.update({attr: value})
def update(self, *other, **kwargs) -> None:
other_mapping = {k: v for item in other for k, v in dict(item).items()}
super().update(*other, **kwargs)
for attr, value in {**other_mapping, **kwargs}.items():
self._post_set(attr, value)
def _post_set(self, attr, value) -> None:
if self.get("_init"):
if attr in (
"REQUEST_MAX_HEADER_SIZE",
"REQUEST_BUFFER_SIZE",
"REQUEST_MAX_SIZE",
):
self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format()
if self.app and value != self.app.error_handler.fallback:
if self.app.error_handler.fallback != "auto":
warn(
"Overriding non-default ErrorHandler fallback "
"value. Changing from "
f"{self.app.error_handler.fallback} to {value}."
)
self.app.error_handler.fallback = value
elif attr == "LOGO":
self._LOGO = value
warn(
"Setting the config.LOGO is deprecated and will no longer "
"be supported starting in v22.6.",
DeprecationWarning,
)
@property
def app(self):
return self._app
@property
def LOGO(self):
return self._LOGO
def __setattr__(self, attr, value):
self[attr] = value
if attr in (
"REQUEST_MAX_HEADER_SIZE",
"REQUEST_BUFFER_SIZE",
"REQUEST_MAX_SIZE",
):
self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format()
def _configure_header_size(self):
Http.set_header_max_size(

View File

@@ -393,8 +393,7 @@ def exception_response(
# from the route
if request.route:
try:
if request.route.ctx.error_format:
render_format = request.route.ctx.error_format
render_format = request.route.ctx.error_format
except AttributeError:
...

View File

@@ -38,14 +38,7 @@ class ErrorHandler:
self.base = base
@classmethod
def finalize(cls, error_handler, fallback: Optional[str] = None):
if (
fallback
and fallback != "auto"
and error_handler.fallback == "auto"
):
error_handler.fallback = fallback
def finalize(cls, error_handler):
if not isinstance(error_handler, cls):
error_logger.warning(
f"Error handler is non-conforming: {type(error_handler)}"

View File

@@ -105,6 +105,7 @@ class Http(metaclass=TouchUpMeta):
self.keep_alive = True
self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch
self.init_for_request()
def init_for_request(self):
"""Init/reset all per-request variables."""
@@ -128,20 +129,14 @@ class Http(metaclass=TouchUpMeta):
"""
HTTP 1.1 connection handler
"""
# Handle requests while the connection stays reusable
while self.keep_alive and self.stage is Stage.IDLE:
self.init_for_request()
# Wait for incoming bytes (in IDLE stage)
if not self.recv_buffer:
await self._receive_more()
self.stage = Stage.REQUEST
while True: # As long as connection stays keep-alive
try:
# Receive and handle a request
self.stage = Stage.REQUEST
self.response_func = self.http1_response_header
await self.http1_request_header()
self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
@@ -192,6 +187,16 @@ class Http(metaclass=TouchUpMeta):
if self.response:
self.response.stream = None
# Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive:
break
self.init_for_request()
# Wait for the next request
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self): # no cov
"""
Receive and parse request header into self.request.
@@ -294,6 +299,7 @@ class Http(metaclass=TouchUpMeta):
# Remove header and its trailing CRLF
del buf[: pos + 4]
self.stage = Stage.HANDLER
self.request, request.stream = request, self
self.protocol.state["requests_count"] += 1

View File

@@ -918,7 +918,7 @@ class RouteMixin:
return route
def _determine_error_format(self, handler) -> Optional[str]:
def _determine_error_format(self, handler) -> str:
if not isinstance(handler, CompositionView):
try:
src = dedent(getsource(handler))
@@ -930,7 +930,7 @@ class RouteMixin:
except (OSError, TypeError):
...
return None
return "auto"
def _get_response_types(self, node):
types = set()

View File

@@ -139,10 +139,11 @@ class Router(BaseRouter):
route.ctx.stream = stream
route.ctx.hosts = hosts
route.ctx.static = static
route.ctx.error_format = error_format
route.ctx.error_format = (
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
if error_format:
check_error_format(route.ctx.error_format)
check_error_format(route.ctx.error_format)
routes.append(route)

View File

@@ -3,7 +3,6 @@ from os import environ
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
from unittest.mock import Mock
import pytest
@@ -351,40 +350,3 @@ def test_update_from_lowercase_key(app):
d = {"test_setting_value": 1}
app.update_config(d)
assert "test_setting_value" not in app.config
def test_deprecation_notice_when_setting_logo(app):
message = (
"Setting the config.LOGO is deprecated and will no longer be "
"supported starting in v22.6."
)
with pytest.warns(DeprecationWarning, match=message):
app.config.LOGO = "My Custom Logo"
def test_config_set_methods(app, monkeypatch):
post_set = Mock()
monkeypatch.setattr(Config, "_post_set", post_set)
app.config.FOO = 1
post_set.assert_called_once_with("FOO", 1)
post_set.reset_mock()
app.config["FOO"] = 2
post_set.assert_called_once_with("FOO", 2)
post_set.reset_mock()
app.config.update({"FOO": 3})
post_set.assert_called_once_with("FOO", 3)
post_set.reset_mock()
app.config.update([("FOO", 4)])
post_set.assert_called_once_with("FOO", 4)
post_set.reset_mock()
app.config.update(FOO=5)
post_set.assert_called_once_with("FOO", 5)
post_set.reset_mock()
app.config.update_config({"FOO": 6})
post_set.assert_called_once_with("FOO", 6)

View File

@@ -1,10 +1,8 @@
import pytest
from sanic import Sanic
from sanic.config import Config
from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound, SanicException
from sanic.handlers import ErrorHandler
from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text
@@ -273,72 +271,3 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected):
)
assert response.content_type == expected
def test_allow_fallback_error_format_set_main_process_start(app):
@app.main_process_start
async def start(app, _):
app.config.FALLBACK_ERROR_FORMAT = "text"
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_setting_fallback_to_non_default_raise_warning(app):
app.error_handler = ErrorHandler(fallback="text")
assert app.error_handler.fallback == "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to auto."
),
):
app.config.FALLBACK_ERROR_FORMAT = "auto"
assert app.error_handler.fallback == "auto"
app.config.FALLBACK_ERROR_FORMAT = "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to json."
),
):
app.config.FALLBACK_ERROR_FORMAT = "json"
assert app.error_handler.fallback == "json"
def test_allow_fallback_error_format_in_config_injection():
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app = Sanic("test", config=MyConfig())
@app.route("/error", methods=["GET", "POST"])
def err(request):
raise Exception("something went wrong")
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_allow_fallback_error_format_in_config_replacement(app):
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app.config = MyConfig()
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"

View File

@@ -0,0 +1,109 @@
import asyncio
import httpcore
import httpx
import pytest
from sanic_testing.testing import SanicTestClient
from sanic import Sanic
from sanic.response import text
class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection):
async def arequest(self, *args, **kwargs):
await asyncio.sleep(2)
return await super().arequest(*args, **kwargs)
async def _open_socket(self, *args, **kwargs):
retval = await super()._open_socket(*args, **kwargs)
if self._request_delay:
await asyncio.sleep(self._request_delay)
return retval
class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool):
def __init__(self, request_delay=None, *args, **kwargs):
self._request_delay = request_delay
super().__init__(*args, **kwargs)
async def _add_to_pool(self, connection, timeout):
connection.__class__ = DelayableHTTPConnection
connection._request_delay = self._request_delay
await super()._add_to_pool(connection, timeout)
class DelayableSanicSession(httpx.AsyncClient):
def __init__(self, request_delay=None, *args, **kwargs) -> None:
transport = DelayableSanicConnectionPool(request_delay=request_delay)
super().__init__(transport=transport, *args, **kwargs)
class DelayableSanicTestClient(SanicTestClient):
def __init__(self, app, request_delay=None):
super().__init__(app)
self._request_delay = request_delay
self._loop = None
def get_new_session(self):
return DelayableSanicSession(request_delay=self._request_delay)
@pytest.fixture
def request_no_timeout_app():
app = Sanic("test_request_no_timeout")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler2(request):
return text("OK")
return app
@pytest.fixture
def request_timeout_default_app():
app = Sanic("test_request_timeout_default")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler1(request):
return text("OK")
@app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
return app
def test_default_server_error_request_timeout(request_timeout_default_app):
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/1")
assert response.status == 408
assert "Request Timeout" in response.text
def test_default_server_error_request_dont_timeout(request_no_timeout_app):
client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
_, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
def test_default_server_error_websocket_request_timeout(
request_timeout_default_app,
):
headers = {
"Upgrade": "websocket",
"Connection": "upgrade",
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version": "13",
}
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/ws1", headers=headers)
assert response.status == 408
assert "Request Timeout" in response.text

View File

@@ -26,7 +26,6 @@ def protocol(app, mock_transport):
protocol = HttpProtocol(loop=loop, app=app)
protocol.connection_made(mock_transport)
protocol._setup_connection()
protocol._http.init_for_request()
protocol._task = Mock(spec=asyncio.Task)
protocol._task.cancel = Mock()
return protocol