Merge branch 'main' of github.com:sanic-org/sanic into feat/optional-uvloop-use

This commit is contained in:
prryplatypus 2021-11-18 20:06:49 +01:00
commit 12ecf52878
No known key found for this signature in database
GPG Key ID: 6687E128FB70819B
24 changed files with 673 additions and 272 deletions

View File

@ -72,6 +72,7 @@ from sanic.models.futures import (
FutureException, FutureException,
FutureListener, FutureListener,
FutureMiddleware, FutureMiddleware,
FutureRegistry,
FutureRoute, FutureRoute,
FutureSignal, FutureSignal,
FutureStatic, FutureStatic,
@ -115,6 +116,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"_future_exceptions", "_future_exceptions",
"_future_listeners", "_future_listeners",
"_future_middleware", "_future_middleware",
"_future_registry",
"_future_routes", "_future_routes",
"_future_signals", "_future_signals",
"_future_statics", "_future_statics",
@ -187,17 +189,18 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self._test_manager: Any = None self._test_manager: Any = None
self._blueprint_order: List[Blueprint] = [] self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = [] self._delayed_tasks: List[str] = []
self._future_registry: FutureRegistry = FutureRegistry()
self._state: ApplicationState = ApplicationState(app=self) self._state: ApplicationState = ApplicationState(app=self)
self.blueprints: Dict[str, Blueprint] = {} self.blueprints: Dict[str, Blueprint] = {}
self.config: Config = config or Config( self.config: Config = config or Config(
load_env=load_env, env_prefix=env_prefix load_env=load_env,
env_prefix=env_prefix,
app=self,
) )
self.configure_logging: bool = configure_logging self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace() self.ctx: Any = ctx or SimpleNamespace()
self.debug = False self.debug = False
self.error_handler: ErrorHandler = error_handler or ErrorHandler( self.error_handler: ErrorHandler = error_handler or ErrorHandler()
fallback=self.config.FALLBACK_ERROR_FORMAT,
)
self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list)
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}
@ -957,6 +960,10 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
# Execution # Execution
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
def make_coffee(self, *args, **kwargs):
self.state.coffee = True
self.run(*args, **kwargs)
def run( def run(
self, self,
host: Optional[str] = None, host: Optional[str] = None,
@ -1569,7 +1576,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
extra.update(self.config.MOTD_DISPLAY) extra.update(self.config.MOTD_DISPLAY)
logo = ( logo = (
get_logo() get_logo(coffee=self.state.coffee)
if self.config.LOGO == "" or self.config.LOGO is True if self.config.LOGO == "" or self.config.LOGO is True
else self.config.LOGO else self.config.LOGO
) )
@ -1635,9 +1642,12 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
raise e raise e
async def _startup(self): async def _startup(self):
self._future_registry.clear()
self.signalize() self.signalize()
self.finalize() self.finalize()
ErrorHandler.finalize(self.error_handler) ErrorHandler.finalize(
self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT
)
TouchUp.run(self) TouchUp.run(self)
async def _server_event( async def _server_event(

View File

@ -10,6 +10,15 @@ BASE_LOGO = """
Build Fast. Run Fast. Build Fast. Run Fast.
""" """
COFFEE_LOGO = """\033[48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m \033[0m
\033[48;2;255;13;104m \033[0m
Dark roast. No sugar."""
COLOR_LOGO = """\033[48;2;255;13;104m \033[0m COLOR_LOGO = """\033[48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m \033[0m \033[38;2;255;255;255;48;2;255;13;104m \033[0m
\033[38;2;255;255;255;48;2;255;13;104m \033[0m \033[38;2;255;255;255;48;2;255;13;104m \033[0m
@ -32,9 +41,9 @@ FULL_COLOR_LOGO = """
ansi_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") ansi_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
def get_logo(full=False): def get_logo(full=False, coffee=False):
logo = ( logo = (
(FULL_COLOR_LOGO if full else COLOR_LOGO) (FULL_COLOR_LOGO if full else (COFFEE_LOGO if coffee else COLOR_LOGO))
if sys.stdout.isatty() if sys.stdout.isatty()
else BASE_LOGO else BASE_LOGO
) )

View File

@ -34,6 +34,7 @@ class Mode(StrEnum):
class ApplicationState: class ApplicationState:
app: Sanic app: Sanic
asgi: bool = field(default=False) asgi: bool = field(default=False)
coffee: bool = field(default=False)
fast: bool = field(default=False) fast: bool = field(default=False)
host: str = field(default="") host: str = field(default="")
mode: Mode = field(default=Mode.PRODUCTION) mode: Mode = field(default=Mode.PRODUCTION)

View File

@ -4,6 +4,9 @@ import asyncio
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from functools import wraps
from inspect import isfunction
from itertools import chain
from types import SimpleNamespace from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -12,7 +15,9 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence,
Set, Set,
Tuple,
Union, Union,
) )
@ -35,6 +40,32 @@ if TYPE_CHECKING:
from sanic import Sanic # noqa from sanic import Sanic # noqa
def lazy(func, as_decorator=True):
@wraps(func)
def decorator(bp, *args, **kwargs):
nonlocal as_decorator
kwargs["apply"] = False
pass_handler = None
if args and isfunction(args[0]):
as_decorator = False
def wrapper(handler):
future = func(bp, *args, **kwargs)
if as_decorator:
future = future(handler)
if bp.registered:
for app in bp.apps:
bp.register(app, {})
return future
return wrapper if as_decorator else wrapper(pass_handler)
return decorator
class Blueprint(BaseSanic): class Blueprint(BaseSanic):
""" """
In *Sanic* terminology, a **Blueprint** is a logical collection of In *Sanic* terminology, a **Blueprint** is a logical collection of
@ -124,29 +155,16 @@ class Blueprint(BaseSanic):
) )
return self._apps return self._apps
def route(self, *args, **kwargs): @property
kwargs["apply"] = False def registered(self) -> bool:
return super().route(*args, **kwargs) return bool(self._apps)
def static(self, *args, **kwargs): exception = lazy(BaseSanic.exception)
kwargs["apply"] = False listener = lazy(BaseSanic.listener)
return super().static(*args, **kwargs) middleware = lazy(BaseSanic.middleware)
route = lazy(BaseSanic.route)
def middleware(self, *args, **kwargs): signal = lazy(BaseSanic.signal)
kwargs["apply"] = False static = lazy(BaseSanic.static, as_decorator=False)
return super().middleware(*args, **kwargs)
def listener(self, *args, **kwargs):
kwargs["apply"] = False
return super().listener(*args, **kwargs)
def exception(self, *args, **kwargs):
kwargs["apply"] = False
return super().exception(*args, **kwargs)
def signal(self, event: str, *args, **kwargs):
kwargs["apply"] = False
return super().signal(event, *args, **kwargs)
def reset(self): def reset(self):
self._apps: Set[Sanic] = set() self._apps: Set[Sanic] = set()
@ -283,6 +301,7 @@ class Blueprint(BaseSanic):
middleware = [] middleware = []
exception_handlers = [] exception_handlers = []
listeners = defaultdict(list) listeners = defaultdict(list)
registered = set()
# Routes # Routes
for future in self._future_routes: for future in self._future_routes:
@ -309,12 +328,15 @@ class Blueprint(BaseSanic):
) )
name = app._generate_name(future.name) name = app._generate_name(future.name)
host = future.host or self.host
if isinstance(host, list):
host = tuple(host)
apply_route = FutureRoute( apply_route = FutureRoute(
future.handler, future.handler,
uri[1:] if uri.startswith("//") else uri, uri[1:] if uri.startswith("//") else uri,
future.methods, future.methods,
future.host or self.host, host,
strict_slashes, strict_slashes,
future.stream, future.stream,
version, version,
@ -328,6 +350,10 @@ class Blueprint(BaseSanic):
error_format, error_format,
) )
if (self, apply_route) in app._future_registry:
continue
registered.add(apply_route)
route = app._apply_route(apply_route) route = app._apply_route(apply_route)
operation = ( operation = (
routes.extend if isinstance(route, list) else routes.append routes.extend if isinstance(route, list) else routes.append
@ -339,6 +365,11 @@ class Blueprint(BaseSanic):
# Prepend the blueprint URI prefix if available # Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri uri = url_prefix + future.uri if url_prefix else future.uri
apply_route = FutureStatic(uri, *future[1:]) apply_route = FutureStatic(uri, *future[1:])
if (self, apply_route) in app._future_registry:
continue
registered.add(apply_route)
route = app._apply_static(apply_route) route = app._apply_static(apply_route)
routes.append(route) routes.append(route)
@ -347,30 +378,51 @@ class Blueprint(BaseSanic):
if route_names: if route_names:
# Middleware # Middleware
for future in self._future_middleware: for future in self._future_middleware:
if (self, future) in app._future_registry:
continue
middleware.append(app._apply_middleware(future, route_names)) middleware.append(app._apply_middleware(future, route_names))
# Exceptions # Exceptions
for future in self._future_exceptions: for future in self._future_exceptions:
if (self, future) in app._future_registry:
continue
exception_handlers.append( exception_handlers.append(
app._apply_exception_handler(future, route_names) app._apply_exception_handler(future, route_names)
) )
# Event listeners # Event listeners
for listener in self._future_listeners: for future in self._future_listeners:
listeners[listener.event].append(app._apply_listener(listener)) if (self, future) in app._future_registry:
continue
listeners[future.event].append(app._apply_listener(future))
# Signals # Signals
for signal in self._future_signals: for future in self._future_signals:
signal.condition.update({"blueprint": self.name}) if (self, future) in app._future_registry:
app._apply_signal(signal) continue
future.condition.update({"blueprint": self.name})
app._apply_signal(future)
self.routes = [route for route in routes if isinstance(route, Route)] self.routes += [route for route in routes if isinstance(route, Route)]
self.websocket_routes = [ self.websocket_routes += [
route for route in self.routes if route.ctx.websocket route for route in self.routes if route.ctx.websocket
] ]
self.middlewares = middleware self.middlewares += middleware
self.exceptions = exception_handlers self.exceptions += exception_handlers
self.listeners = dict(listeners) self.listeners.update(dict(listeners))
if self.registered:
self.register_futures(
self.apps,
self,
chain(
registered,
self._future_middleware,
self._future_exceptions,
self._future_listeners,
self._future_signals,
),
)
async def dispatch(self, *args, **kwargs): async def dispatch(self, *args, **kwargs):
condition = kwargs.pop("condition", {}) condition = kwargs.pop("condition", {})
@ -402,3 +454,10 @@ class Blueprint(BaseSanic):
value = v value = v
break break
return value return value
@staticmethod
def register_futures(
apps: Set[Sanic], bp: Blueprint, futures: Sequence[Tuple[Any, ...]]
):
for app in apps:
app._future_registry.update(set((bp, item) for item in futures))

View File

@ -1,7 +1,9 @@
from __future__ import annotations
from inspect import isclass from inspect import isclass
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from warnings import warn from warnings import warn
from sanic.errorpages import check_error_format from sanic.errorpages import check_error_format
@ -9,6 +11,10 @@ from sanic.http import Http
from sanic.utils import load_module_from_file_location, str_to_bool from sanic.utils import load_module_from_file_location, str_to_bool
if TYPE_CHECKING: # no cov
from sanic import Sanic
SANIC_PREFIX = "SANIC_" SANIC_PREFIX = "SANIC_"
@ -75,10 +81,13 @@ class Config(dict):
load_env: Optional[Union[bool, str]] = True, load_env: Optional[Union[bool, str]] = True,
env_prefix: Optional[str] = SANIC_PREFIX, env_prefix: Optional[str] = SANIC_PREFIX,
keep_alive: Optional[bool] = None, keep_alive: Optional[bool] = None,
*,
app: Optional[Sanic] = None,
): ):
defaults = defaults or {} defaults = defaults or {}
super().__init__({**DEFAULT_CONFIG, **defaults}) super().__init__({**DEFAULT_CONFIG, **defaults})
self._app = app
self._LOGO = "" self._LOGO = ""
if keep_alive is not None: if keep_alive is not None:
@ -101,6 +110,7 @@ class Config(dict):
self._configure_header_size() self._configure_header_size()
self._check_error_format() self._check_error_format()
self._init = True
def __getattr__(self, attr): def __getattr__(self, attr):
try: try:
@ -108,23 +118,47 @@ class Config(dict):
except KeyError as ke: except KeyError as ke:
raise AttributeError(f"Config has no '{ke.args[0]}'") raise AttributeError(f"Config has no '{ke.args[0]}'")
def __setattr__(self, attr, value): def __setattr__(self, attr, value) -> None:
self[attr] = value self.update({attr: value})
if attr in (
"REQUEST_MAX_HEADER_SIZE", def __setitem__(self, attr, value) -> None:
"REQUEST_BUFFER_SIZE", self.update({attr: value})
"REQUEST_MAX_SIZE",
): def update(self, *other, **kwargs) -> None:
self._configure_header_size() other_mapping = {k: v for item in other for k, v in dict(item).items()}
elif attr == "FALLBACK_ERROR_FORMAT": super().update(*other, **kwargs)
self._check_error_format() for attr, value in {**other_mapping, **kwargs}.items():
elif attr == "LOGO": self._post_set(attr, value)
self._LOGO = value
warn( def _post_set(self, attr, value) -> None:
"Setting the config.LOGO is deprecated and will no longer " if self.get("_init"):
"be supported starting in v22.6.", if attr in (
DeprecationWarning, "REQUEST_MAX_HEADER_SIZE",
) "REQUEST_BUFFER_SIZE",
"REQUEST_MAX_SIZE",
):
self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format()
if self.app and value != self.app.error_handler.fallback:
if self.app.error_handler.fallback != "auto":
warn(
"Overriding non-default ErrorHandler fallback "
"value. Changing from "
f"{self.app.error_handler.fallback} to {value}."
)
self.app.error_handler.fallback = value
elif attr == "LOGO":
self._LOGO = value
warn(
"Setting the config.LOGO is deprecated and will no longer "
"be supported starting in v22.6.",
DeprecationWarning,
)
@property
def app(self):
return self._app
@property @property
def LOGO(self): def LOGO(self):

View File

@ -25,12 +25,13 @@ from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text from sanic.response import HTTPResponse, html, json, text
dumps: t.Callable[..., str]
try: try:
from ujson import dumps from ujson import dumps
dumps = partial(dumps, escape_forward_slashes=False) dumps = partial(dumps, escape_forward_slashes=False)
except ImportError: # noqa except ImportError: # noqa
from json import dumps # type: ignore from json import dumps
FALLBACK_TEXT = ( FALLBACK_TEXT = (
@ -45,6 +46,8 @@ class BaseRenderer:
Base class that all renderers must inherit from. Base class that all renderers must inherit from.
""" """
dumps = staticmethod(dumps)
def __init__(self, request, exception, debug): def __init__(self, request, exception, debug):
self.request = request self.request = request
self.exception = exception self.exception = exception
@ -112,14 +115,16 @@ class HTMLRenderer(BaseRenderer):
TRACEBACK_STYLE = """ TRACEBACK_STYLE = """
html { font-family: sans-serif } html { font-family: sans-serif }
h2 { color: #888; } h2 { color: #888; }
.tb-wrapper p { margin: 0 } .tb-wrapper p, dl, dd { margin: 0 }
.frame-border { margin: 1rem } .frame-border { margin: 1rem }
.frame-line > * { padding: 0.3rem 0.6rem } .frame-line > *, dt, dd { padding: 0.3rem 0.6rem }
.frame-line { margin-bottom: 0.3rem } .frame-line, dl { margin-bottom: 0.3rem }
.frame-code { font-size: 16px; padding-left: 4ch } .frame-code, dd { font-size: 16px; padding-left: 4ch }
.tb-wrapper { border: 1px solid #eee } .tb-wrapper, dl { border: 1px solid #eee }
.tb-header { background: #eee; padding: 0.3rem; font-weight: bold } .tb-header,.obj-header {
.frame-descriptor { background: #e2eafb; font-size: 14px } background: #eee; padding: 0.3rem; font-weight: bold
}
.frame-descriptor, dt { background: #e2eafb; font-size: 14px }
""" """
TRACEBACK_WRAPPER_HTML = ( TRACEBACK_WRAPPER_HTML = (
"<div class=tb-header>{exc_name}: {exc_value}</div>" "<div class=tb-header>{exc_name}: {exc_value}</div>"
@ -138,6 +143,11 @@ class HTMLRenderer(BaseRenderer):
"<p class=frame-code><code>{0.line}</code>" "<p class=frame-code><code>{0.line}</code>"
"</div>" "</div>"
) )
OBJECT_WRAPPER_HTML = (
"<div class=obj-header>{title}</div>"
"<dl class={obj_type}>{display_html}</dl>"
)
OBJECT_DISPLAY_HTML = "<dt>{key}</dt><dd><code>{value}</code></dd>"
OUTPUT_HTML = ( OUTPUT_HTML = (
"<!DOCTYPE html><html lang=en>" "<!DOCTYPE html><html lang=en>"
"<meta charset=UTF-8><title>{title}</title>\n" "<meta charset=UTF-8><title>{title}</title>\n"
@ -152,7 +162,7 @@ class HTMLRenderer(BaseRenderer):
title=self.title, title=self.title,
text=self.text, text=self.text,
style=self.TRACEBACK_STYLE, style=self.TRACEBACK_STYLE,
body=self._generate_body(), body=self._generate_body(full=True),
), ),
status=self.status, status=self.status,
) )
@ -163,7 +173,7 @@ class HTMLRenderer(BaseRenderer):
title=self.title, title=self.title,
text=self.text, text=self.text,
style=self.TRACEBACK_STYLE, style=self.TRACEBACK_STYLE,
body="", body=self._generate_body(full=False),
), ),
status=self.status, status=self.status,
headers=self.headers, headers=self.headers,
@ -177,27 +187,49 @@ class HTMLRenderer(BaseRenderer):
def title(self): def title(self):
return escape(f"⚠️ {super().title}") return escape(f"⚠️ {super().title}")
def _generate_body(self): def _generate_body(self, *, full):
_, exc_value, __ = sys.exc_info() lines = []
exceptions = [] if full:
while exc_value: _, exc_value, __ = sys.exc_info()
exceptions.append(self._format_exc(exc_value)) exceptions = []
exc_value = exc_value.__cause__ while exc_value:
exceptions.append(self._format_exc(exc_value))
exc_value = exc_value.__cause__
traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions))
appname = escape(self.request.app.name)
name = escape(self.exception.__class__.__name__)
value = escape(self.exception)
path = escape(self.request.path)
lines += [
f"<h2>Traceback of {appname} " "(most recent call last):</h2>",
f"{traceback_html}",
"<div class=summary><p>",
f"<b>{name}: {value}</b> "
f"while handling path <code>{path}</code>",
"</div>",
]
for attr, display in (("context", True), ("extra", bool(full))):
info = getattr(self.exception, attr, None)
if info and display:
lines.append(self._generate_object_display(info, attr))
traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions))
appname = escape(self.request.app.name)
name = escape(self.exception.__class__.__name__)
value = escape(self.exception)
path = escape(self.request.path)
lines = [
f"<h2>Traceback of {appname} (most recent call last):</h2>",
f"{traceback_html}",
"<div class=summary><p>",
f"<b>{name}: {value}</b> while handling path <code>{path}</code>",
"</div>",
]
return "\n".join(lines) return "\n".join(lines)
def _generate_object_display(
self, obj: t.Dict[str, t.Any], descriptor: str
) -> str:
display = "".join(
self.OBJECT_DISPLAY_HTML.format(key=key, value=value)
for key, value in obj.items()
)
return self.OBJECT_WRAPPER_HTML.format(
title=descriptor.title(),
display_html=display,
obj_type=descriptor.lower(),
)
def _format_exc(self, exc): def _format_exc(self, exc):
frames = extract_tb(exc.__traceback__) frames = extract_tb(exc.__traceback__)
frame_html = "".join( frame_html = "".join(
@ -224,7 +256,7 @@ class TextRenderer(BaseRenderer):
title=self.title, title=self.title,
text=self.text, text=self.text,
bar=("=" * len(self.title)), bar=("=" * len(self.title)),
body=self._generate_body(), body=self._generate_body(full=True),
), ),
status=self.status, status=self.status,
) )
@ -235,7 +267,7 @@ class TextRenderer(BaseRenderer):
title=self.title, title=self.title,
text=self.text, text=self.text,
bar=("=" * len(self.title)), bar=("=" * len(self.title)),
body="", body=self._generate_body(full=False),
), ),
status=self.status, status=self.status,
headers=self.headers, headers=self.headers,
@ -245,21 +277,31 @@ class TextRenderer(BaseRenderer):
def title(self): def title(self):
return f"⚠️ {super().title}" return f"⚠️ {super().title}"
def _generate_body(self): def _generate_body(self, *, full):
_, exc_value, __ = sys.exc_info() lines = []
exceptions = [] if full:
_, exc_value, __ = sys.exc_info()
exceptions = []
lines = [ lines += [
f"{self.exception.__class__.__name__}: {self.exception} while " f"{self.exception.__class__.__name__}: {self.exception} while "
f"handling path {self.request.path}", f"handling path {self.request.path}",
f"Traceback of {self.request.app.name} (most recent call last):\n", f"Traceback of {self.request.app.name} "
] "(most recent call last):\n",
]
while exc_value: while exc_value:
exceptions.append(self._format_exc(exc_value)) exceptions.append(self._format_exc(exc_value))
exc_value = exc_value.__cause__ exc_value = exc_value.__cause__
return "\n".join(lines + exceptions[::-1]) lines += exceptions[::-1]
for attr, display in (("context", True), ("extra", bool(full))):
info = getattr(self.exception, attr, None)
if info and display:
lines += self._generate_object_display_list(info, attr)
return "\n".join(lines)
def _format_exc(self, exc): def _format_exc(self, exc):
frames = "\n\n".join( frames = "\n\n".join(
@ -272,6 +314,13 @@ class TextRenderer(BaseRenderer):
) )
return f"{self.SPACER}{exc.__class__.__name__}: {exc}\n{frames}" return f"{self.SPACER}{exc.__class__.__name__}: {exc}\n{frames}"
def _generate_object_display_list(self, obj, descriptor):
lines = [f"\n{descriptor.title()}"]
for key, value in obj.items():
display = self.dumps(value)
lines.append(f"{self.SPACER * 2}{key}: {display}")
return lines
class JSONRenderer(BaseRenderer): class JSONRenderer(BaseRenderer):
""" """
@ -280,11 +329,11 @@ class JSONRenderer(BaseRenderer):
def full(self) -> HTTPResponse: def full(self) -> HTTPResponse:
output = self._generate_output(full=True) output = self._generate_output(full=True)
return json(output, status=self.status, dumps=dumps) return json(output, status=self.status, dumps=self.dumps)
def minimal(self) -> HTTPResponse: def minimal(self) -> HTTPResponse:
output = self._generate_output(full=False) output = self._generate_output(full=False)
return json(output, status=self.status, dumps=dumps) return json(output, status=self.status, dumps=self.dumps)
def _generate_output(self, *, full): def _generate_output(self, *, full):
output = { output = {
@ -293,6 +342,11 @@ class JSONRenderer(BaseRenderer):
"message": self.text, "message": self.text,
} }
for attr, display in (("context", True), ("extra", bool(full))):
info = getattr(self.exception, attr, None)
if info and display:
output[attr] = info
if full: if full:
_, exc_value, __ = sys.exc_info() _, exc_value, __ = sys.exc_info()
exceptions = [] exceptions = []
@ -393,7 +447,8 @@ def exception_response(
# from the route # from the route
if request.route: if request.route:
try: try:
render_format = request.route.ctx.error_format if request.route.ctx.error_format:
render_format = request.route.ctx.error_format
except AttributeError: except AttributeError:
... ...

View File

@ -1,4 +1,4 @@
from typing import Optional, Union from typing import Any, Dict, Optional, Union
from sanic.helpers import STATUS_CODES from sanic.helpers import STATUS_CODES
@ -11,7 +11,11 @@ class SanicException(Exception):
message: Optional[Union[str, bytes]] = None, message: Optional[Union[str, bytes]] = None,
status_code: Optional[int] = None, status_code: Optional[int] = None,
quiet: Optional[bool] = None, quiet: Optional[bool] = None,
context: Optional[Dict[str, Any]] = None,
extra: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.context = context
self.extra = extra
if message is None: if message is None:
if self.message: if self.message:
message = self.message message = self.message

View File

@ -38,7 +38,14 @@ class ErrorHandler:
self.base = base self.base = base
@classmethod @classmethod
def finalize(cls, error_handler): def finalize(cls, error_handler, fallback: Optional[str] = None):
if (
fallback
and fallback != "auto"
and error_handler.fallback == "auto"
):
error_handler.fallback = fallback
if not isinstance(error_handler, cls): if not isinstance(error_handler, cls):
error_logger.warning( error_logger.warning(
f"Error handler is non-conforming: {type(error_handler)}" f"Error handler is non-conforming: {type(error_handler)}"

View File

@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta):
self.keep_alive = True self.keep_alive = True
self.stage: Stage = Stage.IDLE self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch self.dispatch = self.protocol.app.dispatch
self.init_for_request()
def init_for_request(self): def init_for_request(self):
"""Init/reset all per-request variables.""" """Init/reset all per-request variables."""
@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta):
""" """
HTTP 1.1 connection handler HTTP 1.1 connection handler
""" """
while True: # As long as connection stays keep-alive # Handle requests while the connection stays reusable
while self.keep_alive and self.stage is Stage.IDLE:
self.init_for_request()
# Wait for incoming bytes (in IDLE stage)
if not self.recv_buffer:
await self._receive_more()
self.stage = Stage.REQUEST
try: try:
# Receive and handle a request # Receive and handle a request
self.stage = Stage.REQUEST
self.response_func = self.http1_response_header self.response_func = self.http1_response_header
await self.http1_request_header() await self.http1_request_header()
self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request) await self.protocol.request_handler(self.request)
@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta):
if self.response: if self.response:
self.response.stream = None self.response.stream = None
# Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive:
break
self.init_for_request()
# Wait for the next request
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self): # no cov async def http1_request_header(self): # no cov
""" """
Receive and parse request header into self.request. Receive and parse request header into self.request.
@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta):
# Remove header and its trailing CRLF # Remove header and its trailing CRLF
del buf[: pos + 4] del buf[: pos + 4]
self.stage = Stage.HANDLER
self.request, request.stream = request, self self.request, request.stream = request, self
self.protocol.state["requests_count"] += 1 self.protocol.state["requests_count"] += 1

