Compare commits

...

4 Commits

Author SHA1 Message Date
Adam Hopkins
af1d289a45
Add error format from config replacement objects 2021-11-17 09:05:24 +02:00
Adam Hopkins
b20b3cb417
Bump version 2021-11-16 23:13:59 +02:00
L. Kärkkäinen
45c22f9af2
Make HTTP connections start in IDLE stage, avoiding delays and error messages (#2268)
* Make all new connections start in IDLE stage, and switch to REQUEST stage only once any bytes are received from client. This makes new connections without any request obey keepalive timeout rather than request timeout like they currently do.

* Revert typo

* Remove request timeout endpoint test which is no longer working (still tested by mocking). Fix mock timeout test setup.

Co-authored-by: L. Karkkainen <tronic@users.noreply.github.com>
2021-11-16 23:09:29 +02:00
Adam Hopkins
71d845786d
Add error format commit and merge conflicts 2021-11-16 23:08:55 +02:00
12 changed files with 204 additions and 156 deletions

View File

@ -1 +1 @@
__version__ = "21.9.1" __version__ = "21.9.2"

View File

@ -173,18 +173,16 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self.asgi = False self.asgi = False
self.auto_reload = False self.auto_reload = False
self.blueprints: Dict[str, Blueprint] = {} self.blueprints: Dict[str, Blueprint] = {}
self.config = config or Config( self.config: Config = config or Config(
load_env=load_env, env_prefix=env_prefix load_env=load_env,
env_prefix=env_prefix,
app=self,
) )
self.configure_logging = configure_logging self.configure_logging: bool = configure_logging
self.ctx = ctx or SimpleNamespace() self.ctx: Any = ctx or SimpleNamespace()
self.debug = None self.debug = False
self.error_handler = error_handler or ErrorHandler( self.error_handler: ErrorHandler = error_handler or ErrorHandler()
fallback=self.config.FALLBACK_ERROR_FORMAT, self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list)
)
self.is_running = False
self.is_stopping = False
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.reload_dirs: Set[Path] = set() self.reload_dirs: Set[Path] = set()
@ -1474,7 +1472,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
async def _startup(self): async def _startup(self):
self.signalize() self.signalize()
self.finalize() self.finalize()
ErrorHandler.finalize(self.error_handler) ErrorHandler.finalize(
self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT
)
TouchUp.run(self) TouchUp.run(self)
async def _server_event( async def _server_event(

View File

@ -1,7 +1,9 @@
from __future__ import annotations
from inspect import isclass from inspect import isclass
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from warnings import warn from warnings import warn
from sanic.errorpages import check_error_format from sanic.errorpages import check_error_format
@ -10,6 +12,10 @@ from sanic.http import Http
from .utils import load_module_from_file_location, str_to_bool from .utils import load_module_from_file_location, str_to_bool
if TYPE_CHECKING: # no cov
from sanic import Sanic
SANIC_PREFIX = "SANIC_" SANIC_PREFIX = "SANIC_"
BASE_LOGO = """ BASE_LOGO = """
@ -71,11 +77,14 @@ class Config(dict):
load_env: Optional[Union[bool, str]] = True, load_env: Optional[Union[bool, str]] = True,
env_prefix: Optional[str] = SANIC_PREFIX, env_prefix: Optional[str] = SANIC_PREFIX,
keep_alive: Optional[bool] = None, keep_alive: Optional[bool] = None,
*,
app: Optional[Sanic] = None,
): ):
defaults = defaults or {} defaults = defaults or {}
super().__init__({**DEFAULT_CONFIG, **defaults}) super().__init__({**DEFAULT_CONFIG, **defaults})
self.LOGO = BASE_LOGO self._app = app
self._LOGO = ""
if keep_alive is not None: if keep_alive is not None:
self.KEEP_ALIVE = keep_alive self.KEEP_ALIVE = keep_alive
@ -97,6 +106,7 @@ class Config(dict):
self._configure_header_size() self._configure_header_size()
self._check_error_format() self._check_error_format()
self._init = True
def __getattr__(self, attr): def __getattr__(self, attr):
try: try:
@ -104,8 +114,20 @@ class Config(dict):
except KeyError as ke: except KeyError as ke:
raise AttributeError(f"Config has no '{ke.args[0]}'") raise AttributeError(f"Config has no '{ke.args[0]}'")
def __setattr__(self, attr, value): def __setattr__(self, attr, value) -> None:
self[attr] = value self.update({attr: value})
def __setitem__(self, attr, value) -> None:
self.update({attr: value})
def update(self, *other, **kwargs) -> None:
other_mapping = {k: v for item in other for k, v in dict(item).items()}
super().update(*other, **kwargs)
for attr, value in {**other_mapping, **kwargs}.items():
self._post_set(attr, value)
def _post_set(self, attr, value) -> None:
if self.get("_init"):
if attr in ( if attr in (
"REQUEST_MAX_HEADER_SIZE", "REQUEST_MAX_HEADER_SIZE",
"REQUEST_BUFFER_SIZE", "REQUEST_BUFFER_SIZE",
@ -114,6 +136,29 @@ class Config(dict):
self._configure_header_size() self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT": elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format() self._check_error_format()
if self.app and value != self.app.error_handler.fallback:
if self.app.error_handler.fallback != "auto":
warn(
"Overriding non-default ErrorHandler fallback "
"value. Changing from "
f"{self.app.error_handler.fallback} to {value}."
)
self.app.error_handler.fallback = value
elif attr == "LOGO":
self._LOGO = value
warn(
"Setting the config.LOGO is deprecated and will no longer "
"be supported starting in v22.6.",
DeprecationWarning,
)
@property
def app(self):
return self._app
@property
def LOGO(self):
return self._LOGO
def _configure_header_size(self): def _configure_header_size(self):
Http.set_header_max_size( Http.set_header_max_size(

View File

@ -383,6 +383,7 @@ def exception_response(
""" """
content_type = None content_type = None
print("exception_response", fallback)
if not renderer: if not renderer:
# Make sure we have something set # Make sure we have something set
renderer = base renderer = base
@ -393,6 +394,7 @@ def exception_response(
# from the route # from the route
if request.route: if request.route:
try: try:
if request.route.ctx.error_format:
render_format = request.route.ctx.error_format render_format = request.route.ctx.error_format
except AttributeError: except AttributeError:
... ...

View File

@ -38,7 +38,14 @@ class ErrorHandler:
self.base = base self.base = base
@classmethod @classmethod
def finalize(cls, error_handler): def finalize(cls, error_handler, fallback: Optional[str] = None):
if (
fallback
and fallback != "auto"
and error_handler.fallback == "auto"
):
error_handler.fallback = fallback
if not isinstance(error_handler, cls): if not isinstance(error_handler, cls):
error_logger.warning( error_logger.warning(
f"Error handler is non-conforming: {type(error_handler)}" f"Error handler is non-conforming: {type(error_handler)}"

View File

@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta):
self.keep_alive = True self.keep_alive = True
self.stage: Stage = Stage.IDLE self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch self.dispatch = self.protocol.app.dispatch
self.init_for_request()
def init_for_request(self): def init_for_request(self):
"""Init/reset all per-request variables.""" """Init/reset all per-request variables."""
@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta):
""" """
HTTP 1.1 connection handler HTTP 1.1 connection handler
""" """
while True: # As long as connection stays keep-alive # Handle requests while the connection stays reusable
while self.keep_alive and self.stage is Stage.IDLE:
self.init_for_request()
# Wait for incoming bytes (in IDLE stage)
if not self.recv_buffer:
await self._receive_more()
self.stage = Stage.REQUEST
try: try:
# Receive and handle a request # Receive and handle a request
self.stage = Stage.REQUEST
self.response_func = self.http1_response_header self.response_func = self.http1_response_header
await self.http1_request_header() await self.http1_request_header()
self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request) await self.protocol.request_handler(self.request)
@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta):
if self.response: if self.response:
self.response.stream = None self.response.stream = None
# Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive:
break
self.init_for_request()
# Wait for the next request
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self): # no cov async def http1_request_header(self): # no cov
""" """
Receive and parse request header into self.request. Receive and parse request header into self.request.
@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta):
# Remove header and its trailing CRLF # Remove header and its trailing CRLF
del buf[: pos + 4] del buf[: pos + 4]
self.stage = Stage.HANDLER
self.request, request.stream = request, self self.request, request.stream = request, self
self.protocol.state["requests_count"] += 1 self.protocol.state["requests_count"] += 1

View File

@ -918,7 +918,7 @@ class RouteMixin:
return route return route
def _determine_error_format(self, handler) -> str: def _determine_error_format(self, handler) -> Optional[str]:
if not isinstance(handler, CompositionView): if not isinstance(handler, CompositionView):
try: try:
src = dedent(getsource(handler)) src = dedent(getsource(handler))
@ -930,7 +930,7 @@ class RouteMixin:
except (OSError, TypeError): except (OSError, TypeError):
... ...
return "auto" return None
def _get_response_types(self, node): def _get_response_types(self, node):
types = set() types = set()

View File

@ -139,10 +139,9 @@ class Router(BaseRouter):
route.ctx.stream = stream route.ctx.stream = stream
route.ctx.hosts = hosts route.ctx.hosts = hosts
route.ctx.static = static route.ctx.static = static
route.ctx.error_format = ( route.ctx.error_format = error_format
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
if error_format:
check_error_format(route.ctx.error_format) check_error_format(route.ctx.error_format)
routes.append(route) routes.append(route)

View File

@ -3,6 +3,7 @@ from os import environ
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from textwrap import dedent from textwrap import dedent
from unittest.mock import Mock
import pytest import pytest
@ -350,3 +351,40 @@ def test_update_from_lowercase_key(app):
d = {"test_setting_value": 1} d = {"test_setting_value": 1}
app.update_config(d) app.update_config(d)
assert "test_setting_value" not in app.config assert "test_setting_value" not in app.config
def test_deprecation_notice_when_setting_logo(app):
message = (
"Setting the config.LOGO is deprecated and will no longer be "
"supported starting in v22.6."
)
with pytest.warns(DeprecationWarning, match=message):
app.config.LOGO = "My Custom Logo"
def test_config_set_methods(app, monkeypatch):
post_set = Mock()
monkeypatch.setattr(Config, "_post_set", post_set)
app.config.FOO = 1
post_set.assert_called_once_with("FOO", 1)
post_set.reset_mock()
app.config["FOO"] = 2
post_set.assert_called_once_with("FOO", 2)
post_set.reset_mock()
app.config.update({"FOO": 3})
post_set.assert_called_once_with("FOO", 3)
post_set.reset_mock()
app.config.update([("FOO", 4)])
post_set.assert_called_once_with("FOO", 4)
post_set.reset_mock()
app.config.update(FOO=5)
post_set.assert_called_once_with("FOO", 5)
post_set.reset_mock()
app.config.update_config({"FOO": 6})
post_set.assert_called_once_with("FOO", 6)

View File

@ -1,8 +1,10 @@
import pytest import pytest
from sanic import Sanic from sanic import Sanic
from sanic.config import Config
from sanic.errorpages import HTMLRenderer, exception_response from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound, SanicException from sanic.exceptions import NotFound, SanicException
from sanic.handlers import ErrorHandler
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text from sanic.response import HTTPResponse, html, json, text
@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected):
) )
assert response.content_type == expected assert response.content_type == expected
def test_allow_fallback_error_format_set_main_process_start(app):
@app.main_process_start
async def start(app, _):
app.config.FALLBACK_ERROR_FORMAT = "text"
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_setting_fallback_to_non_default_raise_warning(app):
app.error_handler = ErrorHandler(fallback="text")
assert app.error_handler.fallback == "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to auto."
),
):
app.config.FALLBACK_ERROR_FORMAT = "auto"
assert app.error_handler.fallback == "auto"
app.config.FALLBACK_ERROR_FORMAT = "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to json."
),
):
app.config.FALLBACK_ERROR_FORMAT = "json"
assert app.error_handler.fallback == "json"
def test_allow_fallback_error_format_in_config_injection():
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app = Sanic("test", config=MyConfig())
@app.route("/error", methods=["GET", "POST"])
def err(request):
raise Exception("something went wrong")
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_allow_fallback_error_format_in_config_replacement(app):
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app.config = MyConfig()
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"

