diff --git a/sanic/app.py b/sanic/app.py index 5ad99982..88ae1c6e 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -72,6 +72,7 @@ from sanic.models.futures import ( FutureException, FutureListener, FutureMiddleware, + FutureRegistry, FutureRoute, FutureSignal, FutureStatic, @@ -115,6 +116,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "_future_exceptions", "_future_listeners", "_future_middleware", + "_future_registry", "_future_routes", "_future_signals", "_future_statics", @@ -187,17 +189,18 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self._test_manager: Any = None self._blueprint_order: List[Blueprint] = [] self._delayed_tasks: List[str] = [] + self._future_registry: FutureRegistry = FutureRegistry() self._state: ApplicationState = ApplicationState(app=self) self.blueprints: Dict[str, Blueprint] = {} self.config: Config = config or Config( - load_env=load_env, env_prefix=env_prefix + load_env=load_env, + env_prefix=env_prefix, + app=self, ) self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() self.debug = False - self.error_handler: ErrorHandler = error_handler or ErrorHandler( - fallback=self.config.FALLBACK_ERROR_FORMAT, - ) + self.error_handler: ErrorHandler = error_handler or ErrorHandler() self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} @@ -957,6 +960,10 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): # Execution # -------------------------------------------------------------------- # + def make_coffee(self, *args, **kwargs): + self.state.coffee = True + self.run(*args, **kwargs) + def run( self, host: Optional[str] = None, @@ -1569,7 +1576,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): extra.update(self.config.MOTD_DISPLAY) logo = ( - get_logo() + get_logo(coffee=self.state.coffee) if self.config.LOGO == "" or self.config.LOGO is True else self.config.LOGO ) @@ -1635,9 +1642,12 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): raise e async def _startup(self): + self._future_registry.clear() self.signalize() self.finalize() - ErrorHandler.finalize(self.error_handler) + ErrorHandler.finalize( + self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT + ) TouchUp.run(self) async def _server_event( diff --git a/sanic/application/logo.py b/sanic/application/logo.py index 9e3bb2fa..56b8c0b1 100644 --- a/sanic/application/logo.py +++ b/sanic/application/logo.py @@ -10,6 +10,15 @@ BASE_LOGO = """ Build Fast. Run Fast. """ +COFFEE_LOGO = """\033[48;2;255;13;104m \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ▄████████▄ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ██ ██▀▀▄ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ███████████ █ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ███████████▄▄▀ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ▀███████▀ \033[0m +\033[48;2;255;13;104m \033[0m +Dark roast. No sugar.""" + COLOR_LOGO = """\033[48;2;255;13;104m \033[0m \033[38;2;255;255;255;48;2;255;13;104m ▄███ █████ ██ \033[0m \033[38;2;255;255;255;48;2;255;13;104m ██ \033[0m @@ -32,9 +41,9 @@ FULL_COLOR_LOGO = """ ansi_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") -def get_logo(full=False): +def get_logo(full=False, coffee=False): logo = ( - (FULL_COLOR_LOGO if full else COLOR_LOGO) + (FULL_COLOR_LOGO if full else (COFFEE_LOGO if coffee else COLOR_LOGO)) if sys.stdout.isatty() else BASE_LOGO ) diff --git a/sanic/application/state.py b/sanic/application/state.py index b03c30da..eb180708 100644 --- a/sanic/application/state.py +++ b/sanic/application/state.py @@ -34,6 +34,7 @@ class Mode(StrEnum): class ApplicationState: app: Sanic asgi: bool = field(default=False) + coffee: bool = field(default=False) fast: bool = field(default=False) host: str = field(default="") mode: Mode = field(default=Mode.PRODUCTION) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index e5e1d333..290773fa 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -4,6 +4,9 @@ import asyncio from collections import defaultdict from copy import deepcopy +from functools import wraps +from inspect import isfunction +from itertools import chain from types import SimpleNamespace from typing import ( TYPE_CHECKING, @@ -12,7 +15,9 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, + Tuple, Union, ) @@ -35,6 +40,32 @@ if TYPE_CHECKING: from sanic import Sanic # noqa +def lazy(func, as_decorator=True): + @wraps(func) + def decorator(bp, *args, **kwargs): + nonlocal as_decorator + kwargs["apply"] = False + pass_handler = None + + if args and isfunction(args[0]): + as_decorator = False + + def wrapper(handler): + future = func(bp, *args, **kwargs) + if as_decorator: + future = future(handler) + + if bp.registered: + for app in bp.apps: + bp.register(app, {}) + + return future + + return wrapper if as_decorator else wrapper(pass_handler) + + return decorator + + class Blueprint(BaseSanic): """ In *Sanic* terminology, a **Blueprint** is a logical collection of @@ -124,29 +155,16 @@ class Blueprint(BaseSanic): ) return self._apps - def route(self, *args, **kwargs): - kwargs["apply"] = False - return super().route(*args, **kwargs) + @property + def registered(self) -> bool: + return bool(self._apps) - def static(self, *args, **kwargs): - kwargs["apply"] = False - return super().static(*args, **kwargs) - - def middleware(self, *args, **kwargs): - kwargs["apply"] = False - return super().middleware(*args, **kwargs) - - def listener(self, *args, **kwargs): - kwargs["apply"] = False - return super().listener(*args, **kwargs) - - def exception(self, *args, **kwargs): - kwargs["apply"] = False - return super().exception(*args, **kwargs) - - def signal(self, event: str, *args, **kwargs): - kwargs["apply"] = False - return super().signal(event, *args, **kwargs) + exception = lazy(BaseSanic.exception) + listener = lazy(BaseSanic.listener) + middleware = lazy(BaseSanic.middleware) + route = lazy(BaseSanic.route) + signal = lazy(BaseSanic.signal) + static = lazy(BaseSanic.static, as_decorator=False) def reset(self): self._apps: Set[Sanic] = set() @@ -283,6 +301,7 @@ class Blueprint(BaseSanic): middleware = [] exception_handlers = [] listeners = defaultdict(list) + registered = set() # Routes for future in self._future_routes: @@ -309,12 +328,15 @@ class Blueprint(BaseSanic): ) name = app._generate_name(future.name) + host = future.host or self.host + if isinstance(host, list): + host = tuple(host) apply_route = FutureRoute( future.handler, uri[1:] if uri.startswith("//") else uri, future.methods, - future.host or self.host, + host, strict_slashes, future.stream, version, @@ -328,6 +350,10 @@ class Blueprint(BaseSanic): error_format, ) + if (self, apply_route) in app._future_registry: + continue + + registered.add(apply_route) route = app._apply_route(apply_route) operation = ( routes.extend if isinstance(route, list) else routes.append @@ -339,6 +365,11 @@ class Blueprint(BaseSanic): # Prepend the blueprint URI prefix if available uri = url_prefix + future.uri if url_prefix else future.uri apply_route = FutureStatic(uri, *future[1:]) + + if (self, apply_route) in app._future_registry: + continue + + registered.add(apply_route) route = app._apply_static(apply_route) routes.append(route) @@ -347,30 +378,51 @@ class Blueprint(BaseSanic): if route_names: # Middleware for future in self._future_middleware: + if (self, future) in app._future_registry: + continue middleware.append(app._apply_middleware(future, route_names)) # Exceptions for future in self._future_exceptions: + if (self, future) in app._future_registry: + continue exception_handlers.append( app._apply_exception_handler(future, route_names) ) # Event listeners - for listener in self._future_listeners: - listeners[listener.event].append(app._apply_listener(listener)) + for future in self._future_listeners: + if (self, future) in app._future_registry: + continue + listeners[future.event].append(app._apply_listener(future)) # Signals - for signal in self._future_signals: - signal.condition.update({"blueprint": self.name}) - app._apply_signal(signal) + for future in self._future_signals: + if (self, future) in app._future_registry: + continue + future.condition.update({"blueprint": self.name}) + app._apply_signal(future) - self.routes = [route for route in routes if isinstance(route, Route)] - self.websocket_routes = [ + self.routes += [route for route in routes if isinstance(route, Route)] + self.websocket_routes += [ route for route in self.routes if route.ctx.websocket ] - self.middlewares = middleware - self.exceptions = exception_handlers - self.listeners = dict(listeners) + self.middlewares += middleware + self.exceptions += exception_handlers + self.listeners.update(dict(listeners)) + + if self.registered: + self.register_futures( + self.apps, + self, + chain( + registered, + self._future_middleware, + self._future_exceptions, + self._future_listeners, + self._future_signals, + ), + ) async def dispatch(self, *args, **kwargs): condition = kwargs.pop("condition", {}) @@ -402,3 +454,10 @@ class Blueprint(BaseSanic): value = v break return value + + @staticmethod + def register_futures( + apps: Set[Sanic], bp: Blueprint, futures: Sequence[Tuple[Any, ...]] + ): + for app in apps: + app._future_registry.update(set((bp, item) for item in futures)) diff --git a/sanic/config.py b/sanic/config.py index 3961d91d..e08b3f60 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from inspect import isclass from os import environ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from warnings import warn from sanic.errorpages import check_error_format @@ -9,6 +11,10 @@ from sanic.http import Http from sanic.utils import load_module_from_file_location, str_to_bool +if TYPE_CHECKING: # no cov + from sanic import Sanic + + SANIC_PREFIX = "SANIC_" @@ -75,10 +81,13 @@ class Config(dict): load_env: Optional[Union[bool, str]] = True, env_prefix: Optional[str] = SANIC_PREFIX, keep_alive: Optional[bool] = None, + *, + app: Optional[Sanic] = None, ): defaults = defaults or {} super().__init__({**DEFAULT_CONFIG, **defaults}) + self._app = app self._LOGO = "" if keep_alive is not None: @@ -101,6 +110,7 @@ class Config(dict): self._configure_header_size() self._check_error_format() + self._init = True def __getattr__(self, attr): try: @@ -108,23 +118,47 @@ class Config(dict): except KeyError as ke: raise AttributeError(f"Config has no '{ke.args[0]}'") - def __setattr__(self, attr, value): - self[attr] = value - if attr in ( - "REQUEST_MAX_HEADER_SIZE", - "REQUEST_BUFFER_SIZE", - "REQUEST_MAX_SIZE", - ): - self._configure_header_size() - elif attr == "FALLBACK_ERROR_FORMAT": - self._check_error_format() - elif attr == "LOGO": - self._LOGO = value - warn( - "Setting the config.LOGO is deprecated and will no longer " - "be supported starting in v22.6.", - DeprecationWarning, - ) + def __setattr__(self, attr, value) -> None: + self.update({attr: value}) + + def __setitem__(self, attr, value) -> None: + self.update({attr: value}) + + def update(self, *other, **kwargs) -> None: + other_mapping = {k: v for item in other for k, v in dict(item).items()} + super().update(*other, **kwargs) + for attr, value in {**other_mapping, **kwargs}.items(): + self._post_set(attr, value) + + def _post_set(self, attr, value) -> None: + if self.get("_init"): + if attr in ( + "REQUEST_MAX_HEADER_SIZE", + "REQUEST_BUFFER_SIZE", + "REQUEST_MAX_SIZE", + ): + self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() + if self.app and value != self.app.error_handler.fallback: + if self.app.error_handler.fallback != "auto": + warn( + "Overriding non-default ErrorHandler fallback " + "value. Changing from " + f"{self.app.error_handler.fallback} to {value}." + ) + self.app.error_handler.fallback = value + elif attr == "LOGO": + self._LOGO = value + warn( + "Setting the config.LOGO is deprecated and will no longer " + "be supported starting in v22.6.", + DeprecationWarning, + ) + + @property + def app(self): + return self._app @property def LOGO(self): diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 82cdd57a..66ff6c95 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -25,12 +25,13 @@ from sanic.request import Request from sanic.response import HTTPResponse, html, json, text +dumps: t.Callable[..., str] try: from ujson import dumps dumps = partial(dumps, escape_forward_slashes=False) except ImportError: # noqa - from json import dumps # type: ignore + from json import dumps FALLBACK_TEXT = ( @@ -45,6 +46,8 @@ class BaseRenderer: Base class that all renderers must inherit from. """ + dumps = staticmethod(dumps) + def __init__(self, request, exception, debug): self.request = request self.exception = exception @@ -112,14 +115,16 @@ class HTMLRenderer(BaseRenderer): TRACEBACK_STYLE = """ html { font-family: sans-serif } h2 { color: #888; } - .tb-wrapper p { margin: 0 } + .tb-wrapper p, dl, dd { margin: 0 } .frame-border { margin: 1rem } - .frame-line > * { padding: 0.3rem 0.6rem } - .frame-line { margin-bottom: 0.3rem } - .frame-code { font-size: 16px; padding-left: 4ch } - .tb-wrapper { border: 1px solid #eee } - .tb-header { background: #eee; padding: 0.3rem; font-weight: bold } - .frame-descriptor { background: #e2eafb; font-size: 14px } + .frame-line > *, dt, dd { padding: 0.3rem 0.6rem } + .frame-line, dl { margin-bottom: 0.3rem } + .frame-code, dd { font-size: 16px; padding-left: 4ch } + .tb-wrapper, dl { border: 1px solid #eee } + .tb-header,.obj-header { + background: #eee; padding: 0.3rem; font-weight: bold + } + .frame-descriptor, dt { background: #e2eafb; font-size: 14px } """ TRACEBACK_WRAPPER_HTML = ( "
{exc_name}: {exc_value}
" @@ -138,6 +143,11 @@ class HTMLRenderer(BaseRenderer): "