View File

@ -918,7 +918,7 @@ class RouteMixin:
return route return route
def _determine_error_format(self, handler) -> str: def _determine_error_format(self, handler) -> Optional[str]:
if not isinstance(handler, CompositionView): if not isinstance(handler, CompositionView):
try: try:
src = dedent(getsource(handler)) src = dedent(getsource(handler))
@ -930,7 +930,7 @@ class RouteMixin:
except (OSError, TypeError): except (OSError, TypeError):
... ...
return "auto" return None
def _get_response_types(self, node): def _get_response_types(self, node):
types = set() types = set()

View File

@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, Optional, Set from enum import Enum
from typing import Any, Callable, Dict, Optional, Set, Union
from sanic.models.futures import FutureSignal from sanic.models.futures import FutureSignal
from sanic.models.handler_types import SignalHandler from sanic.models.handler_types import SignalHandler
@ -19,7 +20,7 @@ class SignalMixin:
def signal( def signal(
self, self,
event: str, event: Union[str, Enum],
*, *,
apply: bool = True, apply: bool = True,
condition: Dict[str, Any] = None, condition: Dict[str, Any] = None,
@ -41,13 +42,11 @@ class SignalMixin:
filtering, defaults to None filtering, defaults to None
:type condition: Dict[str, Any], optional :type condition: Dict[str, Any], optional
""" """
event_value = str(event.value) if isinstance(event, Enum) else event
def decorator(handler: SignalHandler): def decorator(handler: SignalHandler):
nonlocal event
nonlocal apply
future_signal = FutureSignal( future_signal = FutureSignal(
handler, event, HashableDict(condition or {}) handler, event_value, HashableDict(condition or {})
) )
self._future_signals.add(future_signal) self._future_signals.add(future_signal)

View File

@ -60,3 +60,7 @@ class FutureSignal(NamedTuple):
handler: SignalHandler handler: SignalHandler
event: str event: str
condition: Optional[Dict[str, str]] condition: Optional[Dict[str, str]]
class FutureRegistry(set):
...

View File

@ -47,16 +47,18 @@ def _get_args_for_reloading():
return [sys.executable] + sys.argv return [sys.executable] + sys.argv
def restart_with_reloader(): def restart_with_reloader(changed=None):
"""Create a new process and a subprocess in it with the same arguments as """Create a new process and a subprocess in it with the same arguments as
this one. this one.
""" """
reloaded = ",".join(changed) if changed else ""
return subprocess.Popen( return subprocess.Popen(
_get_args_for_reloading(), _get_args_for_reloading(),
env={ env={
**os.environ, **os.environ,
"SANIC_SERVER_RUNNING": "true", "SANIC_SERVER_RUNNING": "true",
"SANIC_RELOADER_PROCESS": "true", "SANIC_RELOADER_PROCESS": "true",
"SANIC_RELOADED_FILES": reloaded,
}, },
) )
@ -94,24 +96,27 @@ def watchdog(sleep_interval, app):
try: try:
while True: while True:
need_reload = False
changed = set()
for filename in itertools.chain( for filename in itertools.chain(
_iter_module_files(), _iter_module_files(),
*(d.glob("**/*") for d in app.reload_dirs), *(d.glob("**/*") for d in app.reload_dirs),
): ):
try: try:
check = _check_file(filename, mtimes) if _check_file(filename, mtimes):
path = (
filename
if isinstance(filename, str)
else filename.resolve()
)
changed.add(str(path))
except OSError: except OSError:
continue continue
if check: if changed:
need_reload = True
if need_reload:
worker_process.terminate() worker_process.terminate()
worker_process.wait() worker_process.wait()
worker_process = restart_with_reloader() worker_process = restart_with_reloader(changed)
sleep(sleep_interval) sleep(sleep_interval)
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -139,11 +139,10 @@ class Router(BaseRouter):
route.ctx.stream = stream route.ctx.stream = stream
route.ctx.hosts = hosts route.ctx.hosts = hosts
route.ctx.static = static route.ctx.static = static
route.ctx.error_format = ( route.ctx.error_format = error_format
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
check_error_format(route.ctx.error_format) if error_format:
check_error_format(route.ctx.error_format)
routes.append(route) routes.append(route)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
from enum import Enum
from inspect import isawaitable from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -14,29 +15,47 @@ from sanic.log import error_logger, logger
from sanic.models.handler_types import SignalHandler from sanic.models.handler_types import SignalHandler
class Event(Enum):
SERVER_INIT_AFTER = "server.init.after"
SERVER_INIT_BEFORE = "server.init.before"
SERVER_SHUTDOWN_AFTER = "server.shutdown.after"
SERVER_SHUTDOWN_BEFORE = "server.shutdown.before"
HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin"
HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete"
HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception"
HTTP_LIFECYCLE_HANDLE = "http.lifecycle.handle"
HTTP_LIFECYCLE_READ_BODY = "http.lifecycle.read_body"
HTTP_LIFECYCLE_READ_HEAD = "http.lifecycle.read_head"
HTTP_LIFECYCLE_REQUEST = "http.lifecycle.request"
HTTP_LIFECYCLE_RESPONSE = "http.lifecycle.response"
HTTP_ROUTING_AFTER = "http.routing.after"
HTTP_ROUTING_BEFORE = "http.routing.before"
HTTP_LIFECYCLE_SEND = "http.lifecycle.send"
HTTP_MIDDLEWARE_AFTER = "http.middleware.after"
HTTP_MIDDLEWARE_BEFORE = "http.middleware.before"
RESERVED_NAMESPACES = { RESERVED_NAMESPACES = {
"server": ( "server": (
# "server.main.start", Event.SERVER_INIT_AFTER.value,
# "server.main.stop", Event.SERVER_INIT_BEFORE.value,
"server.init.before", Event.SERVER_SHUTDOWN_AFTER.value,
"server.init.after", Event.SERVER_SHUTDOWN_BEFORE.value,
"server.shutdown.before",
"server.shutdown.after",
), ),
"http": ( "http": (
"http.lifecycle.begin", Event.HTTP_LIFECYCLE_BEGIN.value,
"http.lifecycle.complete", Event.HTTP_LIFECYCLE_COMPLETE.value,
"http.lifecycle.exception", Event.HTTP_LIFECYCLE_EXCEPTION.value,
"http.lifecycle.handle", Event.HTTP_LIFECYCLE_HANDLE.value,
"http.lifecycle.read_body", Event.HTTP_LIFECYCLE_READ_BODY.value,
"http.lifecycle.read_head", Event.HTTP_LIFECYCLE_READ_HEAD.value,
"http.lifecycle.request", Event.HTTP_LIFECYCLE_REQUEST.value,
"http.lifecycle.response", Event.HTTP_LIFECYCLE_RESPONSE.value,
"http.routing.after", Event.HTTP_ROUTING_AFTER.value,
"http.routing.before", Event.HTTP_ROUTING_BEFORE.value,
"http.lifecycle.send", Event.HTTP_LIFECYCLE_SEND.value,
"http.middleware.after", Event.HTTP_MIDDLEWARE_AFTER.value,
"http.middleware.before", Event.HTTP_MIDDLEWARE_BEFORE.value,
), ),
} }

View File

@ -1,6 +1,4 @@
from copy import deepcopy from sanic import Blueprint, Sanic
from sanic import Blueprint, Sanic, blueprints, response
from sanic.response import text from sanic.response import text

View File

@ -1088,3 +1088,31 @@ def test_bp_set_attribute_warning():
"and will be removed in version 21.12. You should change your " "and will be removed in version 21.12. You should change your "
"Blueprint instance to use instance.ctx.foo instead." "Blueprint instance to use instance.ctx.foo instead."
) )
def test_early_registration(app):
assert len(app.router.routes) == 0
bp = Blueprint("bp")
@bp.get("/one")
async def one(_):
return text("one")
app.blueprint(bp)
assert len(app.router.routes) == 1
@bp.get("/two")
async def two(_):
return text("two")
@bp.get("/three")
async def three(_):
return text("three")
assert len(app.router.routes) == 3
for path in ("one", "two", "three"):
_, response = app.test_client.get(f"/{path}")
assert response.text == path

48
tests/test_coffee.py Normal file
View File

@ -0,0 +1,48 @@
import logging
from unittest.mock import patch
import pytest
from sanic.application.logo import COFFEE_LOGO, get_logo
from sanic.exceptions import SanicException
def has_sugar(value):
if value:
raise SanicException("I said no sugar please")
return False
@pytest.mark.parametrize("sugar", (True, False))
def test_no_sugar(sugar):
if sugar:
with pytest.raises(SanicException):
assert has_sugar(sugar)
else:
assert not has_sugar(sugar)
def test_get_logo_returns_expected_logo():
with patch("sys.stdout.isatty") as isatty:
isatty.return_value = True
logo = get_logo(coffee=True)
assert logo is COFFEE_LOGO
def test_logo_true(app, caplog):
@app.after_server_start
async def shutdown(*_):
app.stop()
with patch("sys.stdout.isatty") as isatty:
isatty.return_value = True
with caplog.at_level(logging.DEBUG):
app.make_coffee()
# Only in the regular logo
assert " ▄███ █████ ██ " not in caplog.text
# Only in the coffee logo
assert " ██ ██▀▀▄ " in caplog.text

View File

@ -1,9 +1,9 @@
from contextlib import contextmanager from contextlib import contextmanager
from email import message
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from textwrap import dedent from textwrap import dedent
from unittest.mock import Mock
import pytest import pytest
@ -360,3 +360,31 @@ def test_deprecation_notice_when_setting_logo(app):
) )
with pytest.warns(DeprecationWarning, match=message): with pytest.warns(DeprecationWarning, match=message):
app.config.LOGO = "My Custom Logo" app.config.LOGO = "My Custom Logo"
def test_config_set_methods(app, monkeypatch):
post_set = Mock()
monkeypatch.setattr(Config, "_post_set", post_set)
app.config.FOO = 1
post_set.assert_called_once_with("FOO", 1)
post_set.reset_mock()
app.config["FOO"] = 2
post_set.assert_called_once_with("FOO", 2)
post_set.reset_mock()
app.config.update({"FOO": 3})
post_set.assert_called_once_with("FOO", 3)
post_set.reset_mock()
app.config.update([("FOO", 4)])
post_set.assert_called_once_with("FOO", 4)
post_set.reset_mock()
app.config.update(FOO=5)
post_set.assert_called_once_with("FOO", 5)
post_set.reset_mock()
app.config.update_config({"FOO": 6})
post_set.assert_called_once_with("FOO", 6)