View File

@ -1,109 +0,0 @@
import asyncio
import httpcore
import httpx
import pytest
from sanic_testing.testing import SanicTestClient
from sanic import Sanic
from sanic.response import text
class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection):
async def arequest(self, *args, **kwargs):
await asyncio.sleep(2)
return await super().arequest(*args, **kwargs)
async def _open_socket(self, *args, **kwargs):
retval = await super()._open_socket(*args, **kwargs)
if self._request_delay:
await asyncio.sleep(self._request_delay)
return retval
class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool):
def __init__(self, request_delay=None, *args, **kwargs):
self._request_delay = request_delay
super().__init__(*args, **kwargs)
async def _add_to_pool(self, connection, timeout):
connection.__class__ = DelayableHTTPConnection
connection._request_delay = self._request_delay
await super()._add_to_pool(connection, timeout)
class DelayableSanicSession(httpx.AsyncClient):
def __init__(self, request_delay=None, *args, **kwargs) -> None:
transport = DelayableSanicConnectionPool(request_delay=request_delay)
super().__init__(transport=transport, *args, **kwargs)
class DelayableSanicTestClient(SanicTestClient):
def __init__(self, app, request_delay=None):
super().__init__(app)
self._request_delay = request_delay
self._loop = None
def get_new_session(self):
return DelayableSanicSession(request_delay=self._request_delay)
@pytest.fixture
def request_no_timeout_app():
app = Sanic("test_request_no_timeout")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler2(request):
return text("OK")
return app
@pytest.fixture
def request_timeout_default_app():
app = Sanic("test_request_timeout_default")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler1(request):
return text("OK")
@app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
return app
def test_default_server_error_request_timeout(request_timeout_default_app):
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/1")
assert response.status == 408
assert "Request Timeout" in response.text
def test_default_server_error_request_dont_timeout(request_no_timeout_app):
client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
_, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
def test_default_server_error_websocket_request_timeout(
request_timeout_default_app,
):
headers = {
"Upgrade": "websocket",
"Connection": "upgrade",
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version": "13",
}
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/ws1", headers=headers)
assert response.status == 408
assert "Request Timeout" in response.text

View File

@ -26,6 +26,7 @@ def protocol(app, mock_transport):
protocol = HttpProtocol(loop=loop, app=app) protocol = HttpProtocol(loop=loop, app=app)
protocol.connection_made(mock_transport) protocol.connection_made(mock_transport)
protocol._setup_connection() protocol._setup_connection()
protocol._http.init_for_request()
protocol._task = Mock(spec=asyncio.Task) protocol._task = Mock(spec=asyncio.Task)
protocol._task.cancel = Mock() protocol._task.cancel = Mock()
return protocol return protocol