{0.line}" "" ) + OBJECT_WRAPPER_HTML = ( + "

{title}
" + "
{display_html}
" + ) + OBJECT_DISPLAY_HTML = "
{key}
{value}
" OUTPUT_HTML = ( "" "{title}\n" @@ -152,7 +162,7 @@ class HTMLRenderer(BaseRenderer): title=self.title, text=self.text, style=self.TRACEBACK_STYLE, - body=self._generate_body(), + body=self._generate_body(full=True), ), status=self.status, ) @@ -163,7 +173,7 @@ class HTMLRenderer(BaseRenderer): title=self.title, text=self.text, style=self.TRACEBACK_STYLE, - body="", + body=self._generate_body(full=False), ), status=self.status, headers=self.headers, @@ -177,27 +187,49 @@ class HTMLRenderer(BaseRenderer): def title(self): return escape(f"⚠️ {super().title}") - def _generate_body(self): - _, exc_value, __ = sys.exc_info() - exceptions = [] - while exc_value: - exceptions.append(self._format_exc(exc_value)) - exc_value = exc_value.__cause__ + def _generate_body(self, *, full): + lines = [] + if full: + _, exc_value, __ = sys.exc_info() + exceptions = [] + while exc_value: + exceptions.append(self._format_exc(exc_value)) + exc_value = exc_value.__cause__ + + traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions)) + appname = escape(self.request.app.name) + name = escape(self.exception.__class__.__name__) + value = escape(self.exception) + path = escape(self.request.path) + lines += [ + f"