View File

@ -1,8 +1,10 @@
import pytest import pytest
from sanic import Sanic from sanic import Sanic
from sanic.config import Config
from sanic.errorpages import HTMLRenderer, exception_response from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound, SanicException from sanic.exceptions import NotFound, SanicException
from sanic.handlers import ErrorHandler
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text from sanic.response import HTTPResponse, html, json, text
@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected):
) )
assert response.content_type == expected assert response.content_type == expected
def test_allow_fallback_error_format_set_main_process_start(app):
@app.main_process_start
async def start(app, _):
app.config.FALLBACK_ERROR_FORMAT = "text"
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_setting_fallback_to_non_default_raise_warning(app):
app.error_handler = ErrorHandler(fallback="text")
assert app.error_handler.fallback == "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to auto."
),
):
app.config.FALLBACK_ERROR_FORMAT = "auto"
assert app.error_handler.fallback == "auto"
app.config.FALLBACK_ERROR_FORMAT = "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to json."
),
):
app.config.FALLBACK_ERROR_FORMAT = "json"
assert app.error_handler.fallback == "json"
def test_allow_fallback_error_format_in_config_injection():
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app = Sanic("test", config=MyConfig())
@app.route("/error", methods=["GET", "POST"])
def err(request):
raise Exception("something went wrong")
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_allow_fallback_error_format_in_config_replacement(app):
class MyConfig(Config):
FALLBACK_ERROR_FORMAT = "text"
app.config = MyConfig()
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"

