Signals Integration (#2160)

* Update some tests

* Resolve #2122 route decorator returning tuple

* Use rc sanic-routing version

* Update unit tests to <:str>

* Minimal working version with some signals implemented

* Add more http signals

* Update ASGI and change listeners to signals

* Allow for dynamic ODE signals

* Allow signals to be stacked

* Begin tests

* Prioritize match_info on keyword argument injection

* WIP on tests

* Compat with signals

* Work through some test coverage

* Passing tests

* Post linting

* Setup proper resets

* coverage reporting

* Fixes from vltr comments

* clear delayed tasks

* Fix bad test

* rm pycache
This commit is contained in:
Adam Hopkins
2021-08-05 22:55:42 +03:00
committed by GitHub
parent 0ba57d4701
commit b1b12e004e
31 changed files with 823 additions and 513 deletions

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
import logging
import logging.config
import os
import re
from asyncio import (
AbstractEventLoop,
CancelledError,
Protocol,
ensure_future,
@@ -72,19 +75,27 @@ from sanic.server import AsyncioServer, HttpProtocol
from sanic.server import Signal as ServerSignal
from sanic.server import serve, serve_multiple, serve_single
from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
from sanic.websocket import ConnectionClosed, WebSocketProtocol
class Sanic(BaseSanic):
class Sanic(BaseSanic, metaclass=TouchUpMeta):
"""
The main application instance
"""
__touchup__ = (
"handle_request",
"handle_exception",
"_run_response_middleware",
"_run_request_middleware",
)
__fake_slots__ = (
"_asgi_app",
"_app_registry",
"_asgi_client",
"_blueprint_order",
"_delayed_tasks",
"_future_routes",
"_future_statics",
"_future_middleware",
@@ -155,6 +166,7 @@ class Sanic(BaseSanic):
self._asgi_client = None
self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
self._test_client = None
self._test_manager = None
self.asgi = False
@@ -192,6 +204,7 @@ class Sanic(BaseSanic):
self.__class__.register_app(self)
self.router.ctx.app = self
self.signal_router.ctx.app = self
if dumps:
BaseHTTPResponse._dumps = dumps # type: ignore
@@ -232,9 +245,12 @@ class Sanic(BaseSanic):
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:
self.listener("before_server_start")(
partial(self._loop_add_task, task)
)
task_name = f"sanic.delayed_task.{hash(task)}"
if not self._delayed_tasks:
self.after_server_start(partial(self.dispatch_delayed_tasks))
self.signal(task_name)(partial(self.run_delayed_task, task=task))
self._delayed_tasks.append(task_name)
def register_listener(self, listener: Callable, event: str) -> Any:
"""
@@ -246,12 +262,20 @@ class Sanic(BaseSanic):
"""
try:
_event = ListenerEvent(event)
except ValueError:
valid = ", ".join(ListenerEvent.__members__.values())
_event = ListenerEvent[event.upper()]
except (ValueError, AttributeError):
valid = ", ".join(
map(lambda x: x.lower(), ListenerEvent.__members__.keys())
)
raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}")
self.listeners[_event].append(listener)
if "." in _event:
self.signal(_event.value)(
partial(self._listener, listener=listener)
)
else:
self.listeners[_event.value].append(listener)
return listener
def register_middleware(self, middleware, attach_to: str = "request"):
@@ -379,11 +403,17 @@ class Sanic(BaseSanic):
*,
condition: Optional[Dict[str, str]] = None,
context: Optional[Dict[str, Any]] = None,
fail_not_found: bool = True,
inline: bool = False,
reverse: bool = False,
) -> Coroutine[Any, Any, Awaitable[Any]]:
return self.signal_router.dispatch(
event,
context=context,
condition=condition,
inline=inline,
reverse=reverse,
fail_not_found=fail_not_found,
)
async def event(
@@ -659,7 +689,7 @@ class Sanic(BaseSanic):
async def handle_exception(
self, request: Request, exception: BaseException
):
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
@@ -669,6 +699,12 @@ class Sanic(BaseSanic):
:type exception: BaseException
:raises ServerError: response 500
"""
await self.dispatch(
"http.lifecycle.exception",
inline=True,
context={"request": request, "exception": exception},
)
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
@@ -715,7 +751,7 @@ class Sanic(BaseSanic):
f"Invalid response type {response!r} (need HTTPResponse)"
)
async def handle_request(self, request: Request):
async def handle_request(self, request: Request): # no cov
"""Take a request from the HTTP Server and return a response object
to be sent back The HTTP Server only expects a response object, so
exception handling must be done here
@@ -723,10 +759,22 @@ class Sanic(BaseSanic):
:param request: HTTP Request object
:return: Nothing
"""
await self.dispatch(
"http.lifecycle.handle",
inline=True,
context={"request": request},
)
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
try:
await self.dispatch(
"http.routing.before",
inline=True,
context={"request": request},
)
# Fetch handler from router
route, handler, kwargs = self.router.get(
request.path,
@@ -734,9 +782,20 @@ class Sanic(BaseSanic):
request.headers.getone("host", None),
)
request._match_info = kwargs
request._match_info = {**kwargs}
request.route = route
await self.dispatch(
"http.routing.after",
inline=True,
context={
"request": request,
"route": route,
"kwargs": kwargs,
"handler": handler,
},
)
if (
request.stream
and request.stream.request_body
@@ -772,7 +831,7 @@ class Sanic(BaseSanic):
)
# Run response handler
response = handler(request, **kwargs)
response = handler(request, **request.match_info)
if isawaitable(response):
response = await response
@@ -783,6 +842,14 @@ class Sanic(BaseSanic):
# Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse):
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": response,
},
)
await response.send(end_stream=True)
else:
if not hasattr(handler, "is_websocket"):
@@ -1078,11 +1145,6 @@ class Sanic(BaseSanic):
run_async=return_asyncio_server,
)
# Trigger before_start events
await self.trigger_events(
server_settings.get("before_start", []),
server_settings.get("loop"),
)
main_start = server_settings.pop("main_start", None)
main_stop = server_settings.pop("main_stop", None)
if main_start or main_stop:
@@ -1095,17 +1157,9 @@ class Sanic(BaseSanic):
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
)
async def trigger_events(self, events, loop):
"""Trigger events (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
for event in events:
result = event(loop)
if isawaitable(result):
await result
async def _run_request_middleware(self, request, request_name=None):
async def _run_request_middleware(
self, request, request_name=None
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
@@ -1118,25 +1172,67 @@ class Sanic(BaseSanic):
request.request_middleware_started = True
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request)
if isawaitable(response):
response = await response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
return None
async def _run_response_middleware(
self, request, response, request_name=None
):
): # no cov
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
@@ -1162,10 +1258,6 @@ class Sanic(BaseSanic):
):
"""Helper function used by `run` and `create_server`."""
self.listeners["before_server_start"] = [
self.finalize
] + self.listeners["before_server_start"]
if isinstance(ssl, dict):
# try common aliaseses
cert = ssl.get("cert") or ssl.get("certificate")
@@ -1202,10 +1294,6 @@ class Sanic(BaseSanic):
# Register start/stop events
for event_name, settings_name, reverse in (
("before_server_start", "before_start", False),
("after_server_start", "after_start", False),
("before_server_stop", "before_stop", True),
("after_server_stop", "after_stop", True),
("main_process_start", "main_start", False),
("main_process_stop", "main_stop", True),
):
@@ -1253,20 +1341,44 @@ class Sanic(BaseSanic):
return ".".join(parts)
@classmethod
def _loop_add_task(cls, task, app, loop):
def _prep_task(cls, task, app, loop):
if callable(task):
try:
loop.create_task(task(app))
task = task(app)
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)
task = task()
return task
@classmethod
def _loop_add_task(cls, task, app, loop):
prepped = cls._prep_task(task, app, loop)
loop.create_task(prepped)
@classmethod
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()
@staticmethod
async def dispatch_delayed_tasks(app, loop):
for name in app._delayed_tasks:
await app.dispatch(name, context={"app": app, "loop": loop})
app._delayed_tasks.clear()
@staticmethod
async def run_delayed_task(app, loop, task):
prepped = app._prep_task(task, app, loop)
await prepped
@staticmethod
async def _listener(
app: Sanic, loop: AbstractEventLoop, listener: ListenerType
):
maybe_coro = listener(app, loop)
if maybe_coro and isawaitable(maybe_coro):
await maybe_coro
# -------------------------------------------------------------------- #
# ASGI
# -------------------------------------------------------------------- #
@@ -1340,15 +1452,51 @@ class Sanic(BaseSanic):
raise SanicException(f'Sanic app name "{name}" not found.')
# -------------------------------------------------------------------- #
# Static methods
# Lifecycle
# -------------------------------------------------------------------- #
@staticmethod
async def finalize(app, _):
def finalize(self):
try:
app.router.finalize()
if app.signal_router.routes:
app.signal_router.finalize() # noqa
self.router.finalize()
except FinalizationError as e:
if not Sanic.test_mode:
raise e # noqa
raise e
def signalize(self):
try:
self.signal_router.finalize()
except FinalizationError as e:
if not Sanic.test_mode:
raise e
async def _startup(self):
self.signalize()
self.finalize()
TouchUp.run(self)
async def _server_event(
self,
concern: str,
action: str,
loop: Optional[AbstractEventLoop] = None,
) -> None:
event = f"server.{concern}.{action}"
if action not in ("before", "after") or concern not in (
"init",
"shutdown",
):
raise SanicException(f"Invalid server event: {event}")
logger.debug(f"Triggering server events: {event}")
reverse = concern == "shutdown"
if loop is None:
loop = self.loop
await self.dispatch(
event,
fail_not_found=False,
reverse=reverse,
inline=True,
context={
"app": self,
"loop": loop,
},
)

