From b1b12e004eff1702d969ac711e5863b46b79e893 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Thu, 5 Aug 2021 22:55:42 +0300 Subject: [PATCH] Signals Integration (#2160) * Update some tests * Resolve #2122 route decorator returning tuple * Use rc sanic-routing version * Update unit tests to <:str> * Minimal working version with some signals implemented * Add more http signals * Update ASGI and change listeners to signals * Allow for dynamic ODE signals * Allow signals to be stacked * Begin tests * Prioritize match_info on keyword argument injection * WIP on tests * Compat with signals * Work through some test coverage * Passing tests * Post linting * Setup proper resets * coverage reporting * Fixes from vltr comments * clear delayed tasks * Fix bad test * rm pycache --- examples/run_async_advanced.py | 29 +++- sanic/app.py | 246 ++++++++++++++++++++++++------ sanic/asgi.py | 39 ++--- sanic/http.py | 33 +++- sanic/mixins/listeners.py | 8 +- sanic/mixins/signals.py | 4 +- sanic/request.py | 4 + sanic/server.py | 101 +++++++----- sanic/signals.py | 97 +++++++++--- sanic/touchup/__init__.py | 8 + sanic/touchup/meta.py | 22 +++ sanic/touchup/schemes/__init__.py | 5 + sanic/touchup/schemes/base.py | 20 +++ sanic/touchup/schemes/ode.py | 67 ++++++++ sanic/touchup/service.py | 33 ++++ sanic/worker.py | 33 ++-- setup.py | 2 +- tests/conftest.py | 56 +++++-- tests/test_cli.py | 6 +- tests/test_create_task.py | 21 +-- tests/test_exceptions_handler.py | 2 - tests/test_keep_alive_timeout.py | 230 +++++----------------------- tests/test_logo.py | 80 ++-------- tests/test_request_timeout.py | 56 ++++--- tests/test_routes.py | 47 +++--- tests/test_server_events.py | 31 +++- tests/test_signal_handlers.py | 5 + tests/test_signals.py | 4 +- tests/test_touchup.py | 21 +++ tests/test_url_for.py | 24 ++- tox.ini | 2 +- 31 files changed, 823 insertions(+), 513 deletions(-) create mode 100644 sanic/touchup/__init__.py create mode 100644 sanic/touchup/meta.py create mode 100644 sanic/touchup/schemes/__init__.py create mode 100644 sanic/touchup/schemes/base.py create mode 100644 sanic/touchup/schemes/ode.py create mode 100644 sanic/touchup/service.py create mode 100644 tests/test_touchup.py diff --git a/examples/run_async_advanced.py b/examples/run_async_advanced.py index 36027c2f..27f86f3f 100644 --- a/examples/run_async_advanced.py +++ b/examples/run_async_advanced.py @@ -1,29 +1,44 @@ -from sanic import Sanic -from sanic import response -from signal import signal, SIGINT import asyncio + +from signal import SIGINT, signal + import uvloop +from sanic import Sanic, response +from sanic.server import AsyncioServer + + app = Sanic(__name__) -@app.listener('after_server_start') + +@app.listener("after_server_start") async def after_start_test(app, loop): print("Async Server Started!") + @app.route("/") async def test(request): return response.json({"answer": "42"}) + asyncio.set_event_loop(uvloop.new_event_loop()) -serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True) +serv_coro = app.create_server( + host="0.0.0.0", port=8000, return_asyncio_server=True +) loop = asyncio.get_event_loop() serv_task = asyncio.ensure_future(serv_coro, loop=loop) signal(SIGINT, lambda s, f: loop.stop()) -server = loop.run_until_complete(serv_task) +server: AsyncioServer = loop.run_until_complete(serv_task) # type: ignore +server.startup() + +# When using app.run(), this actually triggers before the serv_coro. +# But, in this example, we are using the convenience method, even if it is +# out of order. +server.before_start() server.after_start() try: loop.run_forever() -except KeyboardInterrupt as e: +except KeyboardInterrupt: loop.stop() finally: server.before_stop() diff --git a/sanic/app.py b/sanic/app.py index ac346d73..69b564ea 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import logging import logging.config import os import re from asyncio import ( + AbstractEventLoop, CancelledError, Protocol, ensure_future, @@ -72,19 +75,27 @@ from sanic.server import AsyncioServer, HttpProtocol from sanic.server import Signal as ServerSignal from sanic.server import serve, serve_multiple, serve_single from sanic.signals import Signal, SignalRouter +from sanic.touchup import TouchUp, TouchUpMeta from sanic.websocket import ConnectionClosed, WebSocketProtocol -class Sanic(BaseSanic): +class Sanic(BaseSanic, metaclass=TouchUpMeta): """ The main application instance """ + __touchup__ = ( + "handle_request", + "handle_exception", + "_run_response_middleware", + "_run_request_middleware", + ) __fake_slots__ = ( "_asgi_app", "_app_registry", "_asgi_client", "_blueprint_order", + "_delayed_tasks", "_future_routes", "_future_statics", "_future_middleware", @@ -155,6 +166,7 @@ class Sanic(BaseSanic): self._asgi_client = None self._blueprint_order: List[Blueprint] = [] + self._delayed_tasks: List[str] = [] self._test_client = None self._test_manager = None self.asgi = False @@ -192,6 +204,7 @@ class Sanic(BaseSanic): self.__class__.register_app(self) self.router.ctx.app = self + self.signal_router.ctx.app = self if dumps: BaseHTTPResponse._dumps = dumps # type: ignore @@ -232,9 +245,12 @@ class Sanic(BaseSanic): loop = self.loop # Will raise SanicError if loop is not started self._loop_add_task(task, self, loop) except SanicException: - self.listener("before_server_start")( - partial(self._loop_add_task, task) - ) + task_name = f"sanic.delayed_task.{hash(task)}" + if not self._delayed_tasks: + self.after_server_start(partial(self.dispatch_delayed_tasks)) + + self.signal(task_name)(partial(self.run_delayed_task, task=task)) + self._delayed_tasks.append(task_name) def register_listener(self, listener: Callable, event: str) -> Any: """ @@ -246,12 +262,20 @@ class Sanic(BaseSanic): """ try: - _event = ListenerEvent(event) - except ValueError: - valid = ", ".join(ListenerEvent.__members__.values()) + _event = ListenerEvent[event.upper()] + except (ValueError, AttributeError): + valid = ", ".join( + map(lambda x: x.lower(), ListenerEvent.__members__.keys()) + ) raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") - self.listeners[_event].append(listener) + if "." in _event: + self.signal(_event.value)( + partial(self._listener, listener=listener) + ) + else: + self.listeners[_event.value].append(listener) + return listener def register_middleware(self, middleware, attach_to: str = "request"): @@ -379,11 +403,17 @@ class Sanic(BaseSanic): *, condition: Optional[Dict[str, str]] = None, context: Optional[Dict[str, Any]] = None, + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, ) -> Coroutine[Any, Any, Awaitable[Any]]: return self.signal_router.dispatch( event, context=context, condition=condition, + inline=inline, + reverse=reverse, + fail_not_found=fail_not_found, ) async def event( @@ -659,7 +689,7 @@ class Sanic(BaseSanic): async def handle_exception( self, request: Request, exception: BaseException - ): + ): # no cov """ A handler that catches specific exceptions and outputs a response. @@ -669,6 +699,12 @@ class Sanic(BaseSanic): :type exception: BaseException :raises ServerError: response 500 """ + await self.dispatch( + "http.lifecycle.exception", + inline=True, + context={"request": request, "exception": exception}, + ) + # -------------------------------------------- # # Request Middleware # -------------------------------------------- # @@ -715,7 +751,7 @@ class Sanic(BaseSanic): f"Invalid response type {response!r} (need HTTPResponse)" ) - async def handle_request(self, request: Request): + async def handle_request(self, request: Request): # no cov """Take a request from the HTTP Server and return a response object to be sent back The HTTP Server only expects a response object, so exception handling must be done here @@ -723,10 +759,22 @@ class Sanic(BaseSanic): :param request: HTTP Request object :return: Nothing """ + await self.dispatch( + "http.lifecycle.handle", + inline=True, + context={"request": request}, + ) + # Define `response` var here to remove warnings about # allocation before assignment below. response = None try: + + await self.dispatch( + "http.routing.before", + inline=True, + context={"request": request}, + ) # Fetch handler from router route, handler, kwargs = self.router.get( request.path, @@ -734,9 +782,20 @@ class Sanic(BaseSanic): request.headers.getone("host", None), ) - request._match_info = kwargs + request._match_info = {**kwargs} request.route = route + await self.dispatch( + "http.routing.after", + inline=True, + context={ + "request": request, + "route": route, + "kwargs": kwargs, + "handler": handler, + }, + ) + if ( request.stream and request.stream.request_body @@ -772,7 +831,7 @@ class Sanic(BaseSanic): ) # Run response handler - response = handler(request, **kwargs) + response = handler(request, **request.match_info) if isawaitable(response): response = await response @@ -783,6 +842,14 @@ class Sanic(BaseSanic): # Make sure that response is finished / run StreamingHTTP callback if isinstance(response, BaseHTTPResponse): + await self.dispatch( + "http.lifecycle.response", + inline=True, + context={ + "request": request, + "response": response, + }, + ) await response.send(end_stream=True) else: if not hasattr(handler, "is_websocket"): @@ -1078,11 +1145,6 @@ class Sanic(BaseSanic): run_async=return_asyncio_server, ) - # Trigger before_start events - await self.trigger_events( - server_settings.get("before_start", []), - server_settings.get("loop"), - ) main_start = server_settings.pop("main_start", None) main_stop = server_settings.pop("main_stop", None) if main_start or main_stop: @@ -1095,17 +1157,9 @@ class Sanic(BaseSanic): asyncio_server_kwargs=asyncio_server_kwargs, **server_settings ) - async def trigger_events(self, events, loop): - """Trigger events (functions or async) - :param events: one or more sync or async functions to execute - :param loop: event loop - """ - for event in events: - result = event(loop) - if isawaitable(result): - await result - - async def _run_request_middleware(self, request, request_name=None): + async def _run_request_middleware( + self, request, request_name=None + ): # no cov # The if improves speed. I don't know why named_middleware = self.named_request_middleware.get( request_name, deque() @@ -1118,25 +1172,67 @@ class Sanic(BaseSanic): request.request_middleware_started = True for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + response = middleware(request) if isawaitable(response): response = await response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + if response: return response return None async def _run_response_middleware( self, request, response, request_name=None - ): + ): # no cov named_middleware = self.named_response_middleware.get( request_name, deque() ) applicable_middleware = self.response_middleware + named_middleware if applicable_middleware: for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": response, + }, + condition={"attach_to": "response"}, + ) + _response = middleware(request, response) if isawaitable(_response): _response = await _response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": _response if _response else response, + }, + condition={"attach_to": "response"}, + ) + if _response: response = _response if isinstance(response, BaseHTTPResponse): @@ -1162,10 +1258,6 @@ class Sanic(BaseSanic): ): """Helper function used by `run` and `create_server`.""" - self.listeners["before_server_start"] = [ - self.finalize - ] + self.listeners["before_server_start"] - if isinstance(ssl, dict): # try common aliaseses cert = ssl.get("cert") or ssl.get("certificate") @@ -1202,10 +1294,6 @@ class Sanic(BaseSanic): # Register start/stop events for event_name, settings_name, reverse in ( - ("before_server_start", "before_start", False), - ("after_server_start", "after_start", False), - ("before_server_stop", "before_stop", True), - ("after_server_stop", "after_stop", True), ("main_process_start", "main_start", False), ("main_process_stop", "main_stop", True), ): @@ -1253,20 +1341,44 @@ class Sanic(BaseSanic): return ".".join(parts) @classmethod - def _loop_add_task(cls, task, app, loop): + def _prep_task(cls, task, app, loop): if callable(task): try: - loop.create_task(task(app)) + task = task(app) except TypeError: - loop.create_task(task()) - else: - loop.create_task(task) + task = task() + + return task + + @classmethod + def _loop_add_task(cls, task, app, loop): + prepped = cls._prep_task(task, app, loop) + loop.create_task(prepped) @classmethod def _cancel_websocket_tasks(cls, app, loop): for task in app.websocket_tasks: task.cancel() + @staticmethod + async def dispatch_delayed_tasks(app, loop): + for name in app._delayed_tasks: + await app.dispatch(name, context={"app": app, "loop": loop}) + app._delayed_tasks.clear() + + @staticmethod + async def run_delayed_task(app, loop, task): + prepped = app._prep_task(task, app, loop) + await prepped + + @staticmethod + async def _listener( + app: Sanic, loop: AbstractEventLoop, listener: ListenerType + ): + maybe_coro = listener(app, loop) + if maybe_coro and isawaitable(maybe_coro): + await maybe_coro + # -------------------------------------------------------------------- # # ASGI # -------------------------------------------------------------------- # @@ -1340,15 +1452,51 @@ class Sanic(BaseSanic): raise SanicException(f'Sanic app name "{name}" not found.') # -------------------------------------------------------------------- # - # Static methods + # Lifecycle # -------------------------------------------------------------------- # - @staticmethod - async def finalize(app, _): + def finalize(self): try: - app.router.finalize() - if app.signal_router.routes: - app.signal_router.finalize() # noqa + self.router.finalize() except FinalizationError as e: if not Sanic.test_mode: - raise e # noqa + raise e + + def signalize(self): + try: + self.signal_router.finalize() + except FinalizationError as e: + if not Sanic.test_mode: + raise e + + async def _startup(self): + self.signalize() + self.finalize() + TouchUp.run(self) + + async def _server_event( + self, + concern: str, + action: str, + loop: Optional[AbstractEventLoop] = None, + ) -> None: + event = f"server.{concern}.{action}" + if action not in ("before", "after") or concern not in ( + "init", + "shutdown", + ): + raise SanicException(f"Invalid server event: {event}") + logger.debug(f"Triggering server events: {event}") + reverse = concern == "shutdown" + if loop is None: + loop = self.loop + await self.dispatch( + event, + fail_not_found=False, + reverse=reverse, + inline=True, + context={ + "app": self, + "loop": loop, + }, + ) diff --git a/sanic/asgi.py b/sanic/asgi.py index 330ced5a..13d4f87c 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,6 +1,5 @@ import warnings -from inspect import isawaitable from typing import Optional from urllib.parse import quote @@ -18,14 +17,20 @@ class Lifespan: def __init__(self, asgi_app: "ASGIApp") -> None: self.asgi_app = asgi_app - if "before_server_start" in self.asgi_app.sanic_app.listeners: + if ( + "server.init.before" + in self.asgi_app.sanic_app.signal_router.name_index + ): warnings.warn( 'You have set a listener for "before_server_start" ' "in ASGI mode. " "It will be executed as early as possible, but not before " "the ASGI server is started." ) - if "after_server_stop" in self.asgi_app.sanic_app.listeners: + if ( + "server.shutdown.after" + in self.asgi_app.sanic_app.signal_router.name_index + ): warnings.warn( 'You have set a listener for "after_server_stop" ' "in ASGI mode. " @@ -42,19 +47,9 @@ class Lifespan: in sequence since the ASGI lifespan protocol only supports a single startup event. """ - self.asgi_app.sanic_app.router.finalize() - if self.asgi_app.sanic_app.signal_router.routes: - self.asgi_app.sanic_app.signal_router.finalize() - listeners = self.asgi_app.sanic_app.listeners.get( - "before_server_start", [] - ) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) - - for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if response and isawaitable(response): - await response + await self.asgi_app.sanic_app._startup() + await self.asgi_app.sanic_app._server_event("init", "before") + await self.asgi_app.sanic_app._server_event("init", "after") async def shutdown(self) -> None: """ @@ -65,16 +60,8 @@ class Lifespan: in sequence since the ASGI lifespan protocol only supports a single shutdown event. """ - listeners = self.asgi_app.sanic_app.listeners.get( - "before_server_stop", [] - ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) - - for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if response and isawaitable(response): - await response + await self.asgi_app.sanic_app._server_event("shutdown", "before") + await self.asgi_app.sanic_app._server_event("shutdown", "after") async def __call__( self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend diff --git a/sanic/http.py b/sanic/http.py index 700dd432..7bdca6a2 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -21,6 +21,7 @@ from sanic.exceptions import ( from sanic.headers import format_http1_response from sanic.helpers import has_message_body from sanic.log import access_logger, error_logger, logger +from sanic.touchup import TouchUpMeta class Stage(Enum): @@ -45,7 +46,7 @@ class Stage(Enum): HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" -class Http: +class Http(metaclass=TouchUpMeta): """ Internal helper for managing the HTTP request/response cycle @@ -67,9 +68,15 @@ class Http: HEADER_CEILING = 16_384 HEADER_MAX_SIZE = 0 + __touchup__ = ( + "http1_request_header", + "http1_response_header", + "read", + ) __slots__ = [ "_send", "_receive_more", + "dispatch", "recv_buffer", "protocol", "expecting_continue", @@ -97,6 +104,7 @@ class Http: self.protocol = protocol self.keep_alive = True self.stage: Stage = Stage.IDLE + self.dispatch = self.protocol.app.dispatch self.init_for_request() def init_for_request(self): @@ -183,7 +191,7 @@ class Http: if not self.recv_buffer: await self._receive_more() - async def http1_request_header(self): + async def http1_request_header(self): # no cov """ Receive and parse request header into self.request. """ @@ -212,6 +220,12 @@ class Http: reqline, *split_headers = raw_headers.split("\r\n") method, self.url, protocol = reqline.split(" ") + await self.dispatch( + "http.lifecycle.read_head", + inline=True, + context={"head": bytes(head)}, + ) + if protocol == "HTTP/1.1": self.keep_alive = True elif protocol == "HTTP/1.0": @@ -250,6 +264,11 @@ class Http: transport=self.protocol.transport, app=self.protocol.app, ) + await self.dispatch( + "http.lifecycle.request", + inline=True, + context={"request": request}, + ) # Prepare for request body self.request_bytes_left = self.request_bytes = 0 @@ -280,7 +299,7 @@ class Http: async def http1_response_header( self, data: bytes, end_stream: bool - ) -> None: + ) -> None: # no cov res = self.response # Compatibility with simple response body @@ -469,7 +488,7 @@ class Http: if data: yield data - async def read(self) -> Optional[bytes]: + async def read(self) -> Optional[bytes]: # no cov """ Read some bytes of request body. """ @@ -543,6 +562,12 @@ class Http: self.request_bytes_left -= size + await self.dispatch( + "http.lifecycle.read_body", + inline=True, + context={"body": data}, + ) + return data # Response methods diff --git a/sanic/mixins/listeners.py b/sanic/mixins/listeners.py index c12326c4..71f11148 100644 --- a/sanic/mixins/listeners.py +++ b/sanic/mixins/listeners.py @@ -9,10 +9,10 @@ class ListenerEvent(str, Enum): def _generate_next_value_(name: str, *args) -> str: # type: ignore return name.lower() - BEFORE_SERVER_START = auto() - AFTER_SERVER_START = auto() - BEFORE_SERVER_STOP = auto() - AFTER_SERVER_STOP = auto() + BEFORE_SERVER_START = "server.init.before" + AFTER_SERVER_START = "server.init.after" + BEFORE_SERVER_STOP = "server.shutdown.before" + AFTER_SERVER_STOP = "server.shutdown.after" MAIN_PROCESS_START = auto() MAIN_PROCESS_STOP = auto() diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index e849e562..2be9fee2 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -23,7 +23,7 @@ class SignalMixin: *, apply: bool = True, condition: Dict[str, Any] = None, - ) -> Callable[[SignalHandler], FutureSignal]: + ) -> Callable[[SignalHandler], SignalHandler]: """ For creating a signal handler, used similar to a route handler: @@ -54,7 +54,7 @@ class SignalMixin: if apply: self._apply_signal(future_signal) - return future_signal + return handler return decorator diff --git a/sanic/request.py b/sanic/request.py index 177df637..e37de36a 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -497,6 +497,10 @@ class Request: """ return self._match_info + @match_info.setter + def match_info(self, value): + self._match_info = value + # Transport properties (obtained from local interface only) @property diff --git a/sanic/server.py b/sanic/server.py index 4ec83f9c..18eb79d2 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -13,7 +13,7 @@ from typing import ( Union, ) -from sanic.models.handler_types import ListenerType +from sanic.touchup.meta import TouchUpMeta if TYPE_CHECKING: @@ -37,7 +37,7 @@ from time import monotonic as current_time from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows from sanic.config import Config -from sanic.exceptions import RequestTimeout, ServiceUnavailable +from sanic.exceptions import RequestTimeout, SanicException, ServiceUnavailable from sanic.http import Http, Stage from sanic.log import error_logger, logger from sanic.models.protocol_types import TransportProtocol @@ -102,11 +102,15 @@ class ConnInfo: self.client_port = addr[1] -class HttpProtocol(asyncio.Protocol): +class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta): """ This class provides a basic HTTP implementation of the sanic framework. """ + __touchup__ = ( + "send", + "connection_task", + ) __slots__ = ( # app "app", @@ -185,7 +189,7 @@ class HttpProtocol(asyncio.Protocol): self._time = current_time() self.check_timeouts() - async def connection_task(self): + async def connection_task(self): # no cov """ Run a HTTP connection. @@ -194,6 +198,11 @@ class HttpProtocol(asyncio.Protocol): """ try: self._setup_connection() + await self.app.dispatch( + "http.lifecycle.begin", + inline=True, + context={"conn_info": self.conn_info}, + ) await self._http.http1() except CancelledError: pass @@ -212,6 +221,13 @@ class HttpProtocol(asyncio.Protocol): self.close() except BaseException: error_logger.exception("Closing failed") + finally: + await self.app.dispatch( + "http.lifecycle.complete", + inline=True, + context={"conn_info": self.conn_info}, + ) + ... async def receive_more(self): """ @@ -259,13 +275,18 @@ class HttpProtocol(asyncio.Protocol): except Exception: error_logger.exception("protocol.check_timeouts") - async def send(self, data): + async def send(self, data): # no cov """ Writes data with backpressure control. """ await self._can_write.wait() if self.transport.is_closing(): raise CancelledError + await self.app.dispatch( + "http.lifecycle.send", + inline=True, + context={"data": data}, + ) self.transport.write(data) self._time = current_time() @@ -359,52 +380,54 @@ class AsyncioServer: a user who needs to manage the server lifecycle manually. """ - __slots__ = ( - "loop", - "serve_coro", - "_after_start", - "_before_stop", - "_after_stop", - "server", - "connections", - ) + __slots__ = ("app", "connections", "loop", "serve_coro", "server", "init") def __init__( self, + app, loop, serve_coro, connections, - after_start: Optional[Iterable[ListenerType]], - before_stop: Optional[Iterable[ListenerType]], - after_stop: Optional[Iterable[ListenerType]], ): # Note, Sanic already called "before_server_start" events # before this helper was even created. So we don't need it here. + self.app = app + self.connections = connections self.loop = loop self.serve_coro = serve_coro - self._after_start = after_start - self._before_stop = before_stop - self._after_stop = after_stop self.server = None - self.connections = connections + self.init = False + + def startup(self): + """ + Trigger "before_server_start" events + """ + self.init = True + return self.app._startup() + + def before_start(self): + """ + Trigger "before_server_start" events + """ + return self._server_event("init", "before") def after_start(self): """ Trigger "after_server_start" events """ - trigger_events(self._after_start, self.loop) + return self._server_event("init", "after") def before_stop(self): """ Trigger "before_server_stop" events """ - trigger_events(self._before_stop, self.loop) + return self._server_event("shutdown", "before") def after_stop(self): """ Trigger "after_server_stop" events """ - trigger_events(self._after_stop, self.loop) + return self._server_event("shutdown", "after") def is_serving(self) -> bool: if self.server: @@ -442,6 +465,14 @@ class AsyncioServer: "of asyncio or uvloop." ) + def _server_event(self, concern: str, action: str): + if not self.init: + raise SanicException( + "Cannot dispatch server event without " + "first running server.startup()" + ) + return self.app._server_event(concern, action, loop=self.loop) + def __await__(self): """ Starts the asyncio server, returns AsyncServerCoro @@ -456,11 +487,7 @@ class AsyncioServer: def serve( host, port, - app, - before_start: Optional[Iterable[ListenerType]] = None, - after_start: Optional[Iterable[ListenerType]] = None, - before_stop: Optional[Iterable[ListenerType]] = None, - after_stop: Optional[Iterable[ListenerType]] = None, + app: Sanic, ssl: Optional[SSLContext] = None, sock: Optional[socket.socket] = None, unix: Optional[str] = None, @@ -542,15 +569,14 @@ def serve( if run_async: return AsyncioServer( + app=app, loop=loop, serve_coro=server_coroutine, connections=connections, - after_start=after_start, - before_stop=before_stop, - after_stop=after_stop, ) - trigger_events(before_start, loop) + loop.run_until_complete(app._startup()) + loop.run_until_complete(app._server_event("init", "before")) try: http_server = loop.run_until_complete(server_coroutine) @@ -558,8 +584,6 @@ def serve( error_logger.exception("Unable to start server") return - trigger_events(after_start, loop) - # Ignore SIGINT when run_multiple if run_multiple: signal_func(SIGINT, SIG_IGN) @@ -571,6 +595,8 @@ def serve( else: for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: loop.add_signal_handler(_signal, app.stop) + + loop.run_until_complete(app._server_event("init", "after")) pid = os.getpid() try: logger.info("Starting worker [%s]", pid) @@ -579,7 +605,7 @@ def serve( logger.info("Stopping worker [%s]", pid) # Run the on_stop function if provided - trigger_events(before_stop, loop) + loop.run_until_complete(app._server_event("shutdown", "before")) # Wait for event loop to finish and all connections to drain http_server.close() @@ -611,8 +637,7 @@ def serve( _shutdown = asyncio.gather(*coros) loop.run_until_complete(_shutdown) - - trigger_events(after_stop, loop) + loop.run_until_complete(app._server_event("shutdown", "after")) remove_unix_socket(unix) diff --git a/sanic/signals.py b/sanic/signals.py index eec2a438..2c1a704c 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.utils import path_to_parts # type: ignore from sanic.exceptions import InvalidSignal +from sanic.log import error_logger, logger from sanic.models.handler_types import SignalHandler -RESERVED_NAMESPACES = ( - "server", - "http", -) +RESERVED_NAMESPACES = { + "server": ( + # "server.main.start", + # "server.main.stop", + "server.init.before", + "server.init.after", + "server.shutdown.before", + "server.shutdown.after", + ), + "http": ( + "http.lifecycle.begin", + "http.lifecycle.complete", + "http.lifecycle.exception", + "http.lifecycle.handle", + "http.lifecycle.read_body", + "http.lifecycle.read_head", + "http.lifecycle.request", + "http.lifecycle.response", + "http.routing.after", + "http.routing.before", + "http.lifecycle.send", + "http.middleware.after", + "http.middleware.before", + ), +} + + +def _blank(): + ... class Signal(Route): @@ -59,8 +85,13 @@ class SignalRouter(BaseRouter): terms.append(extra) raise NotFound(message % tuple(terms)) + # Regex routes evaluate and can extract params directly. They are set + # on param_basket["__params__"] params = param_basket["__params__"] if not params: + # If param_basket["__params__"] does not exist, we might have + # param_basket["__matches__"], which are indexed based matches + # on path segments. They should already be cast types. params = { param.name: param_basket["__matches__"][idx] for idx, param in group.params.items() @@ -73,8 +104,18 @@ class SignalRouter(BaseRouter): event: str, context: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, str]] = None, - ) -> None: - group, handlers, params = self.get(event, condition=condition) + fail_not_found: bool = True, + reverse: bool = False, + ) -> Any: + try: + group, handlers, params = self.get(event, condition=condition) + except NotFound as e: + if fail_not_found: + raise e + else: + if self.ctx.app.debug: + error_logger.warning(str(e)) + return None events = [signal.ctx.event for signal in group] for signal_event in events: @@ -82,12 +123,19 @@ class SignalRouter(BaseRouter): if context: params.update(context) + if not reverse: + handlers = handlers[::-1] try: for handler in handlers: if condition is None or condition == handler.__requirements__: maybe_coroutine = handler(**params) if isawaitable(maybe_coroutine): - await maybe_coroutine + retval = await maybe_coroutine + if retval: + return retval + elif maybe_coroutine: + return maybe_coroutine + return None finally: for signal_event in events: signal_event.clear() @@ -98,14 +146,23 @@ class SignalRouter(BaseRouter): *, context: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, str]] = None, - ) -> asyncio.Task: - task = self.ctx.loop.create_task( - self._dispatch( - event, - context=context, - condition=condition, - ) + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, + ) -> Union[asyncio.Task, Any]: + dispatch = self._dispatch( + event, + context=context, + condition=condition, + fail_not_found=fail_not_found and inline, + reverse=reverse, ) + logger.debug(f"Dispatching signal: {event}") + + if inline: + return await dispatch + + task = asyncio.get_running_loop().create_task(dispatch) await asyncio.sleep(0) return task @@ -131,7 +188,9 @@ class SignalRouter(BaseRouter): append=True, ) # type: ignore - def finalize(self, do_compile: bool = True): + def finalize(self, do_compile: bool = True, do_optimize: bool = False): + self.add(_blank, "sanic.__signal__.__init__") + try: self.ctx.loop = asyncio.get_running_loop() except RuntimeError: @@ -140,7 +199,7 @@ class SignalRouter(BaseRouter): for signal in self.routes: signal.ctx.event = asyncio.Event() - return super().finalize(do_compile=do_compile) + return super().finalize(do_compile=do_compile, do_optimize=do_optimize) def _build_event_parts(self, event: str) -> Tuple[str, str, str]: parts = path_to_parts(event, self.delimiter) @@ -151,7 +210,11 @@ class SignalRouter(BaseRouter): ): raise InvalidSignal("Invalid signal event: %s" % event) - if parts[0] in RESERVED_NAMESPACES: + if ( + parts[0] in RESERVED_NAMESPACES + and event not in RESERVED_NAMESPACES[parts[0]] + and not (parts[2].startswith("<") and parts[2].endswith(">")) + ): raise InvalidSignal( "Cannot declare reserved signal event: %s" % event ) diff --git a/sanic/touchup/__init__.py b/sanic/touchup/__init__.py new file mode 100644 index 00000000..6fe208ab --- /dev/null +++ b/sanic/touchup/__init__.py @@ -0,0 +1,8 @@ +from .meta import TouchUpMeta +from .service import TouchUp + + +__all__ = ( + "TouchUp", + "TouchUpMeta", +) diff --git a/sanic/touchup/meta.py b/sanic/touchup/meta.py new file mode 100644 index 00000000..9f60af38 --- /dev/null +++ b/sanic/touchup/meta.py @@ -0,0 +1,22 @@ +from sanic.exceptions import SanicException + +from .service import TouchUp + + +class TouchUpMeta(type): + def __new__(cls, name, bases, attrs, **kwargs): + gen_class = super().__new__(cls, name, bases, attrs, **kwargs) + + methods = attrs.get("__touchup__") + attrs["__touched__"] = False + if methods: + + for method in methods: + if method not in attrs: + raise SanicException( + "Cannot perform touchup on non-existent method: " + f"{name}.{method}" + ) + TouchUp.register(gen_class, method) + + return gen_class diff --git a/sanic/touchup/schemes/__init__.py b/sanic/touchup/schemes/__init__.py new file mode 100644 index 00000000..87057a5f --- /dev/null +++ b/sanic/touchup/schemes/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseScheme +from .ode import OptionalDispatchEvent # noqa + + +__all__ = ("BaseScheme",) diff --git a/sanic/touchup/schemes/base.py b/sanic/touchup/schemes/base.py new file mode 100644 index 00000000..d16619b2 --- /dev/null +++ b/sanic/touchup/schemes/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Set, Type + + +class BaseScheme(ABC): + ident: str + _registry: Set[Type] = set() + + def __init__(self, app) -> None: + self.app = app + + @abstractmethod + def run(self, method, module_globals) -> None: + ... + + def __init_subclass__(cls): + BaseScheme._registry.add(cls) + + def __call__(self, method, module_globals): + return self.run(method, module_globals) diff --git a/sanic/touchup/schemes/ode.py b/sanic/touchup/schemes/ode.py new file mode 100644 index 00000000..357f748c --- /dev/null +++ b/sanic/touchup/schemes/ode.py @@ -0,0 +1,67 @@ +from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse +from inspect import getsource +from textwrap import dedent +from typing import Any + +from sanic.log import logger + +from .base import BaseScheme + + +class OptionalDispatchEvent(BaseScheme): + ident = "ODE" + + def __init__(self, app) -> None: + super().__init__(app) + + self._registered_events = [ + signal.path for signal in app.signal_router.routes + ] + + def run(self, method, module_globals): + raw_source = getsource(method) + src = dedent(raw_source) + tree = parse(src) + node = RemoveDispatch(self._registered_events).visit(tree) + compiled_src = compile(node, method.__name__, "exec") + exec_locals: Dict[str, Any] = {} + exec(compiled_src, module_globals, exec_locals) # nosec + + return exec_locals[method.__name__] + + +class RemoveDispatch(NodeTransformer): + def __init__(self, registered_events) -> None: + self._registered_events = registered_events + + def visit_Expr(self, node: Expr) -> Any: + call = node.value + if isinstance(call, Await): + call = call.value + + func = getattr(call, "func", None) + args = getattr(call, "args", None) + if not func or not args: + return node + + if isinstance(func, Attribute) and func.attr == "dispatch": + event = args[0] + if hasattr(event, "s"): + event_name = getattr(event, "value", event.s) + if self._not_registered(event_name): + logger.debug(f"Disabling event: {event_name}") + return None + return node + + def _not_registered(self, event_name): + dynamic = [] + for event in self._registered_events: + if event.endswith(">"): + namespace_concern, _ = event.rsplit(".", 1) + dynamic.append(namespace_concern) + + namespace_concern, _ = event_name.rsplit(".", 1) + return ( + event_name not in self._registered_events + and namespace_concern not in dynamic + ) diff --git a/sanic/touchup/service.py b/sanic/touchup/service.py new file mode 100644 index 00000000..95792dca --- /dev/null +++ b/sanic/touchup/service.py @@ -0,0 +1,33 @@ +from inspect import getmembers, getmodule +from typing import Set, Tuple, Type + +from .schemes import BaseScheme + + +class TouchUp: + _registry: Set[Tuple[Type, str]] = set() + + @classmethod + def run(cls, app): + for target, method_name in cls._registry: + method = getattr(target, method_name) + + if app.test_mode: + placeholder = f"_{method_name}" + if hasattr(target, placeholder): + method = getattr(target, placeholder) + else: + setattr(target, placeholder, method) + + module = getmodule(target) + module_globals = dict(getmembers(module)) + + for scheme in BaseScheme._registry: + modified = scheme(app)(method, module_globals) + setattr(target, method_name, modified) + + target.__touched__ = True + + @classmethod + def register(cls, target, method_name): + cls._registry.add((target, method_name)) diff --git a/sanic/worker.py b/sanic/worker.py index 342900e6..51bee6c2 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -8,7 +8,7 @@ import traceback from gunicorn.workers import base # type: ignore from sanic.log import logger -from sanic.server import HttpProtocol, Signal, serve, trigger_events +from sanic.server import HttpProtocol, Signal, serve from sanic.websocket import WebSocketProtocol @@ -68,10 +68,10 @@ class GunicornWorker(base.Worker): ) self._server_settings["signal"] = self.signal self._server_settings.pop("sock") - trigger_events( - self._server_settings.get("before_start", []), self.loop + self._await(self.app.callable._startup()) + self._await( + self.app.callable._server_event("init", "before", loop=self.loop) ) - self._server_settings["before_start"] = () main_start = self._server_settings.pop("main_start", None) main_stop = self._server_settings.pop("main_stop", None) @@ -82,24 +82,29 @@ class GunicornWorker(base.Worker): "with GunicornWorker" ) - self._runner = asyncio.ensure_future(self._run(), loop=self.loop) try: - self.loop.run_until_complete(self._runner) + self._await(self._run()) self.app.callable.is_running = True - trigger_events( - self._server_settings.get("after_start", []), self.loop + self._await( + self.app.callable._server_event( + "init", "after", loop=self.loop + ) ) self.loop.run_until_complete(self._check_alive()) - trigger_events( - self._server_settings.get("before_stop", []), self.loop + self._await( + self.app.callable._server_event( + "shutdown", "before", loop=self.loop + ) ) self.loop.run_until_complete(self.close()) except BaseException: traceback.print_exc() finally: try: - trigger_events( - self._server_settings.get("after_stop", []), self.loop + self._await( + self.app.callable._server_event( + "shutdown", "after", loop=self.loop + ) ) except BaseException: traceback.print_exc() @@ -238,3 +243,7 @@ class GunicornWorker(base.Worker): self.exit_code = 1 self.cfg.worker_abort(self) sys.exit(1) + + def _await(self, coro): + fut = asyncio.ensure_future(coro, loop=self.loop) + self.loop.run_until_complete(fut) diff --git a/setup.py b/setup.py index 6371fe46..9341eae8 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ requirements = [ ] tests_require = [ - "sanic-testing==0.7.0b1", + "sanic-testing>=0.7.0b2", "pytest==5.2.1", "coverage==5.3", "gunicorn==20.0.4", diff --git a/tests/conftest.py b/tests/conftest.py index 65b218cf..d24066c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import asyncio +import logging import random import re import string @@ -9,10 +11,12 @@ from typing import Tuple import pytest from sanic_routing.exceptions import RouteExists +from sanic_testing.testing import PORT from sanic import Sanic from sanic.constants import HTTP_METHODS from sanic.router import Router +from sanic.touchup.service import TouchUp slugify = re.compile(r"[^a-zA-Z0-9_\-]") @@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]: collect_ignore = ["test_worker.py"] -@pytest.fixture -def caplog(caplog): - yield caplog - - async def _handler(request): """ Dummy placeholder method used for route resolver when creating a new @@ -41,33 +40,32 @@ async def _handler(request): TYPE_TO_GENERATOR_MAP = { - "string": lambda: "".join( + "str": lambda: "".join( [random.choice(string.ascii_lowercase) for _ in range(4)] ), "int": lambda: random.choice(range(1000000)), - "number": lambda: random.random(), + "float": lambda: random.random(), "alpha": lambda: "".join( [random.choice(string.ascii_lowercase) for _ in range(4)] ), "uuid": lambda: str(uuid.uuid1()), } +CACHE = {} + class RouteStringGenerator: ROUTE_COUNT_PER_DEPTH = 100 HTTP_METHODS = HTTP_METHODS - ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"] + ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"] def generate_random_direct_route(self, max_route_depth=4): routes = [] for depth in range(1, max_route_depth + 1): for _ in range(self.ROUTE_COUNT_PER_DEPTH): route = "/".join( - [ - TYPE_TO_GENERATOR_MAP.get("string")() - for _ in range(depth) - ] + [TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)] ) route = route.replace(".", "", -1) route_detail = (random.choice(self.HTTP_METHODS), route) @@ -83,7 +81,7 @@ class RouteStringGenerator: new_route_part = "/".join( [ "<{}:{}>".format( - TYPE_TO_GENERATOR_MAP.get("string")(), + TYPE_TO_GENERATOR_MAP.get("str")(), random.choice(self.ROUTE_PARAM_TYPES), ) for _ in range(max_route_depth - current_length) @@ -98,7 +96,7 @@ class RouteStringGenerator: def generate_url_for_template(template): url = template for pattern, param_type in re.findall( - re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"), + re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"), template, ): value = TYPE_TO_GENERATOR_MAP.get(param_type)() @@ -141,5 +139,33 @@ def url_param_generator(): @pytest.fixture(scope="function") def app(request): + if not CACHE: + for target, method_name in TouchUp._registry: + CACHE[method_name] = getattr(target, method_name) app = Sanic(slugify.sub("-", request.node.name)) - return app + yield app + for target, method_name in TouchUp._registry: + setattr(target, method_name, CACHE[method_name]) + + +@pytest.fixture(scope="function") +def run_startup(caplog): + def run(app): + nonlocal caplog + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + with caplog.at_level(logging.DEBUG): + server = app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + loop._stopping = False + + _server = loop.run_until_complete(server) + + _server.close() + loop.run_until_complete(_server.wait_closed()) + app.stop() + + return caplog.record_tuples + + return run diff --git a/tests/test_cli.py b/tests/test_cli.py index 5f69dd95..908a91a3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -89,7 +89,7 @@ def test_debug(cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO @@ -103,7 +103,7 @@ def test_auto_reload(cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert info["debug"] is False @@ -118,7 +118,7 @@ def test_access_logs(cmd, expected): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert info["access_log"] is expected diff --git a/tests/test_create_task.py b/tests/test_create_task.py index e128263b..99f724b5 100644 --- a/tests/test_create_task.py +++ b/tests/test_create_task.py @@ -1,6 +1,5 @@ import asyncio -from queue import Queue from threading import Event from sanic.response import text @@ -13,8 +12,6 @@ def test_create_task(app): await asyncio.sleep(0.05) e.set() - app.add_task(coro) - @app.route("/early") def not_set(request): return text(str(e.is_set())) @@ -24,24 +21,30 @@ def test_create_task(app): await asyncio.sleep(0.1) return text(str(e.is_set())) + app.add_task(coro) + request, response = app.test_client.get("/early") assert response.body == b"False" + app.signal_router.reset() + app.add_task(coro) request, response = app.test_client.get("/late") assert response.body == b"True" def test_create_task_with_app_arg(app): - q = Queue() + @app.after_server_start + async def setup_q(app, _): + app.ctx.q = asyncio.Queue() @app.route("/") - def not_set(request): - return "hello" + async def not_set(request): + return text(await request.app.ctx.q.get()) async def coro(app): - q.put(app.name) + await app.ctx.q.put(app.name) app.add_task(coro) - request, response = app.test_client.get("/") - assert q.get() == "test_create_task_with_app_arg" + _, response = app.test_client.get("/") + assert response.text == "test_create_task_with_app_arg" diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index e6fd42eb..44beb278 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -127,7 +127,6 @@ def test_html_traceback_output_in_debug_mode(): soup = BeautifulSoup(response.body, "html.parser") html = str(soup) - assert "response = handler(request, **kwargs)" in html assert "handler_4" in html assert "foo = bar" in html @@ -151,7 +150,6 @@ def test_chained_exception_handler(): soup = BeautifulSoup(response.body, "html.parser") html = str(soup) - assert "response = handler(request, **kwargs)" in html assert "handler_6" in html assert "foo = 1 / arg" in html assert "ValueError" in html diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index e777de2e..e30761ed 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -2,16 +2,13 @@ import asyncio import platform from asyncio import sleep as aio_sleep -from json import JSONDecodeError from os import environ -import httpcore -import httpx import pytest -from sanic_testing.testing import HOST, SanicTestClient +from sanic_testing.reusable import ReusableClient -from sanic import Sanic, server +from sanic import Sanic from sanic.compat import OS_IS_WINDOWS from sanic.response import text @@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port -class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool): - last_reused_connection = None - - async def _get_connection_from_pool(self, *args, **kwargs): - conn = await super()._get_connection_from_pool(*args, **kwargs) - self.__class__.last_reused_connection = conn - return conn - - -class ResusableSanicSession(httpx.AsyncClient): - def __init__(self, *args, **kwargs) -> None: - transport = ReusableSanicConnectionPool() - super().__init__(transport=transport, *args, **kwargs) - - -class ReuseableSanicTestClient(SanicTestClient): - def __init__(self, app, loop=None): - super().__init__(app) - if loop is None: - loop = asyncio.get_event_loop() - self._loop = loop - self._server = None - self._tcp_connector = None - self._session = None - - def get_new_session(self): - return ResusableSanicSession() - - # Copied from SanicTestClient, but with some changes to reuse the - # same loop for the same app. - def _sanic_endpoint_test( - self, - method="get", - uri="/", - gather_request=True, - debug=False, - server_kwargs=None, - *request_args, - **request_kwargs, - ): - loop = self._loop - results = [None, None] - exceptions = [] - server_kwargs = server_kwargs or {"return_asyncio_server": True} - if gather_request: - - def _collect_request(request): - if results[0] is None: - results[0] = request - - self.app.request_middleware.appendleft(_collect_request) - - if uri.startswith( - ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") - ): - url = uri - else: - uri = uri if uri.startswith("/") else f"/{uri}" - scheme = "http" - url = f"{scheme}://{HOST}:{PORT}{uri}" - - @self.app.listener("after_server_start") - async def _collect_response(loop): - try: - response = await self._local_request( - method, url, *request_args, **request_kwargs - ) - results[-1] = response - except Exception as e2: - exceptions.append(e2) - - if self._server is not None: - _server = self._server - else: - _server_co = self.app.create_server( - host=HOST, debug=debug, port=PORT, **server_kwargs - ) - - server.trigger_events( - self.app.listeners["before_server_start"], loop - ) - - try: - loop._stopping = False - _server = loop.run_until_complete(_server_co) - except Exception as e1: - raise e1 - self._server = _server - server.trigger_events(self.app.listeners["after_server_start"], loop) - self.app.listeners["after_server_start"].pop() - - if exceptions: - raise ValueError(f"Exception during request: {exceptions}") - - if gather_request: - self.app.request_middleware.pop() - try: - request, response = results - return request, response - except Exception: - raise ValueError( - f"Request and response object expected, got ({results})" - ) - else: - try: - return results[-1] - except Exception: - raise ValueError(f"Request object expected, got ({results})") - - def kill_server(self): - try: - if self._server: - self._server.close() - self._loop.run_until_complete(self._server.wait_closed()) - self._server = None - - if self._session: - self._loop.run_until_complete(self._session.aclose()) - self._session = None - - except Exception as e3: - raise e3 - - # Copied from SanicTestClient, but with some changes to reuse the - # same TCPConnection and the sane ClientSession more than once. - # Note, you cannot use the same session if you are in a _different_ - # loop, so the changes above are required too. - async def _local_request(self, method, url, *args, **kwargs): - raw_cookies = kwargs.pop("raw_cookies", None) - request_keepalive = kwargs.pop( - "request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"] - ) - if not self._session: - self._session = self.get_new_session() - try: - response = await getattr(self._session, method.lower())( - url, timeout=request_keepalive, *args, **kwargs - ) - except NameError: - raise Exception(response.status_code) - - try: - response.json = response.json() - except (JSONDecodeError, UnicodeDecodeError): - response.json = None - - response.body = await response.aread() - response.status = response.status_code - response.content_type = response.headers.get("content-type") - - if raw_cookies: - response.raw_cookies = {} - for cookie in response.cookies: - response.raw_cookies[cookie.name] = cookie - - return response - - keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") @@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse(): """If the server keep-alive timeout and client keep-alive timeout are both longer than the delay, the client _and_ server will successfully reuse the existing connection.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT) + with client: headers = {"Connection": "keep-alive"} request, response = client.get("/1", headers=headers) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 + loop.run_until_complete(aio_sleep(1)) + request, response = client.get("/1") assert response.status == 200 assert response.text == "OK" - assert ReusableSanicConnectionPool.last_reused_connection - finally: - client.kill_server() + assert request.protocol.state["requests_count"] == 2 @pytest.mark.skipif( @@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse(): def test_keep_alive_client_timeout(): """If the server keep-alive timeout is longer than the client keep-alive timeout, client will try to create a new connection here.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient( + keep_alive_app_client_timeout, loop=loop, port=PORT + ) + with client: headers = {"Connection": "keep-alive"} - _, response = client.get("/1", headers=headers, request_keepalive=1) + request, response = client.get("/1", headers=headers, timeout=1) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 loop.run_until_complete(aio_sleep(2)) - _, response = client.get("/1", request_keepalive=1) - - assert ReusableSanicConnectionPool.last_reused_connection is None - finally: - client.kill_server() + request, response = client.get("/1", timeout=1) + assert request.protocol.state["requests_count"] == 1 @pytest.mark.skipif( @@ -277,22 +117,23 @@ def test_keep_alive_server_timeout(): keep-alive timeout, the client will either a 'Connection reset' error _or_ a new connection. Depending on how the event-loop handles the broken server connection.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient( + keep_alive_app_server_timeout, loop=loop, port=PORT + ) + with client: headers = {"Connection": "keep-alive"} - _, response = client.get("/1", headers=headers, request_keepalive=60) + request, response = client.get("/1", headers=headers, timeout=60) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 loop.run_until_complete(aio_sleep(3)) - _, response = client.get("/1", request_keepalive=60) + request, response = client.get("/1", timeout=60) - assert ReusableSanicConnectionPool.last_reused_connection is None - finally: - client.kill_server() + assert request.protocol.state["requests_count"] == 1 @pytest.mark.skipif( @@ -300,10 +141,10 @@ def test_keep_alive_server_timeout(): reason="Not testable with current client", ) def test_keep_alive_connection_context(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_context, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT) + with client: headers = {"Connection": "keep-alive"} request1, _ = client.post("/ctx", headers=headers) @@ -315,5 +156,4 @@ def test_keep_alive_connection_context(): assert ( request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello" ) - finally: - client.kill_server() + assert request2.protocol.state["requests_count"] == 2 diff --git a/tests/test_logo.py b/tests/test_logo.py index 3fff32db..e59975c3 100644 --- a/tests/test_logo.py +++ b/tests/test_logo.py @@ -6,85 +6,37 @@ from sanic_testing.testing import PORT from sanic.config import BASE_LOGO -def test_logo_base(app, caplog): - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False +def test_logo_base(app, run_startup): + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == BASE_LOGO + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == BASE_LOGO -def test_logo_false(app, caplog): +def test_logo_false(app, caplog, run_startup): app.config.LOGO = False - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - banner, port = caplog.record_tuples[0][2].rsplit(":", 1) - assert caplog.record_tuples[0][1] == logging.INFO + banner, port = logs[0][2].rsplit(":", 1) + assert logs[0][1] == logging.INFO assert banner == "Goin' Fast @ http://127.0.0.1" assert int(port) > 0 -def test_logo_true(app, caplog): +def test_logo_true(app, run_startup): app.config.LOGO = True - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == BASE_LOGO + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == BASE_LOGO -def test_logo_custom(app, caplog): +def test_logo_custom(app, run_startup): app.config.LOGO = "My Custom Logo" - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == "My Custom Logo" + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == "My Custom Logo" diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index 89cb46df..48e23f1d 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -2,6 +2,7 @@ import asyncio import httpcore import httpx +import pytest from sanic_testing.testing import SanicTestClient @@ -48,42 +49,51 @@ class DelayableSanicTestClient(SanicTestClient): return DelayableSanicSession(request_delay=self._request_delay) -request_timeout_default_app = Sanic("test_request_timeout_default") -request_no_timeout_app = Sanic("test_request_no_timeout") -request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6 -request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6 +@pytest.fixture +def request_no_timeout_app(): + app = Sanic("test_request_no_timeout") + app.config.REQUEST_TIMEOUT = 0.6 + + @app.route("/1") + async def handler2(request): + return text("OK") + + return app -@request_timeout_default_app.route("/1") -async def handler1(request): - return text("OK") +@pytest.fixture +def request_timeout_default_app(): + app = Sanic("test_request_timeout_default") + app.config.REQUEST_TIMEOUT = 0.6 + + @app.route("/1") + async def handler1(request): + return text("OK") + + @app.websocket("/ws1") + async def ws_handler1(request, ws): + await ws.send("OK") + + return app -@request_no_timeout_app.route("/1") -async def handler2(request): - return text("OK") - - -@request_timeout_default_app.websocket("/ws1") -async def ws_handler1(request, ws): - await ws.send("OK") - - -def test_default_server_error_request_timeout(): +def test_default_server_error_request_timeout(request_timeout_default_app): client = DelayableSanicTestClient(request_timeout_default_app, 2) - request, response = client.get("/1") + _, response = client.get("/1") assert response.status == 408 assert "Request Timeout" in response.text -def test_default_server_error_request_dont_timeout(): +def test_default_server_error_request_dont_timeout(request_no_timeout_app): client = DelayableSanicTestClient(request_no_timeout_app, 0.2) - request, response = client.get("/1") + _, response = client.get("/1") assert response.status == 200 assert response.text == "OK" -def test_default_server_error_websocket_request_timeout(): +def test_default_server_error_websocket_request_timeout( + request_timeout_default_app, +): headers = { "Upgrade": "websocket", @@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout(): } client = DelayableSanicTestClient(request_timeout_default_app, 2) - request, response = client.get("/ws1", headers=headers) + _, response = client.get("/ws1", headers=headers) assert response.status == 408 assert "Request Timeout" in response.text diff --git a/tests/test_routes.py b/tests/test_routes.py index 0f4980f6..9af615b5 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app): @pytest.mark.asyncio @pytest.mark.parametrize("url", ["/ws", "ws"]) async def test_websocket_route_asgi(app, url): - ev = asyncio.Event() + @app.after_server_start + async def setup_ev(app, _): + app.ctx.ev = asyncio.Event() @app.websocket(url) async def handler(request, ws): - ev.set() + request.app.ctx.ev.set() - request, response = await app.asgi_client.websocket(url) - assert ev.is_set() + @app.get("/ev") + async def check(request): + return json({"set": request.app.ctx.ev.is_set()}) + + _, response = await app.asgi_client.websocket(url) + _, response = await app.asgi_client.get("/") + assert response.json["set"] -def test_websocket_route_with_subprotocols(app): - results = [] +@pytest.mark.parametrize( + "subprotocols,expected", + ( + (["bar"], "bar"), + (["bar", "foo"], "bar"), + (["baz"], None), + (None, None), + ), +) +def test_websocket_route_with_subprotocols(app, subprotocols, expected): + results = "unset" @app.websocket("/ws", subprotocols=["foo", "bar"]) async def handler(request, ws): - results.append(ws.subprotocol) + nonlocal results + results = ws.subprotocol assert ws.subprotocol is not None - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"]) - assert response.opened is True - assert results == ["bar"] - _, response = SanicTestClient(app).websocket( - "/ws", subprotocols=["bar", "foo"] + "/ws", subprotocols=subprotocols ) assert response.opened is True - assert results == ["bar", "bar"] - - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"]) - assert response.opened is True - assert results == ["bar", "bar", None] - - _, response = SanicTestClient(app).websocket("/ws") - assert response.opened is True - assert results == ["bar", "bar", None, None] + assert results == expected @pytest.mark.parametrize("strict_slashes", [True, False, None]) diff --git a/tests/test_server_events.py b/tests/test_server_events.py index 2e48f408..7ce1859c 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -8,7 +8,7 @@ import pytest from sanic_testing.testing import HOST, PORT -from sanic.exceptions import InvalidUsage +from sanic.exceptions import InvalidUsage, SanicException AVAILABLE_LISTENERS = [ @@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app): async def init_db(app, loop): app.db = MySanicDb() - await app.create_server(debug=True, return_asyncio_server=True, port=PORT) + srv = await app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + await srv.startup() + await srv.before_start() assert hasattr(app, "db") assert isinstance(app.db, MySanicDb) @@ -157,14 +161,15 @@ def test_create_server_trigger_events(app): serv_coro = app.create_server(return_asyncio_server=True, sock=sock) serv_task = asyncio.ensure_future(serv_coro, loop=loop) server = loop.run_until_complete(serv_task) - server.after_start() + loop.run_until_complete(server.startup()) + loop.run_until_complete(server.after_start()) try: loop.run_forever() - except KeyboardInterrupt as e: + except KeyboardInterrupt: loop.stop() finally: # Run the on_stop function if provided - server.before_stop() + loop.run_until_complete(server.before_stop()) # Wait for server to close close_task = server.close() @@ -174,5 +179,19 @@ def test_create_server_trigger_events(app): signal.stopped = True for connection in server.connections: connection.close_if_idle() - server.after_stop() + loop.run_until_complete(server.after_stop()) assert flag1 and flag2 and flag3 + + +@pytest.mark.asyncio +async def test_missing_startup_raises_exception(app): + @app.listener("before_server_start") + async def init_db(app, loop): + ... + + srv = await app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + + with pytest.raises(SanicException): + await srv.before_start() diff --git a/tests/test_signal_handlers.py b/tests/test_signal_handlers.py index 857b5283..c51cdd21 100644 --- a/tests/test_signal_handlers.py +++ b/tests/test_signal_handlers.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock import pytest +from sanic_testing.reusable import ReusableClient from sanic_testing.testing import HOST, PORT from sanic.compat import ctrlc_workaround_for_windows @@ -28,9 +29,13 @@ def set_loop(app, loop): signal.signal = mock else: loop.add_signal_handler = mock + print(">>>>>>>>>>>>>>>1", id(loop)) + print(">>>>>>>>>>>>>>>1", loop.add_signal_handler) def after(app, loop): + print(">>>>>>>>>>>>>>>2", id(loop)) + print(">>>>>>>>>>>>>>>2", loop.add_signal_handler) calledq.put(mock.called) diff --git a/tests/test_signals.py b/tests/test_signals.py index 5d116f90..9b8a9495 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app): app.signal_router.finalize() + assert len(app.signal_router.routes) == 3 await app.dispatch("foo.bar.baz") assert counter == 2 @@ -331,7 +332,8 @@ def test_event_on_bp_not_registered(): "event,expected", ( ("foo.bar.baz", True), - ("server.init.before", False), + ("server.init.before", True), + ("server.init.somethingelse", False), ("http.request.start", False), ("sanic.notice.anything", True), ), diff --git a/tests/test_touchup.py b/tests/test_touchup.py new file mode 100644 index 00000000..3079aa1b --- /dev/null +++ b/tests/test_touchup.py @@ -0,0 +1,21 @@ +import logging + +from sanic.signals import RESERVED_NAMESPACES +from sanic.touchup import TouchUp + + +def test_touchup_methods(app): + assert len(TouchUp._registry) == 9 + + +async def test_ode_removes_dispatch_events(app, caplog): + with caplog.at_level(logging.DEBUG, logger="sanic.root"): + await app._startup() + logs = caplog.record_tuples + + for signal in RESERVED_NAMESPACES["http"]: + assert ( + "sanic.root", + logging.DEBUG, + f"Disabling event: {signal}", + ) in logs diff --git a/tests/test_url_for.py b/tests/test_url_for.py index d623cc4a..6ec6a93f 100644 --- a/tests/test_url_for.py +++ b/tests/test_url_for.py @@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app): ) -def test_websocket_bp_route_name(app): +@pytest.mark.parametrize( + "name,expected", + ( + ("test_route", "/bp/route"), + ("test_route2", "/bp/route2"), + ("foobar_3", "/bp/route3"), + ), +) +def test_websocket_bp_route_name(app, name, expected): """Tests that blueprint websocket route is named.""" event = asyncio.Event() bp = Blueprint("test_bp", url_prefix="/bp") @@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app): uri = app.url_for("test_bp.main") assert uri == "/bp/main" - uri = app.url_for("test_bp.test_route") - assert uri == "/bp/route" + uri = app.url_for(f"test_bp.{name}") + assert uri == expected request, response = SanicTestClient(app).websocket(uri) assert response.opened is True assert event.is_set() - event.clear() - uri = app.url_for("test_bp.test_route2") - assert uri == "/bp/route2" - request, response = SanicTestClient(app).websocket(uri) - assert response.opened is True - assert event.is_set() - - uri = app.url_for("test_bp.foobar_3") - assert uri == "/bp/route3" - # TODO: add test with a route with multiple hosts # TODO: add test with a route with _host in url_for diff --git a/tox.ini b/tox.ini index 3d4e35b8..5612f6de 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ extras = test commands = pytest {posargs:tests --cov sanic} - coverage combine --append - coverage report -m + coverage report -m -i coverage html -i [testenv:lint]