View File

@ -18,6 +18,16 @@ from sanic.exceptions import (
from sanic.response import text from sanic.response import text
def dl_to_dict(soup, css_class):
keys, values = [], []
for dl in soup.find_all("dl", {"class": css_class}):
for dt in dl.find_all("dt"):
keys.append(dt.text.strip())
for dd in dl.find_all("dd"):
values.append(dd.text.strip())
return dict(zip(keys, values))
class SanicExceptionTestException(Exception): class SanicExceptionTestException(Exception):
pass pass
@ -264,3 +274,110 @@ def test_exception_in_ws_logged(caplog):
error_logs = [r for r in caplog.record_tuples if r[0] == "sanic.error"] error_logs = [r for r in caplog.record_tuples if r[0] == "sanic.error"]
assert error_logs[1][1] == logging.ERROR assert error_logs[1][1] == logging.ERROR
assert "Exception occurred while handling uri:" in error_logs[1][2] assert "Exception occurred while handling uri:" in error_logs[1][2]
@pytest.mark.parametrize("debug", (True, False))
def test_contextual_exception_context(debug):
app = Sanic(__name__)
class TeapotError(SanicException):
status_code = 418
message = "Sorry, I cannot brew coffee"
def fail():
raise TeapotError(context={"foo": "bar"})
app.post("/coffee/json", error_format="json")(lambda _: fail())
app.post("/coffee/html", error_format="html")(lambda _: fail())
app.post("/coffee/text", error_format="text")(lambda _: fail())
_, response = app.test_client.post("/coffee/json", debug=debug)
assert response.status == 418
assert response.json["message"] == "Sorry, I cannot brew coffee"
assert response.json["context"] == {"foo": "bar"}
_, response = app.test_client.post("/coffee/html", debug=debug)
soup = BeautifulSoup(response.body, "html.parser")
dl = dl_to_dict(soup, "context")
assert response.status == 418
assert "Sorry, I cannot brew coffee" in soup.find("p").text
assert dl == {"foo": "bar"}
_, response = app.test_client.post("/coffee/text", debug=debug)
lines = list(map(lambda x: x.decode(), response.body.split(b"\n")))
idx = lines.index("Context") + 1
assert response.status == 418
assert lines[2] == "Sorry, I cannot brew coffee"
assert lines[idx] == ' foo: "bar"'
@pytest.mark.parametrize("debug", (True, False))
def test_contextual_exception_extra(debug):
app = Sanic(__name__)
class TeapotError(SanicException):
status_code = 418
@property
def message(self):
return f"Found {self.extra['foo']}"
def fail():
raise TeapotError(extra={"foo": "bar"})
app.post("/coffee/json", error_format="json")(lambda _: fail())
app.post("/coffee/html", error_format="html")(lambda _: fail())
app.post("/coffee/text", error_format="text")(lambda _: fail())
_, response = app.test_client.post("/coffee/json", debug=debug)
assert response.status == 418
assert response.json["message"] == "Found bar"
if debug:
assert response.json["extra"] == {"foo": "bar"}
else:
assert "extra" not in response.json
_, response = app.test_client.post("/coffee/html", debug=debug)
soup = BeautifulSoup(response.body, "html.parser")
dl = dl_to_dict(soup, "extra")
assert response.status == 418
assert "Found bar" in soup.find("p").text
if debug:
assert dl == {"foo": "bar"}
else:
assert not dl
_, response = app.test_client.post("/coffee/text", debug=debug)
lines = list(map(lambda x: x.decode(), response.body.split(b"\n")))
assert response.status == 418
assert lines[2] == "Found bar"
if debug:
idx = lines.index("Extra") + 1
assert lines[idx] == ' foo: "bar"'
else:
assert "Extra" not in lines
@pytest.mark.parametrize("override", (True, False))
def test_contextual_exception_functional_message(override):
app = Sanic(__name__)
class TeapotError(SanicException):
status_code = 418
@property
def message(self):
return f"Received foo={self.context['foo']}"
@app.post("/coffee", error_format="json")
async def make_coffee(_):
error_args = {"context": {"foo": "bar"}}
if override:
error_args["message"] = "override"
raise TeapotError(**error_args)
_, response = app.test_client.post("/coffee", debug=True)
error_message = "override" if override else "Received foo=bar"
assert response.status == 418
assert response.json["message"] == error_message
assert response.json["context"] == {"foo": "bar"}

View File

@ -1,109 +0,0 @@
import asyncio
import httpcore
import httpx
import pytest
from sanic_testing.testing import SanicTestClient
from sanic import Sanic
from sanic.response import text
class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection):
async def arequest(self, *args, **kwargs):
await asyncio.sleep(2)
return await super().arequest(*args, **kwargs)
async def _open_socket(self, *args, **kwargs):
retval = await super()._open_socket(*args, **kwargs)
if self._request_delay:
await asyncio.sleep(self._request_delay)
return retval
class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool):
def __init__(self, request_delay=None, *args, **kwargs):
self._request_delay = request_delay
super().__init__(*args, **kwargs)
async def _add_to_pool(self, connection, timeout):
connection.__class__ = DelayableHTTPConnection
connection._request_delay = self._request_delay
await super()._add_to_pool(connection, timeout)
class DelayableSanicSession(httpx.AsyncClient):
def __init__(self, request_delay=None, *args, **kwargs) -> None:
transport = DelayableSanicConnectionPool(request_delay=request_delay)
super().__init__(transport=transport, *args, **kwargs)
class DelayableSanicTestClient(SanicTestClient):
def __init__(self, app, request_delay=None):
super().__init__(app)
self._request_delay = request_delay
self._loop = None
def get_new_session(self):
return DelayableSanicSession(request_delay=self._request_delay)
@pytest.fixture
def request_no_timeout_app():
app = Sanic("test_request_no_timeout")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler2(request):
return text("OK")
return app
@pytest.fixture
def request_timeout_default_app():
app = Sanic("test_request_timeout_default")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler1(request):
return text("OK")
@app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
return app
def test_default_server_error_request_timeout(request_timeout_default_app):
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/1")
assert response.status == 408
assert "Request Timeout" in response.text
def test_default_server_error_request_dont_timeout(request_no_timeout_app):
client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
_, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
def test_default_server_error_websocket_request_timeout(
request_timeout_default_app,
):
headers = {
"Upgrade": "websocket",
"Connection": "upgrade",
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version": "13",
}
client = DelayableSanicTestClient(request_timeout_default_app, 2)
_, response = client.get("/ws1", headers=headers)
assert response.status == 408
assert "Request Timeout" in response.text

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
from enum import Enum
from inspect import isawaitable from inspect import isawaitable
import pytest import pytest
@ -50,6 +51,25 @@ def test_invalid_signal(app, signal):
... ...
@pytest.mark.asyncio
async def test_dispatch_signal_with_enum_event(app):
counter = 0
class FooEnum(Enum):
FOO_BAR_BAZ = "foo.bar.baz"
@app.signal(FooEnum.FOO_BAR_BAZ)
def sync_signal(*_):
nonlocal counter
counter += 1
app.signal_router.finalize()
await app.dispatch("foo.bar.baz")
assert counter == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dispatch_signal_triggers_multiple_handlers(app): async def test_dispatch_signal_triggers_multiple_handlers(app):
counter = 0 counter = 0

View File

@ -26,6 +26,7 @@ def protocol(app, mock_transport):
protocol = HttpProtocol(loop=loop, app=app) protocol = HttpProtocol(loop=loop, app=app)
protocol.connection_made(mock_transport) protocol.connection_made(mock_transport)
protocol._setup_connection() protocol._setup_connection()
protocol._http.init_for_request()
protocol._task = Mock(spec=asyncio.Task) protocol._task = Mock(spec=asyncio.Task)
protocol._task.cancel = Mock() protocol._task.cancel = Mock()
return protocol return protocol