View File

@@ -1,6 +1,5 @@
import warnings
from inspect import isawaitable
from typing import Optional
from urllib.parse import quote
@@ -18,14 +17,20 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app
if "before_server_start" in self.asgi_app.sanic_app.listeners:
if (
"server.init.before"
in self.asgi_app.sanic_app.signal_router.name_index
):
warnings.warn(
'You have set a listener for "before_server_start" '
"in ASGI mode. "
"It will be executed as early as possible, but not before "
"the ASGI server is started."
)
if "after_server_stop" in self.asgi_app.sanic_app.listeners:
if (
"server.shutdown.after"
in self.asgi_app.sanic_app.signal_router.name_index
):
warnings.warn(
'You have set a listener for "after_server_stop" '
"in ASGI mode. "
@@ -42,19 +47,9 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single
startup event.
"""
self.asgi_app.sanic_app.router.finalize()
if self.asgi_app.sanic_app.signal_router.routes:
self.asgi_app.sanic_app.signal_router.finalize()
listeners = self.asgi_app.sanic_app.listeners.get(
"before_server_start", []
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if response and isawaitable(response):
await response
await self.asgi_app.sanic_app._startup()
await self.asgi_app.sanic_app._server_event("init", "before")
await self.asgi_app.sanic_app._server_event("init", "after")
async def shutdown(self) -> None:
"""
@@ -65,16 +60,8 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single
shutdown event.
"""
listeners = self.asgi_app.sanic_app.listeners.get(
"before_server_stop", []
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in listeners:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if response and isawaitable(response):
await response
await self.asgi_app.sanic_app._server_event("shutdown", "before")
await self.asgi_app.sanic_app._server_event("shutdown", "after")
async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend

View File

@@ -21,6 +21,7 @@ from sanic.exceptions import (
from sanic.headers import format_http1_response
from sanic.helpers import has_message_body
from sanic.log import access_logger, error_logger, logger
from sanic.touchup import TouchUpMeta
class Stage(Enum):
@@ -45,7 +46,7 @@ class Stage(Enum):
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http:
class Http(metaclass=TouchUpMeta):
"""
Internal helper for managing the HTTP request/response cycle
@@ -67,9 +68,15 @@ class Http:
HEADER_CEILING = 16_384
HEADER_MAX_SIZE = 0
__touchup__ = (
"http1_request_header",
"http1_response_header",
"read",
)
__slots__ = [
"_send",
"_receive_more",
"dispatch",
"recv_buffer",
"protocol",
"expecting_continue",
@@ -97,6 +104,7 @@ class Http:
self.protocol = protocol
self.keep_alive = True
self.stage: Stage = Stage.IDLE
self.dispatch = self.protocol.app.dispatch
self.init_for_request()
def init_for_request(self):
@@ -183,7 +191,7 @@ class Http:
if not self.recv_buffer:
await self._receive_more()
async def http1_request_header(self):
async def http1_request_header(self): # no cov
"""
Receive and parse request header into self.request.
"""
@@ -212,6 +220,12 @@ class Http:
reqline, *split_headers = raw_headers.split("\r\n")
method, self.url, protocol = reqline.split(" ")
await self.dispatch(
"http.lifecycle.read_head",
inline=True,
context={"head": bytes(head)},
)
if protocol == "HTTP/1.1":
self.keep_alive = True
elif protocol == "HTTP/1.0":
@@ -250,6 +264,11 @@ class Http:
transport=self.protocol.transport,
app=self.protocol.app,
)
await self.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": request},
)
# Prepare for request body
self.request_bytes_left = self.request_bytes = 0
@@ -280,7 +299,7 @@ class Http:
async def http1_response_header(
self, data: bytes, end_stream: bool
) -> None:
) -> None: # no cov
res = self.response
# Compatibility with simple response body
@@ -469,7 +488,7 @@ class Http:
if data:
yield data
async def read(self) -> Optional[bytes]:
async def read(self) -> Optional[bytes]: # no cov
"""
Read some bytes of request body.
"""
@@ -543,6 +562,12 @@ class Http:
self.request_bytes_left -= size
await self.dispatch(
"http.lifecycle.read_body",
inline=True,
context={"body": data},
)
return data
# Response methods