Traceback of {appname} " "(most recent call last):

", + f"{traceback_html}", + "

", + f"{name}: {value} " + f"while handling path {path}", + "

", + ] + + for attr, display in (("context", True), ("extra", bool(full))): + info = getattr(self.exception, attr, None) + if info and display: + lines.append(self._generate_object_display(info, attr)) - traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions)) - appname = escape(self.request.app.name) - name = escape(self.exception.__class__.__name__) - value = escape(self.exception) - path = escape(self.request.path) - lines = [ - f"

Traceback of {appname} (most recent call last):

", - f"{traceback_html}", - "

", - f"{name}: {value} while handling path {path}", - "

", - ] return "\n".join(lines) + def _generate_object_display( + self, obj: t.Dict[str, t.Any], descriptor: str + ) -> str: + display = "".join( + self.OBJECT_DISPLAY_HTML.format(key=key, value=value) + for key, value in obj.items() + ) + return self.OBJECT_WRAPPER_HTML.format( + title=descriptor.title(), + display_html=display, + obj_type=descriptor.lower(), + ) + def _format_exc(self, exc): frames = extract_tb(exc.__traceback__) frame_html = "".join( @@ -224,7 +256,7 @@ class TextRenderer(BaseRenderer): title=self.title, text=self.text, bar=("=" * len(self.title)), - body=self._generate_body(), + body=self._generate_body(full=True), ), status=self.status, ) @@ -235,7 +267,7 @@ class TextRenderer(BaseRenderer): title=self.title, text=self.text, bar=("=" * len(self.title)), - body="", + body=self._generate_body(full=False), ), status=self.status, headers=self.headers, @@ -245,21 +277,31 @@ class TextRenderer(BaseRenderer): def title(self): return f"⚠️ {super().title}" - def _generate_body(self): - _, exc_value, __ = sys.exc_info() - exceptions = [] + def _generate_body(self, *, full): + lines = [] + if full: + _, exc_value, __ = sys.exc_info() + exceptions = [] - lines = [ - f"{self.exception.__class__.__name__}: {self.exception} while " - f"handling path {self.request.path}", - f"Traceback of {self.request.app.name} (most recent call last):\n", - ] + lines += [ + f"{self.exception.__class__.__name__}: {self.exception} while " + f"handling path {self.request.path}", + f"Traceback of {self.request.app.name} " + "(most recent call last):\n", + ] - while exc_value: - exceptions.append(self._format_exc(exc_value)) - exc_value = exc_value.__cause__ + while exc_value: + exceptions.append(self._format_exc(exc_value)) + exc_value = exc_value.__cause__ - return "\n".join(lines + exceptions[::-1]) + lines += exceptions[::-1] + + for attr, display in (("context", True), ("extra", bool(full))): + info = getattr(self.exception, attr, None) + if info and display: + lines += self._generate_object_display_list(info, attr) + + return "\n".join(lines) def _format_exc(self, exc): frames = "\n\n".join( @@ -272,6 +314,13 @@ class TextRenderer(BaseRenderer): ) return f"{self.SPACER}{exc.__class__.__name__}: {exc}\n{frames}" + def _generate_object_display_list(self, obj, descriptor): + lines = [f"\n{descriptor.title()}"] + for key, value in obj.items(): + display = self.dumps(value) + lines.append(f"{self.SPACER * 2}{key}: {display}") + return lines + class JSONRenderer(BaseRenderer): """ @@ -280,11 +329,11 @@ class JSONRenderer(BaseRenderer): def full(self) -> HTTPResponse: output = self._generate_output(full=True) - return json(output, status=self.status, dumps=dumps) + return json(output, status=self.status, dumps=self.dumps) def minimal(self) -> HTTPResponse: output = self._generate_output(full=False) - return json(output, status=self.status, dumps=dumps) + return json(output, status=self.status, dumps=self.dumps) def _generate_output(self, *, full): output = { @@ -293,6 +342,11 @@ class JSONRenderer(BaseRenderer): "message": self.text, } + for attr, display in (("context", True), ("extra", bool(full))): + info = getattr(self.exception, attr, None) + if info and display: + output[attr] = info + if full: _, exc_value, __ = sys.exc_info() exceptions = [] @@ -393,7 +447,8 @@ def exception_response( # from the route if request.route: try: - render_format = request.route.ctx.error_format + if request.route.ctx.error_format: + render_format = request.route.ctx.error_format except AttributeError: ... diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 1bb06f1d..6459f15a 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Any, Dict, Optional, Union from sanic.helpers import STATUS_CODES @@ -11,7 +11,11 @@ class SanicException(Exception): message: Optional[Union[str, bytes]] = None, status_code: Optional[int] = None, quiet: Optional[bool] = None, + context: Optional[Dict[str, Any]] = None, + extra: Optional[Dict[str, Any]] = None, ) -> None: + self.context = context + self.extra = extra if message is None: if self.message: message = self.message diff --git a/sanic/handlers.py b/sanic/handlers.py index af667c9a..046e56e1 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -38,7 +38,14 @@ class ErrorHandler: self.base = base @classmethod - def finalize(cls, error_handler): + def finalize(cls, error_handler, fallback: Optional[str] = None): + if ( + fallback + and fallback != "auto" + and error_handler.fallback == "auto" + ): + error_handler.fallback = fallback + if not isinstance(error_handler, cls): error_logger.warning( f"Error handler is non-conforming: {type(error_handler)}" diff --git a/sanic/http.py b/sanic/http.py index d30e4c82..6f59ef25 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta): self.keep_alive = True self.stage: Stage = Stage.IDLE self.dispatch = self.protocol.app.dispatch - self.init_for_request() def init_for_request(self): """Init/reset all per-request variables.""" @@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta): """ HTTP 1.1 connection handler """ - while True: # As long as connection stays keep-alive + # Handle requests while the connection stays reusable + while self.keep_alive and self.stage is Stage.IDLE: + self.init_for_request() + # Wait for incoming bytes (in IDLE stage) + if not self.recv_buffer: + await self._receive_more() + self.stage = Stage.REQUEST try: # Receive and handle a request - self.stage = Stage.REQUEST self.response_func = self.http1_response_header await self.http1_request_header() + self.stage = Stage.HANDLER self.request.conn_info = self.protocol.conn_info await self.protocol.request_handler(self.request) @@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta): if self.response: self.response.stream = None - # Exit and disconnect if no more requests can be taken - if self.stage is not Stage.IDLE or not self.keep_alive: - break - - self.init_for_request() - - # Wait for the next request - if not self.recv_buffer: - await self._receive_more() - async def http1_request_header(self): # no cov """ Receive and parse request header into self.request. @@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta): # Remove header and its trailing CRLF del buf[: pos + 4] - self.stage = Stage.HANDLER self.request, request.stream = request, self self.protocol.state["requests_count"] += 1 diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 8467a2e3..7139cd3c 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -918,7 +918,7 @@ class RouteMixin: return route - def _determine_error_format(self, handler) -> str: + def _determine_error_format(self, handler) -> Optional[str]: if not isinstance(handler, CompositionView): try: src = dedent(getsource(handler)) @@ -930,7 +930,7 @@ class RouteMixin: except (OSError, TypeError): ... - return "auto" + return None def _get_response_types(self, node): types = set() diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index 2be9fee2..57b01b46 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Optional, Set +from enum import Enum +from typing import Any, Callable, Dict, Optional, Set, Union from sanic.models.futures import FutureSignal from sanic.models.handler_types import SignalHandler @@ -19,7 +20,7 @@ class SignalMixin: def signal( self, - event: str, + event: Union[str, Enum], *, apply: bool = True, condition: Dict[str, Any] = None, @@ -41,13 +42,11 @@ class SignalMixin: filtering, defaults to None :type condition: Dict[str, Any], optional """ + event_value = str(event.value) if isinstance(event, Enum) else event def decorator(handler: SignalHandler): - nonlocal event - nonlocal apply - future_signal = FutureSignal( - handler, event, HashableDict(condition or {}) + handler, event_value, HashableDict(condition or {}) ) self._future_signals.add(future_signal) diff --git a/sanic/models/futures.py b/sanic/models/futures.py index fe7d77eb..74ee92b9 100644 --- a/sanic/models/futures.py +++ b/sanic/models/futures.py @@ -60,3 +60,7 @@ class FutureSignal(NamedTuple): handler: SignalHandler event: str condition: Optional[Dict[str, str]] + + +class FutureRegistry(set): + ... diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index 3a91a8f0..3c726edb 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -47,16 +47,18 @@ def _get_args_for_reloading(): return [sys.executable] + sys.argv -def restart_with_reloader(): +def restart_with_reloader(changed=None): """Create a new process and a subprocess in it with the same arguments as this one. """ + reloaded = ",".join(changed) if changed else "" return subprocess.Popen( _get_args_for_reloading(), env={ **os.environ, "SANIC_SERVER_RUNNING": "true", "SANIC_RELOADER_PROCESS": "true", + "SANIC_RELOADED_FILES": reloaded, }, ) @@ -94,24 +96,27 @@ def watchdog(sleep_interval, app): try: while True: - need_reload = False + changed = set() for filename in itertools.chain( _iter_module_files(), *(d.glob("**/*") for d in app.reload_dirs), ): try: - check = _check_file(filename, mtimes) + if _check_file(filename, mtimes): + path = ( + filename + if isinstance(filename, str) + else filename.resolve() + ) + changed.add(str(path)) except OSError: continue - if check: - need_reload = True - - if need_reload: + if changed: worker_process.terminate() worker_process.wait() - worker_process = restart_with_reloader() + worker_process = restart_with_reloader(changed) sleep(sleep_interval) except KeyboardInterrupt: diff --git a/sanic/router.py b/sanic/router.py index 6995ed6d..b15c2a3e 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -139,11 +139,10 @@ class Router(BaseRouter): route.ctx.stream = stream route.ctx.hosts = hosts route.ctx.static = static - route.ctx.error_format = ( - error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT - ) + route.ctx.error_format = error_format - check_error_format(route.ctx.error_format) + if error_format: + check_error_format(route.ctx.error_format) routes.append(route) diff --git a/sanic/signals.py b/sanic/signals.py index 9da7eccd..7bb510fa 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from enum import Enum from inspect import isawaitable from typing import Any, Dict, List, Optional, Tuple, Union @@ -14,29 +15,47 @@ from sanic.log import error_logger, logger from sanic.models.handler_types import SignalHandler +class Event(Enum): + SERVER_INIT_AFTER = "server.init.after" + SERVER_INIT_BEFORE = "server.init.before" + SERVER_SHUTDOWN_AFTER = "server.shutdown.after" + SERVER_SHUTDOWN_BEFORE = "server.shutdown.before" + HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin" + HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete" + HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception" + HTTP_LIFECYCLE_HANDLE = "http.lifecycle.handle" + HTTP_LIFECYCLE_READ_BODY = "http.lifecycle.read_body" + HTTP_LIFECYCLE_READ_HEAD = "http.lifecycle.read_head" + HTTP_LIFECYCLE_REQUEST = "http.lifecycle.request" + HTTP_LIFECYCLE_RESPONSE = "http.lifecycle.response" + HTTP_ROUTING_AFTER = "http.routing.after" + HTTP_ROUTING_BEFORE = "http.routing.before" + HTTP_LIFECYCLE_SEND = "http.lifecycle.send" + HTTP_MIDDLEWARE_AFTER = "http.middleware.after" + HTTP_MIDDLEWARE_BEFORE = "http.middleware.before" + + RESERVED_NAMESPACES = { "server": ( - # "server.main.start", - # "server.main.stop", - "server.init.before", - "server.init.after", - "server.shutdown.before", - "server.shutdown.after", + Event.SERVER_INIT_AFTER.value, + Event.SERVER_INIT_BEFORE.value, + Event.SERVER_SHUTDOWN_AFTER.value, + Event.SERVER_SHUTDOWN_BEFORE.value, ), "http": ( - "http.lifecycle.begin", - "http.lifecycle.complete", - "http.lifecycle.exception", - "http.lifecycle.handle", - "http.lifecycle.read_body", - "http.lifecycle.read_head", - "http.lifecycle.request", - "http.lifecycle.response", - "http.routing.after", - "http.routing.before", - "http.lifecycle.send", - "http.middleware.after", - "http.middleware.before", + Event.HTTP_LIFECYCLE_BEGIN.value, + Event.HTTP_LIFECYCLE_COMPLETE.value, + Event.HTTP_LIFECYCLE_EXCEPTION.value, + Event.HTTP_LIFECYCLE_HANDLE.value, + Event.HTTP_LIFECYCLE_READ_BODY.value, + Event.HTTP_LIFECYCLE_READ_HEAD.value, + Event.HTTP_LIFECYCLE_REQUEST.value, + Event.HTTP_LIFECYCLE_RESPONSE.value, + Event.HTTP_ROUTING_AFTER.value, + Event.HTTP_ROUTING_BEFORE.value, + Event.HTTP_LIFECYCLE_SEND.value, + Event.HTTP_MIDDLEWARE_AFTER.value, + Event.HTTP_MIDDLEWARE_BEFORE.value, ), } diff --git a/tests/test_blueprint_copy.py b/tests/test_blueprint_copy.py index 033e2e20..ca8cd67e 100644 --- a/tests/test_blueprint_copy.py +++ b/tests/test_blueprint_copy.py @@ -1,6 +1,4 @@ -from copy import deepcopy - -from sanic import Blueprint, Sanic, blueprints, response +from sanic import Blueprint, Sanic from sanic.response import text diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index b6a23151..3aa4487a 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -1088,3 +1088,31 @@ def test_bp_set_attribute_warning(): "and will be removed in version 21.12. You should change your " "Blueprint instance to use instance.ctx.foo instead." ) + + +def test_early_registration(app): + assert len(app.router.routes) == 0 + + bp = Blueprint("bp") + + @bp.get("/one") + async def one(_): + return text("one") + + app.blueprint(bp) + + assert len(app.router.routes) == 1 + + @bp.get("/two") + async def two(_): + return text("two") + + @bp.get("/three") + async def three(_): + return text("three") + + assert len(app.router.routes) == 3 + + for path in ("one", "two", "three"): + _, response = app.test_client.get(f"/{path}") + assert response.text == path diff --git a/tests/test_coffee.py b/tests/test_coffee.py new file mode 100644 index 00000000..6143f17f --- /dev/null +++ b/tests/test_coffee.py @@ -0,0 +1,48 @@ +import logging + +from unittest.mock import patch + +import pytest + +from sanic.application.logo import COFFEE_LOGO, get_logo +from sanic.exceptions import SanicException + + +def has_sugar(value): + if value: + raise SanicException("I said no sugar please") + + return False + + +@pytest.mark.parametrize("sugar", (True, False)) +def test_no_sugar(sugar): + if sugar: + with pytest.raises(SanicException): + assert has_sugar(sugar) + else: + assert not has_sugar(sugar) + + +def test_get_logo_returns_expected_logo(): + with patch("sys.stdout.isatty") as isatty: + isatty.return_value = True + logo = get_logo(coffee=True) + assert logo is COFFEE_LOGO + + +def test_logo_true(app, caplog): + @app.after_server_start + async def shutdown(*_): + app.stop() + + with patch("sys.stdout.isatty") as isatty: + isatty.return_value = True + with caplog.at_level(logging.DEBUG): + app.make_coffee() + + # Only in the regular logo + assert " ▄███ █████ ██ " not in caplog.text + + # Only in the coffee logo + assert " ██ ██▀▀▄ " in caplog.text diff --git a/tests/test_config.py b/tests/test_config.py index f3447666..67324f1e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,9 @@ from contextlib import contextmanager -from email import message from os import environ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent +from unittest.mock import Mock import pytest @@ -360,3 +360,31 @@ def test_deprecation_notice_when_setting_logo(app): ) with pytest.warns(DeprecationWarning, match=message): app.config.LOGO = "My Custom Logo" + + +def test_config_set_methods(app, monkeypatch): + post_set = Mock() + monkeypatch.setattr(Config, "_post_set", post_set) + + app.config.FOO = 1 + post_set.assert_called_once_with("FOO", 1) + post_set.reset_mock() + + app.config["FOO"] = 2 + post_set.assert_called_once_with("FOO", 2) + post_set.reset_mock() + + app.config.update({"FOO": 3}) + post_set.assert_called_once_with("FOO", 3) + post_set.reset_mock() + + app.config.update([("FOO", 4)]) + post_set.assert_called_once_with("FOO", 4) + post_set.reset_mock() + + app.config.update(FOO=5) + post_set.assert_called_once_with("FOO", 5) + post_set.reset_mock() + + app.config.update_config({"FOO": 6}) + post_set.assert_called_once_with("FOO", 6) diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 5af4ca5f..1843f6a7 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,8 +1,10 @@ import pytest from sanic import Sanic +from sanic.config import Config from sanic.errorpages import HTMLRenderer, exception_response from sanic.exceptions import NotFound, SanicException +from sanic.handlers import ErrorHandler from sanic.request import Request from sanic.response import HTTPResponse, html, json, text @@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected): ) assert response.content_type == expected + + +def test_allow_fallback_error_format_set_main_process_start(app): + @app.main_process_start + async def start(app, _): + app.config.FALLBACK_ERROR_FORMAT = "text" + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_setting_fallback_to_non_default_raise_warning(app): + app.error_handler = ErrorHandler(fallback="text") + + assert app.error_handler.fallback == "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to auto." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + assert app.error_handler.fallback == "auto" + + app.config.FALLBACK_ERROR_FORMAT = "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to json." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "json" + + assert app.error_handler.fallback == "json" + + +def test_allow_fallback_error_format_in_config_injection(): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app = Sanic("test", config=MyConfig()) + + @app.route("/error", methods=["GET", "POST"]) + def err(request): + raise Exception("something went wrong") + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_allow_fallback_error_format_in_config_replacement(app): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app.config = MyConfig() + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 0485137a..eea97935 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -18,6 +18,16 @@ from sanic.exceptions import ( from sanic.response import text +def dl_to_dict(soup, css_class): + keys, values = [], [] + for dl in soup.find_all("dl", {"class": css_class}): + for dt in dl.find_all("dt"): + keys.append(dt.text.strip()) + for dd in dl.find_all("dd"): + values.append(dd.text.strip()) + return dict(zip(keys, values)) + + class SanicExceptionTestException(Exception): pass @@ -264,3 +274,110 @@ def test_exception_in_ws_logged(caplog): error_logs = [r for r in caplog.record_tuples if r[0] == "sanic.error"] assert error_logs[1][1] == logging.ERROR assert "Exception occurred while handling uri:" in error_logs[1][2] + + +@pytest.mark.parametrize("debug", (True, False)) +def test_contextual_exception_context(debug): + app = Sanic(__name__) + + class TeapotError(SanicException): + status_code = 418 + message = "Sorry, I cannot brew coffee" + + def fail(): + raise TeapotError(context={"foo": "bar"}) + + app.post("/coffee/json", error_format="json")(lambda _: fail()) + app.post("/coffee/html", error_format="html")(lambda _: fail()) + app.post("/coffee/text", error_format="text")(lambda _: fail()) + + _, response = app.test_client.post("/coffee/json", debug=debug) + assert response.status == 418 + assert response.json["message"] == "Sorry, I cannot brew coffee" + assert response.json["context"] == {"foo": "bar"} + + _, response = app.test_client.post("/coffee/html", debug=debug) + soup = BeautifulSoup(response.body, "html.parser") + dl = dl_to_dict(soup, "context") + assert response.status == 418 + assert "Sorry, I cannot brew coffee" in soup.find("p").text + assert dl == {"foo": "bar"} + + _, response = app.test_client.post("/coffee/text", debug=debug) + lines = list(map(lambda x: x.decode(), response.body.split(b"\n"))) + idx = lines.index("Context") + 1 + assert response.status == 418 + assert lines[2] == "Sorry, I cannot brew coffee" + assert lines[idx] == ' foo: "bar"' + + +@pytest.mark.parametrize("debug", (True, False)) +def test_contextual_exception_extra(debug): + app = Sanic(__name__) + + class TeapotError(SanicException): + status_code = 418 + + @property + def message(self): + return f"Found {self.extra['foo']}" + + def fail(): + raise TeapotError(extra={"foo": "bar"}) + + app.post("/coffee/json", error_format="json")(lambda _: fail()) + app.post("/coffee/html", error_format="html")(lambda _: fail()) + app.post("/coffee/text", error_format="text")(lambda _: fail()) + + _, response = app.test_client.post("/coffee/json", debug=debug) + assert response.status == 418 + assert response.json["message"] == "Found bar" + if debug: + assert response.json["extra"] == {"foo": "bar"} + else: + assert "extra" not in response.json + + _, response = app.test_client.post("/coffee/html", debug=debug) + soup = BeautifulSoup(response.body, "html.parser") + dl = dl_to_dict(soup, "extra") + assert response.status == 418 + assert "Found bar" in soup.find("p").text + if debug: + assert dl == {"foo": "bar"} + else: + assert not dl + + _, response = app.test_client.post("/coffee/text", debug=debug) + lines = list(map(lambda x: x.decode(), response.body.split(b"\n"))) + assert response.status == 418 + assert lines[2] == "Found bar" + if debug: + idx = lines.index("Extra") + 1 + assert lines[idx] == ' foo: "bar"' + else: + assert "Extra" not in lines + + +@pytest.mark.parametrize("override", (True, False)) +def test_contextual_exception_functional_message(override): + app = Sanic(__name__) + + class TeapotError(SanicException): + status_code = 418 + + @property + def message(self): + return f"Received foo={self.context['foo']}" + + @app.post("/coffee", error_format="json") + async def make_coffee(_): + error_args = {"context": {"foo": "bar"}} + if override: + error_args["message"] = "override" + raise TeapotError(**error_args) + + _, response = app.test_client.post("/coffee", debug=True) + error_message = "override" if override else "Received foo=bar" + assert response.status == 418 + assert response.json["message"] == error_message + assert response.json["context"] == {"foo": "bar"} diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py deleted file mode 100644 index 48e23f1d..00000000 --- a/tests/test_request_timeout.py +++ /dev/null @@ -1,109 +0,0 @@ -import asyncio - -import httpcore -import httpx -import pytest - -from sanic_testing.testing import SanicTestClient - -from sanic import Sanic -from sanic.response import text - - -class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection): - async def arequest(self, *args, **kwargs): - await asyncio.sleep(2) - return await super().arequest(*args, **kwargs) - - async def _open_socket(self, *args, **kwargs): - retval = await super()._open_socket(*args, **kwargs) - if self._request_delay: - await asyncio.sleep(self._request_delay) - return retval - - -class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool): - def __init__(self, request_delay=None, *args, **kwargs): - self._request_delay = request_delay - super().__init__(*args, **kwargs) - - async def _add_to_pool(self, connection, timeout): - connection.__class__ = DelayableHTTPConnection - connection._request_delay = self._request_delay - await super()._add_to_pool(connection, timeout) - - -class DelayableSanicSession(httpx.AsyncClient): - def __init__(self, request_delay=None, *args, **kwargs) -> None: - transport = DelayableSanicConnectionPool(request_delay=request_delay) - super().__init__(transport=transport, *args, **kwargs) - - -class DelayableSanicTestClient(SanicTestClient): - def __init__(self, app, request_delay=None): - super().__init__(app) - self._request_delay = request_delay - self._loop = None - - def get_new_session(self): - return DelayableSanicSession(request_delay=self._request_delay) - - -@pytest.fixture -def request_no_timeout_app(): - app = Sanic("test_request_no_timeout") - app.config.REQUEST_TIMEOUT = 0.6 - - @app.route("/1") - async def handler2(request): - return text("OK") - - return app - - -@pytest.fixture -def request_timeout_default_app(): - app = Sanic("test_request_timeout_default") - app.config.REQUEST_TIMEOUT = 0.6 - - @app.route("/1") - async def handler1(request): - return text("OK") - - @app.websocket("/ws1") - async def ws_handler1(request, ws): - await ws.send("OK") - - return app - - -def test_default_server_error_request_timeout(request_timeout_default_app): - client = DelayableSanicTestClient(request_timeout_default_app, 2) - _, response = client.get("/1") - assert response.status == 408 - assert "Request Timeout" in response.text - - -def test_default_server_error_request_dont_timeout(request_no_timeout_app): - client = DelayableSanicTestClient(request_no_timeout_app, 0.2) - _, response = client.get("/1") - assert response.status == 200 - assert response.text == "OK" - - -def test_default_server_error_websocket_request_timeout( - request_timeout_default_app, -): - - headers = { - "Upgrade": "websocket", - "Connection": "upgrade", - "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version": "13", - } - - client = DelayableSanicTestClient(request_timeout_default_app, 2) - _, response = client.get("/ws1", headers=headers) - - assert response.status == 408 - assert "Request Timeout" in response.text diff --git a/tests/test_signals.py b/tests/test_signals.py index 9b8a9495..51aea3c8 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,5 +1,6 @@ import asyncio +from enum import Enum from inspect import isawaitable import pytest @@ -50,6 +51,25 @@ def test_invalid_signal(app, signal): ... +@pytest.mark.asyncio +async def test_dispatch_signal_with_enum_event(app): + counter = 0 + + class FooEnum(Enum): + FOO_BAR_BAZ = "foo.bar.baz" + + @app.signal(FooEnum.FOO_BAR_BAZ) + def sync_signal(*_): + nonlocal counter + + counter += 1 + + app.signal_router.finalize() + + await app.dispatch("foo.bar.baz") + assert counter == 1 + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_multiple_handlers(app): counter = 0 diff --git a/tests/test_timeout_logic.py b/tests/test_timeout_logic.py index 05249f11..497deda9 100644 --- a/tests/test_timeout_logic.py +++ b/tests/test_timeout_logic.py @@ -26,6 +26,7 @@ def protocol(app, mock_transport): protocol = HttpProtocol(loop=loop, app=app) protocol.connection_made(mock_transport) protocol._setup_connection() + protocol._http.init_for_request() protocol._task = Mock(spec=asyncio.Task) protocol._task.cancel = Mock() return protocol