diff --git a/sanic/app.py b/sanic/app.py
index 5ad99982..88ae1c6e 100644
--- a/sanic/app.py
+++ b/sanic/app.py
@@ -72,6 +72,7 @@ from sanic.models.futures import (
FutureException,
FutureListener,
FutureMiddleware,
+ FutureRegistry,
FutureRoute,
FutureSignal,
FutureStatic,
@@ -115,6 +116,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"_future_exceptions",
"_future_listeners",
"_future_middleware",
+ "_future_registry",
"_future_routes",
"_future_signals",
"_future_statics",
@@ -187,17 +189,18 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self._test_manager: Any = None
self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
+ self._future_registry: FutureRegistry = FutureRegistry()
self._state: ApplicationState = ApplicationState(app=self)
self.blueprints: Dict[str, Blueprint] = {}
self.config: Config = config or Config(
- load_env=load_env, env_prefix=env_prefix
+ load_env=load_env,
+ env_prefix=env_prefix,
+ app=self,
)
self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace()
self.debug = False
- self.error_handler: ErrorHandler = error_handler or ErrorHandler(
- fallback=self.config.FALLBACK_ERROR_FORMAT,
- )
+ self.error_handler: ErrorHandler = error_handler or ErrorHandler()
self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list)
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
@@ -957,6 +960,10 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
# Execution
# -------------------------------------------------------------------- #
+ def make_coffee(self, *args, **kwargs):
+ self.state.coffee = True
+ self.run(*args, **kwargs)
+
def run(
self,
host: Optional[str] = None,
@@ -1569,7 +1576,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
extra.update(self.config.MOTD_DISPLAY)
logo = (
- get_logo()
+ get_logo(coffee=self.state.coffee)
if self.config.LOGO == "" or self.config.LOGO is True
else self.config.LOGO
)
@@ -1635,9 +1642,12 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
raise e
async def _startup(self):
+ self._future_registry.clear()
self.signalize()
self.finalize()
- ErrorHandler.finalize(self.error_handler)
+ ErrorHandler.finalize(
+ self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT
+ )
TouchUp.run(self)
async def _server_event(
diff --git a/sanic/application/logo.py b/sanic/application/logo.py
index 9e3bb2fa..56b8c0b1 100644
--- a/sanic/application/logo.py
+++ b/sanic/application/logo.py
@@ -10,6 +10,15 @@ BASE_LOGO = """
Build Fast. Run Fast.
"""
+COFFEE_LOGO = """\033[48;2;255;13;104m \033[0m
+\033[38;2;255;255;255;48;2;255;13;104m ▄████████▄ \033[0m
+\033[38;2;255;255;255;48;2;255;13;104m ██ ██▀▀▄ \033[0m
+\033[38;2;255;255;255;48;2;255;13;104m ███████████ █ \033[0m
+\033[38;2;255;255;255;48;2;255;13;104m ███████████▄▄▀ \033[0m
+\033[38;2;255;255;255;48;2;255;13;104m ▀███████▀ \033[0m
+\033[48;2;255;13;104m \033[0m
+Dark roast. No sugar."""
+
COLOR_LOGO = """\033[48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m ▄███ █████ ██ \033[0m
\033[38;2;255;255;255;48;2;255;13;104m ██ \033[0m
@@ -32,9 +41,9 @@ FULL_COLOR_LOGO = """
ansi_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
-def get_logo(full=False):
+def get_logo(full=False, coffee=False):
logo = (
- (FULL_COLOR_LOGO if full else COLOR_LOGO)
+ (FULL_COLOR_LOGO if full else (COFFEE_LOGO if coffee else COLOR_LOGO))
if sys.stdout.isatty()
else BASE_LOGO
)
diff --git a/sanic/application/state.py b/sanic/application/state.py
index b03c30da..eb180708 100644
--- a/sanic/application/state.py
+++ b/sanic/application/state.py
@@ -34,6 +34,7 @@ class Mode(StrEnum):
class ApplicationState:
app: Sanic
asgi: bool = field(default=False)
+ coffee: bool = field(default=False)
fast: bool = field(default=False)
host: str = field(default="")
mode: Mode = field(default=Mode.PRODUCTION)
diff --git a/sanic/blueprints.py b/sanic/blueprints.py
index e5e1d333..290773fa 100644
--- a/sanic/blueprints.py
+++ b/sanic/blueprints.py
@@ -4,6 +4,9 @@ import asyncio
from collections import defaultdict
from copy import deepcopy
+from functools import wraps
+from inspect import isfunction
+from itertools import chain
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
@@ -12,7 +15,9 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
+ Tuple,
Union,
)
@@ -35,6 +40,32 @@ if TYPE_CHECKING:
from sanic import Sanic # noqa
+def lazy(func, as_decorator=True):
+ @wraps(func)
+ def decorator(bp, *args, **kwargs):
+ nonlocal as_decorator
+ kwargs["apply"] = False
+ pass_handler = None
+
+ if args and isfunction(args[0]):
+ as_decorator = False
+
+ def wrapper(handler):
+ future = func(bp, *args, **kwargs)
+ if as_decorator:
+ future = future(handler)
+
+ if bp.registered:
+ for app in bp.apps:
+ bp.register(app, {})
+
+ return future
+
+ return wrapper if as_decorator else wrapper(pass_handler)
+
+ return decorator
+
+
class Blueprint(BaseSanic):
"""
In *Sanic* terminology, a **Blueprint** is a logical collection of
@@ -124,29 +155,16 @@ class Blueprint(BaseSanic):
)
return self._apps
- def route(self, *args, **kwargs):
- kwargs["apply"] = False
- return super().route(*args, **kwargs)
+ @property
+ def registered(self) -> bool:
+ return bool(self._apps)
- def static(self, *args, **kwargs):
- kwargs["apply"] = False
- return super().static(*args, **kwargs)
-
- def middleware(self, *args, **kwargs):
- kwargs["apply"] = False
- return super().middleware(*args, **kwargs)
-
- def listener(self, *args, **kwargs):
- kwargs["apply"] = False
- return super().listener(*args, **kwargs)
-
- def exception(self, *args, **kwargs):
- kwargs["apply"] = False
- return super().exception(*args, **kwargs)
-
- def signal(self, event: str, *args, **kwargs):
- kwargs["apply"] = False
- return super().signal(event, *args, **kwargs)
+ exception = lazy(BaseSanic.exception)
+ listener = lazy(BaseSanic.listener)
+ middleware = lazy(BaseSanic.middleware)
+ route = lazy(BaseSanic.route)
+ signal = lazy(BaseSanic.signal)
+ static = lazy(BaseSanic.static, as_decorator=False)
def reset(self):
self._apps: Set[Sanic] = set()
@@ -283,6 +301,7 @@ class Blueprint(BaseSanic):
middleware = []
exception_handlers = []
listeners = defaultdict(list)
+ registered = set()
# Routes
for future in self._future_routes:
@@ -309,12 +328,15 @@ class Blueprint(BaseSanic):
)
name = app._generate_name(future.name)
+ host = future.host or self.host
+ if isinstance(host, list):
+ host = tuple(host)
apply_route = FutureRoute(
future.handler,
uri[1:] if uri.startswith("//") else uri,
future.methods,
- future.host or self.host,
+ host,
strict_slashes,
future.stream,
version,
@@ -328,6 +350,10 @@ class Blueprint(BaseSanic):
error_format,
)
+ if (self, apply_route) in app._future_registry:
+ continue
+
+ registered.add(apply_route)
route = app._apply_route(apply_route)
operation = (
routes.extend if isinstance(route, list) else routes.append
@@ -339,6 +365,11 @@ class Blueprint(BaseSanic):
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
apply_route = FutureStatic(uri, *future[1:])
+
+ if (self, apply_route) in app._future_registry:
+ continue
+
+ registered.add(apply_route)
route = app._apply_static(apply_route)
routes.append(route)
@@ -347,30 +378,51 @@ class Blueprint(BaseSanic):
if route_names:
# Middleware
for future in self._future_middleware:
+ if (self, future) in app._future_registry:
+ continue
middleware.append(app._apply_middleware(future, route_names))
# Exceptions
for future in self._future_exceptions:
+ if (self, future) in app._future_registry:
+ continue
exception_handlers.append(
app._apply_exception_handler(future, route_names)
)
# Event listeners
- for listener in self._future_listeners:
- listeners[listener.event].append(app._apply_listener(listener))
+ for future in self._future_listeners:
+ if (self, future) in app._future_registry:
+ continue
+ listeners[future.event].append(app._apply_listener(future))
# Signals
- for signal in self._future_signals:
- signal.condition.update({"blueprint": self.name})
- app._apply_signal(signal)
+ for future in self._future_signals:
+ if (self, future) in app._future_registry:
+ continue
+ future.condition.update({"blueprint": self.name})
+ app._apply_signal(future)
- self.routes = [route for route in routes if isinstance(route, Route)]
- self.websocket_routes = [
+ self.routes += [route for route in routes if isinstance(route, Route)]
+ self.websocket_routes += [
route for route in self.routes if route.ctx.websocket
]
- self.middlewares = middleware
- self.exceptions = exception_handlers
- self.listeners = dict(listeners)
+ self.middlewares += middleware
+ self.exceptions += exception_handlers
+ self.listeners.update(dict(listeners))
+
+ if self.registered:
+ self.register_futures(
+ self.apps,
+ self,
+ chain(
+ registered,
+ self._future_middleware,
+ self._future_exceptions,
+ self._future_listeners,
+ self._future_signals,
+ ),
+ )
async def dispatch(self, *args, **kwargs):
condition = kwargs.pop("condition", {})
@@ -402,3 +454,10 @@ class Blueprint(BaseSanic):
value = v
break
return value
+
+ @staticmethod
+ def register_futures(
+ apps: Set[Sanic], bp: Blueprint, futures: Sequence[Tuple[Any, ...]]
+ ):
+ for app in apps:
+ app._future_registry.update(set((bp, item) for item in futures))
diff --git a/sanic/config.py b/sanic/config.py
index 3961d91d..e08b3f60 100644
--- a/sanic/config.py
+++ b/sanic/config.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
from inspect import isclass
from os import environ
from pathlib import Path
-from typing import Any, Dict, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from warnings import warn
from sanic.errorpages import check_error_format
@@ -9,6 +11,10 @@ from sanic.http import Http
from sanic.utils import load_module_from_file_location, str_to_bool
+if TYPE_CHECKING: # no cov
+ from sanic import Sanic
+
+
SANIC_PREFIX = "SANIC_"
@@ -75,10 +81,13 @@ class Config(dict):
load_env: Optional[Union[bool, str]] = True,
env_prefix: Optional[str] = SANIC_PREFIX,
keep_alive: Optional[bool] = None,
+ *,
+ app: Optional[Sanic] = None,
):
defaults = defaults or {}
super().__init__({**DEFAULT_CONFIG, **defaults})
+ self._app = app
self._LOGO = ""
if keep_alive is not None:
@@ -101,6 +110,7 @@ class Config(dict):
self._configure_header_size()
self._check_error_format()
+ self._init = True
def __getattr__(self, attr):
try:
@@ -108,23 +118,47 @@ class Config(dict):
except KeyError as ke:
raise AttributeError(f"Config has no '{ke.args[0]}'")
- def __setattr__(self, attr, value):
- self[attr] = value
- if attr in (
- "REQUEST_MAX_HEADER_SIZE",
- "REQUEST_BUFFER_SIZE",
- "REQUEST_MAX_SIZE",
- ):
- self._configure_header_size()
- elif attr == "FALLBACK_ERROR_FORMAT":
- self._check_error_format()
- elif attr == "LOGO":
- self._LOGO = value
- warn(
- "Setting the config.LOGO is deprecated and will no longer "
- "be supported starting in v22.6.",
- DeprecationWarning,
- )
+ def __setattr__(self, attr, value) -> None:
+ self.update({attr: value})
+
+ def __setitem__(self, attr, value) -> None:
+ self.update({attr: value})
+
+ def update(self, *other, **kwargs) -> None:
+ other_mapping = {k: v for item in other for k, v in dict(item).items()}
+ super().update(*other, **kwargs)
+ for attr, value in {**other_mapping, **kwargs}.items():
+ self._post_set(attr, value)
+
+ def _post_set(self, attr, value) -> None:
+ if self.get("_init"):
+ if attr in (
+ "REQUEST_MAX_HEADER_SIZE",
+ "REQUEST_BUFFER_SIZE",
+ "REQUEST_MAX_SIZE",
+ ):
+ self._configure_header_size()
+ elif attr == "FALLBACK_ERROR_FORMAT":
+ self._check_error_format()
+ if self.app and value != self.app.error_handler.fallback:
+ if self.app.error_handler.fallback != "auto":
+ warn(
+ "Overriding non-default ErrorHandler fallback "
+ "value. Changing from "
+ f"{self.app.error_handler.fallback} to {value}."
+ )
+ self.app.error_handler.fallback = value
+ elif attr == "LOGO":
+ self._LOGO = value
+ warn(
+ "Setting the config.LOGO is deprecated and will no longer "
+ "be supported starting in v22.6.",
+ DeprecationWarning,
+ )
+
+ @property
+ def app(self):
+ return self._app
@property
def LOGO(self):
diff --git a/sanic/errorpages.py b/sanic/errorpages.py
index 82cdd57a..66ff6c95 100644
--- a/sanic/errorpages.py
+++ b/sanic/errorpages.py
@@ -25,12 +25,13 @@ from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text
+dumps: t.Callable[..., str]
try:
from ujson import dumps
dumps = partial(dumps, escape_forward_slashes=False)
except ImportError: # noqa
- from json import dumps # type: ignore
+ from json import dumps
FALLBACK_TEXT = (
@@ -45,6 +46,8 @@ class BaseRenderer:
Base class that all renderers must inherit from.
"""
+ dumps = staticmethod(dumps)
+
def __init__(self, request, exception, debug):
self.request = request
self.exception = exception
@@ -112,14 +115,16 @@ class HTMLRenderer(BaseRenderer):
TRACEBACK_STYLE = """
html { font-family: sans-serif }
h2 { color: #888; }
- .tb-wrapper p { margin: 0 }
+ .tb-wrapper p, dl, dd { margin: 0 }
.frame-border { margin: 1rem }
- .frame-line > * { padding: 0.3rem 0.6rem }
- .frame-line { margin-bottom: 0.3rem }
- .frame-code { font-size: 16px; padding-left: 4ch }
- .tb-wrapper { border: 1px solid #eee }
- .tb-header { background: #eee; padding: 0.3rem; font-weight: bold }
- .frame-descriptor { background: #e2eafb; font-size: 14px }
+ .frame-line > *, dt, dd { padding: 0.3rem 0.6rem }
+ .frame-line, dl { margin-bottom: 0.3rem }
+ .frame-code, dd { font-size: 16px; padding-left: 4ch }
+ .tb-wrapper, dl { border: 1px solid #eee }
+ .tb-header,.obj-header {
+ background: #eee; padding: 0.3rem; font-weight: bold
+ }
+ .frame-descriptor, dt { background: #e2eafb; font-size: 14px }
"""
TRACEBACK_WRAPPER_HTML = (
"
"
@@ -138,6 +143,11 @@ class HTMLRenderer(BaseRenderer):
"{0.line}
"
""
)
+ OBJECT_WRAPPER_HTML = (
+ "
"
+ "{display_html}
"
+ )
+ OBJECT_DISPLAY_HTML = "{key}{value}
"
OUTPUT_HTML = (
""
"{title}\n"
@@ -152,7 +162,7 @@ class HTMLRenderer(BaseRenderer):
title=self.title,
text=self.text,
style=self.TRACEBACK_STYLE,
- body=self._generate_body(),
+ body=self._generate_body(full=True),
),
status=self.status,
)
@@ -163,7 +173,7 @@ class HTMLRenderer(BaseRenderer):
title=self.title,
text=self.text,
style=self.TRACEBACK_STYLE,
- body="",
+ body=self._generate_body(full=False),
),
status=self.status,
headers=self.headers,
@@ -177,27 +187,49 @@ class HTMLRenderer(BaseRenderer):
def title(self):
return escape(f"⚠️ {super().title}")
- def _generate_body(self):
- _, exc_value, __ = sys.exc_info()
- exceptions = []
- while exc_value:
- exceptions.append(self._format_exc(exc_value))
- exc_value = exc_value.__cause__
+ def _generate_body(self, *, full):
+ lines = []
+ if full:
+ _, exc_value, __ = sys.exc_info()
+ exceptions = []
+ while exc_value:
+ exceptions.append(self._format_exc(exc_value))
+ exc_value = exc_value.__cause__
+
+ traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions))
+ appname = escape(self.request.app.name)
+ name = escape(self.exception.__class__.__name__)
+ value = escape(self.exception)
+ path = escape(self.request.path)
+ lines += [
+ f"Traceback of {appname} " "(most recent call last):
",
+ f"{traceback_html}",
+ "",
+ f"{name}: {value} "
+ f"while handling path {path}
",
+ "
",
+ ]
+
+ for attr, display in (("context", True), ("extra", bool(full))):
+ info = getattr(self.exception, attr, None)
+ if info and display:
+ lines.append(self._generate_object_display(info, attr))
- traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions))
- appname = escape(self.request.app.name)
- name = escape(self.exception.__class__.__name__)
- value = escape(self.exception)
- path = escape(self.request.path)
- lines = [
- f"Traceback of {appname} (most recent call last):
",
- f"{traceback_html}",
- "",
- f"{name}: {value} while handling path {path}
",
- "
",
- ]
return "\n".join(lines)
+ def _generate_object_display(
+ self, obj: t.Dict[str, t.Any], descriptor: str
+ ) -> str:
+ display = "".join(
+ self.OBJECT_DISPLAY_HTML.format(key=key, value=value)
+ for key, value in obj.items()
+ )
+ return self.OBJECT_WRAPPER_HTML.format(
+ title=descriptor.title(),
+ display_html=display,
+ obj_type=descriptor.lower(),
+ )
+
def _format_exc(self, exc):
frames = extract_tb(exc.__traceback__)
frame_html = "".join(
@@ -224,7 +256,7 @@ class TextRenderer(BaseRenderer):
title=self.title,
text=self.text,
bar=("=" * len(self.title)),
- body=self._generate_body(),
+ body=self._generate_body(full=True),
),
status=self.status,
)
@@ -235,7 +267,7 @@ class TextRenderer(BaseRenderer):
title=self.title,
text=self.text,
bar=("=" * len(self.title)),
- body="",
+ body=self._generate_body(full=False),
),
status=self.status,
headers=self.headers,
@@ -245,21 +277,31 @@ class TextRenderer(BaseRenderer):
def title(self):
return f"⚠️ {super().title}"
- def _generate_body(self):
- _, exc_value, __ = sys.exc_info()
- exceptions = []
+ def _generate_body(self, *, full):
+ lines = []
+ if full:
+ _, exc_value, __ = sys.exc_info()
+ exceptions = []
- lines = [
- f"{self.exception.__class__.__name__}: {self.exception} while "
- f"handling path {self.request.path}",
- f"Traceback of {self.request.app.name} (most recent call last):\n",
- ]
+ lines += [
+ f"{self.exception.__class__.__name__}: {self.exception} while "
+ f"handling path {self.request.path}",
+ f"Traceback of {self.request.app.name} "
+ "(most recent call last):\n",
+ ]
- while exc_value:
- exceptions.append(self._format_exc(exc_value))
- exc_value = exc_value.__cause__
+ while exc_value:
+ exceptions.append(self._format_exc(exc_value))
+ exc_value = exc_value.__cause__
- return "\n".join(lines + exceptions[::-1])
+ lines += exceptions[::-1]
+
+ for attr, display in (("context", True), ("extra", bool(full))):
+ info = getattr(self.exception, attr, None)
+ if info and display:
+ lines += self._generate_object_display_list(info, attr)
+
+ return "\n".join(lines)
def _format_exc(self, exc):
frames = "\n\n".join(
@@ -272,6 +314,13 @@ class TextRenderer(BaseRenderer):
)
return f"{self.SPACER}{exc.__class__.__name__}: {exc}\n{frames}"
+ def _generate_object_display_list(self, obj, descriptor):
+ lines = [f"\n{descriptor.title()}"]
+ for key, value in obj.items():
+ display = self.dumps(value)
+ lines.append(f"{self.SPACER * 2}{key}: {display}")
+ return lines
+
class JSONRenderer(BaseRenderer):
"""
@@ -280,11 +329,11 @@ class JSONRenderer(BaseRenderer):
def full(self) -> HTTPResponse:
output = self._generate_output(full=True)
- return json(output, status=self.status, dumps=dumps)
+ return json(output, status=self.status, dumps=self.dumps)
def minimal(self) -> HTTPResponse:
output = self._generate_output(full=False)
- return json(output, status=self.status, dumps=dumps)
+ return json(output, status=self.status, dumps=self.dumps)
def _generate_output(self, *, full):
output = {
@@ -293,6 +342,11 @@ class JSONRenderer(BaseRenderer):
"message": self.text,
}
+ for attr, display in (("context", True), ("extra", bool(full))):
+ info = getattr(self.exception, attr, None)
+ if info and display:
+ output[attr] = info
+
if full:
_, exc_value, __ = sys.exc_info()
exceptions = []
@@ -393,7 +447,8 @@ def exception_response(
# from the route
if request.route:
try:
- render_format = request.route.ctx.error_format
+ if request.route.ctx.error_format:
+ render_format = request.route.ctx.error_format
except AttributeError:
...
diff --git a/sanic/exceptions.py b/sanic/exceptions.py
index 1bb06f1d..6459f15a 100644
--- a/sanic/exceptions.py
+++ b/sanic/exceptions.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union
+from typing import Any, Dict, Optional, Union
from sanic.helpers import STATUS_CODES
@@ -11,7 +11,11 @@ class SanicException(Exception):
message: Optional[Union[str, bytes]] = None,
status_code: Optional[int] = None,
quiet: Optional[bool] = None,
+ context: Optional[Dict[str, Any]] = None,
+ extra: Optional[Dict[str, Any]] = None,
) -> None:
+ self.context = context
+ self.extra = extra
if message is None:
if self.message:
message = self.message
diff --git a/sanic/handlers.py b/sanic/handlers.py
index af667c9a..046e56e1 100644
--- a/sanic/handlers.py
+++ b/sanic/handlers.py
@@ -38,7 +38,14 @@ class ErrorHandler:
self.base = base
@classmethod
- def finalize(cls, error_handler):
+ def finalize(cls, error_handler, fallback: Optional[str] = None):
+ if (
+ fallback
+ and fallback != "auto"
+ and error_handler.fallback == "auto"
+ ):
+ error_handler.fallback = fallback
+
if not isinstance(error_handler, cls):
error_logger.warning(
f"Error handler is non-conforming: {type(error_handler)}"
diff --git a/sanic/http.py b/sanic/http.py
index d30e4c82..6f59ef25 100644
--- a/sanic/http.py
+++ b/sanic/http.py
@@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta):
self.keep_alive = True
self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch
- self.init_for_request()
def init_for_request(self):
"""Init/reset all per-request variables."""
@@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta):
"""
HTTP 1.1 connection handler
"""
- while True: # As long as connection stays keep-alive
+ # Handle requests while the connection stays reusable
+ while self.keep_alive and self.stage is Stage.IDLE:
+ self.init_for_request()
+ # Wait for incoming bytes (in IDLE stage)
+ if not self.recv_buffer:
+ await self._receive_more()
+ self.stage = Stage.REQUEST
try:
# Receive and handle a request
- self.stage = Stage.REQUEST
self.response_func = self.http1_response_header
await self.http1_request_header()
+ self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
@@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta):
if self.response:
self.response.stream = None
- # Exit and disconnect if no more requests can be taken
- if self.stage is not Stage.IDLE or not self.keep_alive:
- break
-
- self.init_for_request()
-
- # Wait for the next request
- if not self.recv_buffer:
- await self._receive_more()
-
async def http1_request_header(self): # no cov
"""
Receive and parse request header into self.request.
@@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta):
# Remove header and its trailing CRLF
del buf[: pos + 4]
- self.stage = Stage.HANDLER
self.request, request.stream = request, self
self.protocol.state["requests_count"] += 1
diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py
index 8467a2e3..7139cd3c 100644
--- a/sanic/mixins/routes.py
+++ b/sanic/mixins/routes.py
@@ -918,7 +918,7 @@ class RouteMixin:
return route
- def _determine_error_format(self, handler) -> str:
+ def _determine_error_format(self, handler) -> Optional[str]:
if not isinstance(handler, CompositionView):
try:
src = dedent(getsource(handler))
@@ -930,7 +930,7 @@ class RouteMixin:
except (OSError, TypeError):
...
- return "auto"
+ return None
def _get_response_types(self, node):
types = set()
diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py
index 2be9fee2..57b01b46 100644
--- a/sanic/mixins/signals.py
+++ b/sanic/mixins/signals.py
@@ -1,4 +1,5 @@
-from typing import Any, Callable, Dict, Optional, Set
+from enum import Enum
+from typing import Any, Callable, Dict, Optional, Set, Union
from sanic.models.futures import FutureSignal
from sanic.models.handler_types import SignalHandler
@@ -19,7 +20,7 @@ class SignalMixin:
def signal(
self,
- event: str,
+ event: Union[str, Enum],
*,
apply: bool = True,
condition: Dict[str, Any] = None,
@@ -41,13 +42,11 @@ class SignalMixin:
filtering, defaults to None
:type condition: Dict[str, Any], optional
"""
+ event_value = str(event.value) if isinstance(event, Enum) else event
def decorator(handler: SignalHandler):
- nonlocal event
- nonlocal apply
-
future_signal = FutureSignal(
- handler, event, HashableDict(condition or {})
+ handler, event_value, HashableDict(condition or {})
)
self._future_signals.add(future_signal)
diff --git a/sanic/models/futures.py b/sanic/models/futures.py
index fe7d77eb..74ee92b9 100644
--- a/sanic/models/futures.py
+++ b/sanic/models/futures.py
@@ -60,3 +60,7 @@ class FutureSignal(NamedTuple):
handler: SignalHandler
event: str
condition: Optional[Dict[str, str]]
+
+
+class FutureRegistry(set):
+ ...
diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py
index 3a91a8f0..3c726edb 100644
--- a/sanic/reloader_helpers.py
+++ b/sanic/reloader_helpers.py
@@ -47,16 +47,18 @@ def _get_args_for_reloading():
return [sys.executable] + sys.argv
-def restart_with_reloader():
+def restart_with_reloader(changed=None):
"""Create a new process and a subprocess in it with the same arguments as
this one.
"""
+ reloaded = ",".join(changed) if changed else ""
return subprocess.Popen(
_get_args_for_reloading(),
env={
**os.environ,
"SANIC_SERVER_RUNNING": "true",
"SANIC_RELOADER_PROCESS": "true",
+ "SANIC_RELOADED_FILES": reloaded,
},
)
@@ -94,24 +96,27 @@ def watchdog(sleep_interval, app):
try:
while True:
- need_reload = False
+ changed = set()
for filename in itertools.chain(
_iter_module_files(),
*(d.glob("**/*") for d in app.reload_dirs),
):
try:
- check = _check_file(filename, mtimes)
+ if _check_file(filename, mtimes):
+ path = (
+ filename
+ if isinstance(filename, str)
+ else filename.resolve()
+ )
+ changed.add(str(path))
except OSError:
continue
- if check:
- need_reload = True
-
- if need_reload:
+ if changed:
worker_process.terminate()
worker_process.wait()
- worker_process = restart_with_reloader()
+ worker_process = restart_with_reloader(changed)
sleep(sleep_interval)
except KeyboardInterrupt:
diff --git a/sanic/router.py b/sanic/router.py
index 6995ed6d..b15c2a3e 100644
--- a/sanic/router.py
+++ b/sanic/router.py
@@ -139,11 +139,10 @@ class Router(BaseRouter):
route.ctx.stream = stream
route.ctx.hosts = hosts
route.ctx.static = static
- route.ctx.error_format = (
- error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
- )
+ route.ctx.error_format = error_format
- check_error_format(route.ctx.error_format)
+ if error_format:
+ check_error_format(route.ctx.error_format)
routes.append(route)
diff --git a/sanic/signals.py b/sanic/signals.py
index 9da7eccd..7bb510fa 100644
--- a/sanic/signals.py
+++ b/sanic/signals.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
+from enum import Enum
from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -14,29 +15,47 @@ from sanic.log import error_logger, logger
from sanic.models.handler_types import SignalHandler
+class Event(Enum):
+ SERVER_INIT_AFTER = "server.init.after"
+ SERVER_INIT_BEFORE = "server.init.before"
+ SERVER_SHUTDOWN_AFTER = "server.shutdown.after"
+ SERVER_SHUTDOWN_BEFORE = "server.shutdown.before"
+ HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin"
+ HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete"
+ HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception"
+ HTTP_LIFECYCLE_HANDLE = "http.lifecycle.handle"
+ HTTP_LIFECYCLE_READ_BODY = "http.lifecycle.read_body"
+ HTTP_LIFECYCLE_READ_HEAD = "http.lifecycle.read_head"
+ HTTP_LIFECYCLE_REQUEST = "http.lifecycle.request"
+ HTTP_LIFECYCLE_RESPONSE = "http.lifecycle.response"
+ HTTP_ROUTING_AFTER = "http.routing.after"
+ HTTP_ROUTING_BEFORE = "http.routing.before"
+ HTTP_LIFECYCLE_SEND = "http.lifecycle.send"
+ HTTP_MIDDLEWARE_AFTER = "http.middleware.after"
+ HTTP_MIDDLEWARE_BEFORE = "http.middleware.before"
+
+
RESERVED_NAMESPACES = {
"server": (
- # "server.main.start",
- # "server.main.stop",
- "server.init.before",
- "server.init.after",
- "server.shutdown.before",
- "server.shutdown.after",
+ Event.SERVER_INIT_AFTER.value,
+ Event.SERVER_INIT_BEFORE.value,
+ Event.SERVER_SHUTDOWN_AFTER.value,
+ Event.SERVER_SHUTDOWN_BEFORE.value,
),
"http": (
- "http.lifecycle.begin",
- "http.lifecycle.complete",
- "http.lifecycle.exception",
- "http.lifecycle.handle",
- "http.lifecycle.read_body",
- "http.lifecycle.read_head",
- "http.lifecycle.request",
- "http.lifecycle.response",
- "http.routing.after",
- "http.routing.before",
- "http.lifecycle.send",
- "http.middleware.after",
- "http.middleware.before",
+ Event.HTTP_LIFECYCLE_BEGIN.value,
+ Event.HTTP_LIFECYCLE_COMPLETE.value,
+ Event.HTTP_LIFECYCLE_EXCEPTION.value,
+ Event.HTTP_LIFECYCLE_HANDLE.value,
+ Event.HTTP_LIFECYCLE_READ_BODY.value,
+ Event.HTTP_LIFECYCLE_READ_HEAD.value,
+ Event.HTTP_LIFECYCLE_REQUEST.value,
+ Event.HTTP_LIFECYCLE_RESPONSE.value,
+ Event.HTTP_ROUTING_AFTER.value,
+ Event.HTTP_ROUTING_BEFORE.value,
+ Event.HTTP_LIFECYCLE_SEND.value,
+ Event.HTTP_MIDDLEWARE_AFTER.value,
+ Event.HTTP_MIDDLEWARE_BEFORE.value,
),
}
diff --git a/tests/test_blueprint_copy.py b/tests/test_blueprint_copy.py
index 033e2e20..ca8cd67e 100644
--- a/tests/test_blueprint_copy.py
+++ b/tests/test_blueprint_copy.py
@@ -1,6 +1,4 @@
-from copy import deepcopy
-
-from sanic import Blueprint, Sanic, blueprints, response
+from sanic import Blueprint, Sanic
from sanic.response import text
diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py
index b6a23151..3aa4487a 100644
--- a/tests/test_blueprints.py
+++ b/tests/test_blueprints.py
@@ -1088,3 +1088,31 @@ def test_bp_set_attribute_warning():
"and will be removed in version 21.12. You should change your "
"Blueprint instance to use instance.ctx.foo instead."
)
+
+
+def test_early_registration(app):
+ assert len(app.router.routes) == 0
+
+ bp = Blueprint("bp")
+
+ @bp.get("/one")
+ async def one(_):
+ return text("one")
+
+ app.blueprint(bp)
+
+ assert len(app.router.routes) == 1
+
+ @bp.get("/two")
+ async def two(_):
+ return text("two")
+
+ @bp.get("/three")
+ async def three(_):
+ return text("three")
+
+ assert len(app.router.routes) == 3
+
+ for path in ("one", "two", "three"):
+ _, response = app.test_client.get(f"/{path}")
+ assert response.text == path
diff --git a/tests/test_coffee.py b/tests/test_coffee.py
new file mode 100644
index 00000000..6143f17f
--- /dev/null
+++ b/tests/test_coffee.py
@@ -0,0 +1,48 @@
+import logging
+
+from unittest.mock import patch
+
+import pytest
+
+from sanic.application.logo import COFFEE_LOGO, get_logo
+from sanic.exceptions import SanicException
+
+
+def has_sugar(value):
+ if value:
+ raise SanicException("I said no sugar please")
+
+ return False
+
+
+@pytest.mark.parametrize("sugar", (True, False))
+def test_no_sugar(sugar):
+ if sugar:
+ with pytest.raises(SanicException):
+ assert has_sugar(sugar)
+ else:
+ assert not has_sugar(sugar)
+
+
+def test_get_logo_returns_expected_logo():
+ with patch("sys.stdout.isatty") as isatty:
+ isatty.return_value = True
+ logo = get_logo(coffee=True)
+ assert logo is COFFEE_LOGO
+
+
+def test_logo_true(app, caplog):
+ @app.after_server_start
+ async def shutdown(*_):
+ app.stop()
+
+ with patch("sys.stdout.isatty") as isatty:
+ isatty.return_value = True
+ with caplog.at_level(logging.DEBUG):
+ app.make_coffee()
+
+ # Only in the regular logo
+ assert " ▄███ █████ ██ " not in caplog.text
+
+ # Only in the coffee logo
+ assert " ██ ██▀▀▄ " in caplog.text
diff --git a/tests/test_config.py b/tests/test_config.py
index f3447666..67324f1e 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,9 +1,9 @@
from contextlib import contextmanager
-from email import message
from os import environ
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
+from unittest.mock import Mock
import pytest
@@ -360,3 +360,31 @@ def test_deprecation_notice_when_setting_logo(app):
)
with pytest.warns(DeprecationWarning, match=message):
app.config.LOGO = "My Custom Logo"
+
+
+def test_config_set_methods(app, monkeypatch):
+ post_set = Mock()
+ monkeypatch.setattr(Config, "_post_set", post_set)
+
+ app.config.FOO = 1
+ post_set.assert_called_once_with("FOO", 1)
+ post_set.reset_mock()
+
+ app.config["FOO"] = 2
+ post_set.assert_called_once_with("FOO", 2)
+ post_set.reset_mock()
+
+ app.config.update({"FOO": 3})
+ post_set.assert_called_once_with("FOO", 3)
+ post_set.reset_mock()
+
+ app.config.update([("FOO", 4)])
+ post_set.assert_called_once_with("FOO", 4)
+ post_set.reset_mock()
+
+ app.config.update(FOO=5)
+ post_set.assert_called_once_with("FOO", 5)
+ post_set.reset_mock()
+
+ app.config.update_config({"FOO": 6})
+ post_set.assert_called_once_with("FOO", 6)
diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py
index 5af4ca5f..1843f6a7 100644
--- a/tests/test_errorpages.py
+++ b/tests/test_errorpages.py
@@ -1,8 +1,10 @@
import pytest
from sanic import Sanic
+from sanic.config import Config
from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound, SanicException
+from sanic.handlers import ErrorHandler
from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text
@@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected):
)
assert response.content_type == expected
+
+
+def test_allow_fallback_error_format_set_main_process_start(app):
+ @app.main_process_start
+ async def start(app, _):
+ app.config.FALLBACK_ERROR_FORMAT = "text"
+
+ request, response = app.test_client.get("/error")
+ assert request.app.error_handler.fallback == "text"
+ assert response.status == 500
+ assert response.content_type == "text/plain; charset=utf-8"
+
+
+def test_setting_fallback_to_non_default_raise_warning(app):
+ app.error_handler = ErrorHandler(fallback="text")
+
+ assert app.error_handler.fallback == "text"
+
+ with pytest.warns(
+ UserWarning,
+ match=(
+ "Overriding non-default ErrorHandler fallback value. "
+ "Changing from text to auto."
+ ),
+ ):
+ app.config.FALLBACK_ERROR_FORMAT = "auto"
+
+ assert app.error_handler.fallback == "auto"
+
+ app.config.FALLBACK_ERROR_FORMAT = "text"
+
+ with pytest.warns(
+ UserWarning,
+ match=(
+ "Overriding non-default ErrorHandler fallback value. "
+ "Changing from text to json."
+ ),
+ ):
+ app.config.FALLBACK_ERROR_FORMAT = "json"
+
+ assert app.error_handler.fallback == "json"
+
+
+def test_allow_fallback_error_format_in_config_injection():
+ class MyConfig(Config):
+ FALLBACK_ERROR_FORMAT = "text"
+
+ app = Sanic("test", config=MyConfig())
+
+ @app.route("/error", methods=["GET", "POST"])
+ def err(request):
+ raise Exception("something went wrong")
+
+ request, response = app.test_client.get("/error")
+ assert request.app.error_handler.fallback == "text"
+ assert response.status == 500
+ assert response.content_type == "text/plain; charset=utf-8"
+
+
+def test_allow_fallback_error_format_in_config_replacement(app):
+ class MyConfig(Config):
+ FALLBACK_ERROR_FORMAT = "text"
+
+ app.config = MyConfig()
+
+ request, response = app.test_client.get("/error")
+ assert request.app.error_handler.fallback == "text"
+ assert response.status == 500
+ assert response.content_type == "text/plain; charset=utf-8"
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index 0485137a..eea97935 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -18,6 +18,16 @@ from sanic.exceptions import (
from sanic.response import text
+def dl_to_dict(soup, css_class):
+ keys, values = [], []
+ for dl in soup.find_all("dl", {"class": css_class}):
+ for dt in dl.find_all("dt"):
+ keys.append(dt.text.strip())
+ for dd in dl.find_all("dd"):
+ values.append(dd.text.strip())
+ return dict(zip(keys, values))
+
+
class SanicExceptionTestException(Exception):
pass
@@ -264,3 +274,110 @@ def test_exception_in_ws_logged(caplog):
error_logs = [r for r in caplog.record_tuples if r[0] == "sanic.error"]
assert error_logs[1][1] == logging.ERROR
assert "Exception occurred while handling uri:" in error_logs[1][2]
+
+
+@pytest.mark.parametrize("debug", (True, False))
+def test_contextual_exception_context(debug):
+ app = Sanic(__name__)
+
+ class TeapotError(SanicException):
+ status_code = 418
+ message = "Sorry, I cannot brew coffee"
+
+ def fail():
+ raise TeapotError(context={"foo": "bar"})
+
+ app.post("/coffee/json", error_format="json")(lambda _: fail())
+ app.post("/coffee/html", error_format="html")(lambda _: fail())
+ app.post("/coffee/text", error_format="text")(lambda _: fail())
+
+ _, response = app.test_client.post("/coffee/json", debug=debug)
+ assert response.status == 418
+ assert response.json["message"] == "Sorry, I cannot brew coffee"
+ assert response.json["context"] == {"foo": "bar"}
+
+ _, response = app.test_client.post("/coffee/html", debug=debug)
+ soup = BeautifulSoup(response.body, "html.parser")
+ dl = dl_to_dict(soup, "context")
+ assert response.status == 418
+ assert "Sorry, I cannot brew coffee" in soup.find("p").text
+ assert dl == {"foo": "bar"}
+
+ _, response = app.test_client.post("/coffee/text", debug=debug)
+ lines = list(map(lambda x: x.decode(), response.body.split(b"\n")))
+ idx = lines.index("Context") + 1
+ assert response.status == 418
+ assert lines[2] == "Sorry, I cannot brew coffee"
+ assert lines[idx] == ' foo: "bar"'
+
+
+@pytest.mark.parametrize("debug", (True, False))
+def test_contextual_exception_extra(debug):
+ app = Sanic(__name__)
+
+ class TeapotError(SanicException):
+ status_code = 418
+
+ @property
+ def message(self):
+ return f"Found {self.extra['foo']}"
+
+ def fail():
+ raise TeapotError(extra={"foo": "bar"})
+
+ app.post("/coffee/json", error_format="json")(lambda _: fail())
+ app.post("/coffee/html", error_format="html")(lambda _: fail())
+ app.post("/coffee/text", error_format="text")(lambda _: fail())
+
+ _, response = app.test_client.post("/coffee/json", debug=debug)
+ assert response.status == 418
+ assert response.json["message"] == "Found bar"
+ if debug:
+ assert response.json["extra"] == {"foo": "bar"}
+ else:
+ assert "extra" not in response.json
+
+ _, response = app.test_client.post("/coffee/html", debug=debug)
+ soup = BeautifulSoup(response.body, "html.parser")
+ dl = dl_to_dict(soup, "extra")
+ assert response.status == 418
+ assert "Found bar" in soup.find("p").text
+ if debug:
+ assert dl == {"foo": "bar"}
+ else:
+ assert not dl
+
+ _, response = app.test_client.post("/coffee/text", debug=debug)
+ lines = list(map(lambda x: x.decode(), response.body.split(b"\n")))
+ assert response.status == 418
+ assert lines[2] == "Found bar"
+ if debug:
+ idx = lines.index("Extra") + 1
+ assert lines[idx] == ' foo: "bar"'
+ else:
+ assert "Extra" not in lines
+
+
+@pytest.mark.parametrize("override", (True, False))
+def test_contextual_exception_functional_message(override):
+ app = Sanic(__name__)
+
+ class TeapotError(SanicException):
+ status_code = 418
+
+ @property
+ def message(self):
+ return f"Received foo={self.context['foo']}"
+
+ @app.post("/coffee", error_format="json")
+ async def make_coffee(_):
+ error_args = {"context": {"foo": "bar"}}
+ if override:
+ error_args["message"] = "override"
+ raise TeapotError(**error_args)
+
+ _, response = app.test_client.post("/coffee", debug=True)
+ error_message = "override" if override else "Received foo=bar"
+ assert response.status == 418
+ assert response.json["message"] == error_message
+ assert response.json["context"] == {"foo": "bar"}
diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py
deleted file mode 100644
index 48e23f1d..00000000
--- a/tests/test_request_timeout.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import asyncio
-
-import httpcore
-import httpx
-import pytest
-
-from sanic_testing.testing import SanicTestClient
-
-from sanic import Sanic
-from sanic.response import text
-
-
-class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection):
- async def arequest(self, *args, **kwargs):
- await asyncio.sleep(2)
- return await super().arequest(*args, **kwargs)
-
- async def _open_socket(self, *args, **kwargs):
- retval = await super()._open_socket(*args, **kwargs)
- if self._request_delay:
- await asyncio.sleep(self._request_delay)
- return retval
-
-
-class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool):
- def __init__(self, request_delay=None, *args, **kwargs):
- self._request_delay = request_delay
- super().__init__(*args, **kwargs)
-
- async def _add_to_pool(self, connection, timeout):
- connection.__class__ = DelayableHTTPConnection
- connection._request_delay = self._request_delay
- await super()._add_to_pool(connection, timeout)
-
-
-class DelayableSanicSession(httpx.AsyncClient):
- def __init__(self, request_delay=None, *args, **kwargs) -> None:
- transport = DelayableSanicConnectionPool(request_delay=request_delay)
- super().__init__(transport=transport, *args, **kwargs)
-
-
-class DelayableSanicTestClient(SanicTestClient):
- def __init__(self, app, request_delay=None):
- super().__init__(app)
- self._request_delay = request_delay
- self._loop = None
-
- def get_new_session(self):
- return DelayableSanicSession(request_delay=self._request_delay)
-
-
-@pytest.fixture
-def request_no_timeout_app():
- app = Sanic("test_request_no_timeout")
- app.config.REQUEST_TIMEOUT = 0.6
-
- @app.route("/1")
- async def handler2(request):
- return text("OK")
-
- return app
-
-
-@pytest.fixture
-def request_timeout_default_app():
- app = Sanic("test_request_timeout_default")
- app.config.REQUEST_TIMEOUT = 0.6
-
- @app.route("/1")
- async def handler1(request):
- return text("OK")
-
- @app.websocket("/ws1")
- async def ws_handler1(request, ws):
- await ws.send("OK")
-
- return app
-
-
-def test_default_server_error_request_timeout(request_timeout_default_app):
- client = DelayableSanicTestClient(request_timeout_default_app, 2)
- _, response = client.get("/1")
- assert response.status == 408
- assert "Request Timeout" in response.text
-
-
-def test_default_server_error_request_dont_timeout(request_no_timeout_app):
- client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
- _, response = client.get("/1")
- assert response.status == 200
- assert response.text == "OK"
-
-
-def test_default_server_error_websocket_request_timeout(
- request_timeout_default_app,
-):
-
- headers = {
- "Upgrade": "websocket",
- "Connection": "upgrade",
- "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
- "Sec-WebSocket-Version": "13",
- }
-
- client = DelayableSanicTestClient(request_timeout_default_app, 2)
- _, response = client.get("/ws1", headers=headers)
-
- assert response.status == 408
- assert "Request Timeout" in response.text
diff --git a/tests/test_signals.py b/tests/test_signals.py
index 9b8a9495..51aea3c8 100644
--- a/tests/test_signals.py
+++ b/tests/test_signals.py
@@ -1,5 +1,6 @@
import asyncio
+from enum import Enum
from inspect import isawaitable
import pytest
@@ -50,6 +51,25 @@ def test_invalid_signal(app, signal):
...
+@pytest.mark.asyncio
+async def test_dispatch_signal_with_enum_event(app):
+ counter = 0
+
+ class FooEnum(Enum):
+ FOO_BAR_BAZ = "foo.bar.baz"
+
+ @app.signal(FooEnum.FOO_BAR_BAZ)
+ def sync_signal(*_):
+ nonlocal counter
+
+ counter += 1
+
+ app.signal_router.finalize()
+
+ await app.dispatch("foo.bar.baz")
+ assert counter == 1
+
+
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_multiple_handlers(app):
counter = 0
diff --git a/tests/test_timeout_logic.py b/tests/test_timeout_logic.py
index 05249f11..497deda9 100644
--- a/tests/test_timeout_logic.py
+++ b/tests/test_timeout_logic.py
@@ -26,6 +26,7 @@ def protocol(app, mock_transport):
protocol = HttpProtocol(loop=loop, app=app)
protocol.connection_made(mock_transport)
protocol._setup_connection()
+ protocol._http.init_for_request()
protocol._task = Mock(spec=asyncio.Task)
protocol._task.cancel = Mock()
return protocol