View File

@@ -9,10 +9,10 @@ class ListenerEvent(str, Enum):
def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower()
BEFORE_SERVER_START = auto()
AFTER_SERVER_START = auto()
BEFORE_SERVER_STOP = auto()
AFTER_SERVER_STOP = auto()
BEFORE_SERVER_START = "server.init.before"
AFTER_SERVER_START = "server.init.after"
BEFORE_SERVER_STOP = "server.shutdown.before"
AFTER_SERVER_STOP = "server.shutdown.after"
MAIN_PROCESS_START = auto()
MAIN_PROCESS_STOP = auto()

View File

@@ -23,7 +23,7 @@ class SignalMixin:
*,
apply: bool = True,
condition: Dict[str, Any] = None,
) -> Callable[[SignalHandler], FutureSignal]:
) -> Callable[[SignalHandler], SignalHandler]:
"""
For creating a signal handler, used similar to a route handler:
@@ -54,7 +54,7 @@ class SignalMixin:
if apply:
self._apply_signal(future_signal)
return future_signal
return handler
return decorator

View File

@@ -497,6 +497,10 @@ class Request:
"""
return self._match_info
@match_info.setter
def match_info(self, value):
self._match_info = value
# Transport properties (obtained from local interface only)
@property

View File

@@ -13,7 +13,7 @@ from typing import (
Union,
)
from sanic.models.handler_types import ListenerType
from sanic.touchup.meta import TouchUpMeta
if TYPE_CHECKING:
@@ -37,7 +37,7 @@ from time import monotonic as current_time
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.config import Config
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.exceptions import RequestTimeout, SanicException, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol
@@ -102,11 +102,15 @@ class ConnInfo:
self.client_port = addr[1]
class HttpProtocol(asyncio.Protocol):
class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
"""
This class provides a basic HTTP implementation of the sanic framework.
"""
__touchup__ = (
"send",
"connection_task",
)
__slots__ = (
# app
"app",
@@ -185,7 +189,7 @@ class HttpProtocol(asyncio.Protocol):
self._time = current_time()
self.check_timeouts()
async def connection_task(self):
async def connection_task(self): # no cov
"""
Run a HTTP connection.
@@ -194,6 +198,11 @@ class HttpProtocol(asyncio.Protocol):
"""
try:
self._setup_connection()
await self.app.dispatch(
"http.lifecycle.begin",
inline=True,
context={"conn_info": self.conn_info},
)
await self._http.http1()
except CancelledError:
pass
@@ -212,6 +221,13 @@ class HttpProtocol(asyncio.Protocol):
self.close()
except BaseException:
error_logger.exception("Closing failed")
finally:
await self.app.dispatch(
"http.lifecycle.complete",
inline=True,
context={"conn_info": self.conn_info},
)
...
async def receive_more(self):
"""
@@ -259,13 +275,18 @@ class HttpProtocol(asyncio.Protocol):
except Exception:
error_logger.exception("protocol.check_timeouts")
async def send(self, data):
async def send(self, data): # no cov
"""
Writes data with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
await self.app.dispatch(
"http.lifecycle.send",
inline=True,
context={"data": data},
)
self.transport.write(data)
self._time = current_time()
@@ -359,52 +380,54 @@ class AsyncioServer:
a user who needs to manage the server lifecycle manually.
"""
__slots__ = (
"loop",
"serve_coro",
"_after_start",
"_before_stop",
"_after_stop",
"server",
"connections",
)
__slots__ = ("app", "connections", "loop", "serve_coro", "server", "init")
def __init__(
self,
app,
loop,
serve_coro,
connections,
after_start: Optional[Iterable[ListenerType]],
before_stop: Optional[Iterable[ListenerType]],
after_stop: Optional[Iterable[ListenerType]],
):
# Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here.
self.app = app
self.connections = connections
self.loop = loop
self.serve_coro = serve_coro
self._after_start = after_start
self._before_stop = before_stop
self._after_stop = after_stop
self.server = None
self.connections = connections
self.init = False
def startup(self):
"""
Trigger "before_server_start" events
"""
self.init = True
return self.app._startup()
def before_start(self):
"""
Trigger "before_server_start" events
"""
return self._server_event("init", "before")
def after_start(self):
"""
Trigger "after_server_start" events
"""
trigger_events(self._after_start, self.loop)
return self._server_event("init", "after")
def before_stop(self):
"""
Trigger "before_server_stop" events
"""
trigger_events(self._before_stop, self.loop)
return self._server_event("shutdown", "before")
def after_stop(self):
"""
Trigger "after_server_stop" events
"""
trigger_events(self._after_stop, self.loop)
return self._server_event("shutdown", "after")
def is_serving(self) -> bool:
if self.server:
@@ -442,6 +465,14 @@ class AsyncioServer:
"of asyncio or uvloop."
)
def _server_event(self, concern: str, action: str):
if not self.init:
raise SanicException(
"Cannot dispatch server event without "
"first running server.startup()"
)
return self.app._server_event(concern, action, loop=self.loop)
def __await__(self):
"""
Starts the asyncio server, returns AsyncServerCoro
@@ -456,11 +487,7 @@ class AsyncioServer:
def serve(
host,
port,
app,
before_start: Optional[Iterable[ListenerType]] = None,
after_start: Optional[Iterable[ListenerType]] = None,
before_stop: Optional[Iterable[ListenerType]] = None,
after_stop: Optional[Iterable[ListenerType]] = None,
app: Sanic,
ssl: Optional[SSLContext] = None,
sock: Optional[socket.socket] = None,
unix: Optional[str] = None,
@@ -542,15 +569,14 @@ def serve(
if run_async:
return AsyncioServer(
app=app,
loop=loop,
serve_coro=server_coroutine,
connections=connections,
after_start=after_start,
before_stop=before_stop,
after_stop=after_stop,
)
trigger_events(before_start, loop)
loop.run_until_complete(app._startup())
loop.run_until_complete(app._server_event("init", "before"))
try:
http_server = loop.run_until_complete(server_coroutine)
@@ -558,8 +584,6 @@ def serve(
error_logger.exception("Unable to start server")
return
trigger_events(after_start, loop)
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
@@ -571,6 +595,8 @@ def serve(
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
loop.run_until_complete(app._server_event("init", "after"))
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
@@ -579,7 +605,7 @@ def serve(
logger.info("Stopping worker [%s]", pid)
# Run the on_stop function if provided
trigger_events(before_stop, loop)
loop.run_until_complete(app._server_event("shutdown", "before"))
# Wait for event loop to finish and all connections to drain
http_server.close()
@@ -611,8 +637,7 @@ def serve(
_shutdown = asyncio.gather(*coros)
loop.run_until_complete(_shutdown)
trigger_events(after_stop, loop)
loop.run_until_complete(app._server_event("shutdown", "after"))
remove_unix_socket(unix)

View File

@@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound # type: ignore
from sanic_routing.utils import path_to_parts # type: ignore
from sanic.exceptions import InvalidSignal
from sanic.log import error_logger, logger
from sanic.models.handler_types import SignalHandler
RESERVED_NAMESPACES = (
"server",
"http",
)
RESERVED_NAMESPACES = {
"server": (
# "server.main.start",
# "server.main.stop",
"server.init.before",
"server.init.after",
"server.shutdown.before",
"server.shutdown.after",
),
"http": (
"http.lifecycle.begin",
"http.lifecycle.complete",
"http.lifecycle.exception",
"http.lifecycle.handle",
"http.lifecycle.read_body",
"http.lifecycle.read_head",
"http.lifecycle.request",
"http.lifecycle.response",
"http.routing.after",
"http.routing.before",
"http.lifecycle.send",
"http.middleware.after",
"http.middleware.before",
),
}
def _blank():
...
class Signal(Route):
@@ -59,8 +85,13 @@ class SignalRouter(BaseRouter):
terms.append(extra)
raise NotFound(message % tuple(terms))
# Regex routes evaluate and can extract params directly. They are set
# on param_basket["__params__"]
params = param_basket["__params__"]
if not params:
# If param_basket["__params__"] does not exist, we might have
# param_basket["__matches__"], which are indexed based matches
# on path segments. They should already be cast types.
params = {
param.name: param_basket["__matches__"][idx]
for idx, param in group.params.items()
@@ -73,8 +104,18 @@ class SignalRouter(BaseRouter):
event: str,
context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None,
) -> None:
group, handlers, params = self.get(event, condition=condition)
fail_not_found: bool = True,
reverse: bool = False,
) -> Any:
try:
group, handlers, params = self.get(event, condition=condition)
except NotFound as e:
if fail_not_found:
raise e
else:
if self.ctx.app.debug:
error_logger.warning(str(e))
return None
events = [signal.ctx.event for signal in group]
for signal_event in events:
@@ -82,12 +123,19 @@ class SignalRouter(BaseRouter):
if context:
params.update(context)
if not reverse:
handlers = handlers[::-1]
try:
for handler in handlers:
if condition is None or condition == handler.__requirements__:
maybe_coroutine = handler(**params)
if isawaitable(maybe_coroutine):
await maybe_coroutine
retval = await maybe_coroutine
if retval:
return retval
elif maybe_coroutine:
return maybe_coroutine
return None
finally:
for signal_event in events:
signal_event.clear()
@@ -98,14 +146,23 @@ class SignalRouter(BaseRouter):
*,
context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None,
) -> asyncio.Task:
task = self.ctx.loop.create_task(
self._dispatch(
event,
context=context,
condition=condition,
)
fail_not_found: bool = True,
inline: bool = False,
reverse: bool = False,
) -> Union[asyncio.Task, Any]:
dispatch = self._dispatch(
event,
context=context,
condition=condition,
fail_not_found=fail_not_found and inline,
reverse=reverse,
)
logger.debug(f"Dispatching signal: {event}")
if inline:
return await dispatch
task = asyncio.get_running_loop().create_task(dispatch)
await asyncio.sleep(0)
return task
@@ -131,7 +188,9 @@ class SignalRouter(BaseRouter):
append=True,
) # type: ignore
def finalize(self, do_compile: bool = True):
def finalize(self, do_compile: bool = True, do_optimize: bool = False):
self.add(_blank, "sanic.__signal__.__init__")
try:
self.ctx.loop = asyncio.get_running_loop()
except RuntimeError:
@@ -140,7 +199,7 @@ class SignalRouter(BaseRouter):
for signal in self.routes:
signal.ctx.event = asyncio.Event()
return super().finalize(do_compile=do_compile)
return super().finalize(do_compile=do_compile, do_optimize=do_optimize)
def _build_event_parts(self, event: str) -> Tuple[str, str, str]:
parts = path_to_parts(event, self.delimiter)
@@ -151,7 +210,11 @@ class SignalRouter(BaseRouter):
):
raise InvalidSignal("Invalid signal event: %s" % event)
if parts[0] in RESERVED_NAMESPACES:
if (
parts[0] in RESERVED_NAMESPACES
and event not in RESERVED_NAMESPACES[parts[0]]
and not (parts[2].startswith("<") and parts[2].endswith(">"))
):
raise InvalidSignal(
"Cannot declare reserved signal event: %s" % event
)

View File

@@ -0,0 +1,8 @@
from .meta import TouchUpMeta
from .service import TouchUp
__all__ = (
"TouchUp",
"TouchUpMeta",
)

22
sanic/touchup/meta.py Normal file
View File

@@ -0,0 +1,22 @@
from sanic.exceptions import SanicException
from .service import TouchUp
class TouchUpMeta(type):
def __new__(cls, name, bases, attrs, **kwargs):
gen_class = super().__new__(cls, name, bases, attrs, **kwargs)
methods = attrs.get("__touchup__")
attrs["__touched__"] = False
if methods:
for method in methods:
if method not in attrs:
raise SanicException(
"Cannot perform touchup on non-existent method: "
f"{name}.{method}"
)
TouchUp.register(gen_class, method)
return gen_class

View File

@@ -0,0 +1,5 @@
from .base import BaseScheme
from .ode import OptionalDispatchEvent # noqa
__all__ = ("BaseScheme",)

View File

@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from typing import Set, Type
class BaseScheme(ABC):
ident: str
_registry: Set[Type] = set()
def __init__(self, app) -> None:
self.app = app
@abstractmethod
def run(self, method, module_globals) -> None:
...
def __init_subclass__(cls):
BaseScheme._registry.add(cls)
def __call__(self, method, module_globals):
return self.run(method, module_globals)

View File

@@ -0,0 +1,67 @@
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse
from inspect import getsource
from textwrap import dedent
from typing import Any
from sanic.log import logger
from .base import BaseScheme
class OptionalDispatchEvent(BaseScheme):
ident = "ODE"
def __init__(self, app) -> None:
super().__init__(app)
self._registered_events = [
signal.path for signal in app.signal_router.routes
]
def run(self, method, module_globals):
raw_source = getsource(method)
src = dedent(raw_source)
tree = parse(src)
node = RemoveDispatch(self._registered_events).visit(tree)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]
class RemoveDispatch(NodeTransformer):
def __init__(self, registered_events) -> None:
self._registered_events = registered_events
def visit_Expr(self, node: Expr) -> Any:
call = node.value
if isinstance(call, Await):
call = call.value
func = getattr(call, "func", None)
args = getattr(call, "args", None)
if not func or not args:
return node
if isinstance(func, Attribute) and func.attr == "dispatch":
event = args[0]
if hasattr(event, "s"):
event_name = getattr(event, "value", event.s)
if self._not_registered(event_name):
logger.debug(f"Disabling event: {event_name}")
return None
return node
def _not_registered(self, event_name):
dynamic = []
for event in self._registered_events:
if event.endswith(">"):
namespace_concern, _ = event.rsplit(".", 1)
dynamic.append(namespace_concern)
namespace_concern, _ = event_name.rsplit(".", 1)
return (
event_name not in self._registered_events
and namespace_concern not in dynamic
)

33
sanic/touchup/service.py Normal file
View File

@@ -0,0 +1,33 @@
from inspect import getmembers, getmodule
from typing import Set, Tuple, Type
from .schemes import BaseScheme
class TouchUp:
_registry: Set[Tuple[Type, str]] = set()
@classmethod
def run(cls, app):
for target, method_name in cls._registry:
method = getattr(target, method_name)
if app.test_mode:
placeholder = f"_{method_name}"
if hasattr(target, placeholder):
method = getattr(target, placeholder)
else:
setattr(target, placeholder, method)
module = getmodule(target)
module_globals = dict(getmembers(module))
for scheme in BaseScheme._registry:
modified = scheme(app)(method, module_globals)
setattr(target, method_name, modified)
target.__touched__ = True
@classmethod
def register(cls, target, method_name):
cls._registry.add((target, method_name))

View File

@@ -8,7 +8,7 @@ import traceback
from gunicorn.workers import base # type: ignore
from sanic.log import logger
from sanic.server import HttpProtocol, Signal, serve, trigger_events
from sanic.server import HttpProtocol, Signal, serve
from sanic.websocket import WebSocketProtocol
@@ -68,10 +68,10 @@ class GunicornWorker(base.Worker):
)
self._server_settings["signal"] = self.signal
self._server_settings.pop("sock")
trigger_events(
self._server_settings.get("before_start", []), self.loop
self._await(self.app.callable._startup())
self._await(
self.app.callable._server_event("init", "before", loop=self.loop)
)
self._server_settings["before_start"] = ()
main_start = self._server_settings.pop("main_start", None)
main_stop = self._server_settings.pop("main_stop", None)
@@ -82,24 +82,29 @@ class GunicornWorker(base.Worker):
"with GunicornWorker"
)
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
try:
self.loop.run_until_complete(self._runner)
self._await(self._run())
self.app.callable.is_running = True
trigger_events(
self._server_settings.get("after_start", []), self.loop
self._await(
self.app.callable._server_event(
"init", "after", loop=self.loop
)
)
self.loop.run_until_complete(self._check_alive())
trigger_events(
self._server_settings.get("before_stop", []), self.loop
self._await(
self.app.callable._server_event(
"shutdown", "before", loop=self.loop
)
)
self.loop.run_until_complete(self.close())
except BaseException:
traceback.print_exc()
finally:
try:
trigger_events(
self._server_settings.get("after_stop", []), self.loop
self._await(
self.app.callable._server_event(
"shutdown", "after", loop=self.loop
)
)
except BaseException:
traceback.print_exc()
@@ -238,3 +243,7 @@ class GunicornWorker(base.Worker):
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)
def _await(self, coro):
fut = asyncio.ensure_future(coro, loop=self.loop)
self.loop.run_until_complete(fut)