diff --git a/sanic/app.py b/sanic/app.py index 0d1238e3..7c740383 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging import logging.config import os @@ -11,6 +12,7 @@ from asyncio import ( AbstractEventLoop, CancelledError, Protocol, + Task, ensure_future, get_event_loop, wait_for, @@ -125,6 +127,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "_future_signals", "_future_statics", "_state", + "_task_registry", "_test_client", "_test_manager", "asgi", @@ -188,17 +191,22 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "load_env or env_prefix" ) + self.config: Config = config or Config( + load_env=load_env, + env_prefix=env_prefix, + ) + self._asgi_client: Any = None - self._test_client: Any = None - self._test_manager: Any = None self._blueprint_order: List[Blueprint] = [] self._delayed_tasks: List[str] = [] self._future_registry: FutureRegistry = FutureRegistry() self._state: ApplicationState = ApplicationState(app=self) + self._task_registry: Dict[str, Task] = {} + self._test_client: Any = None + self._test_manager: Any = None + self.asgi = False + self.auto_reload = False self.blueprints: Dict[str, Blueprint] = {} - self.config: Config = config or Config( - load_env=load_env, env_prefix=env_prefix - ) self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() self.debug = False @@ -250,32 +258,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): # Registration # -------------------------------------------------------------------- # - def add_task( - self, - task: Union[Future[Any], Coroutine[Any, Any, Any], Awaitable[Any]], - ) -> None: - """ - Schedule a task to run later, after the loop has started. - Different from asyncio.ensure_future in that it does not - also return a future, and the actual ensure_future call - is delayed until before server start. - - `See user guide re: background tasks - `__ - - :param task: future, couroutine or awaitable - """ - try: - loop = self.loop # Will raise SanicError if loop is not started - self._loop_add_task(task, self, loop) - except SanicException: - task_name = f"sanic.delayed_task.{hash(task)}" - if not self._delayed_tasks: - self.after_server_start(partial(self.dispatch_delayed_tasks)) - - self.signal(task_name)(partial(self.run_delayed_task, task=task)) - self._delayed_tasks.append(task_name) - def register_listener( self, listener: ListenerType[SanicVar], event: str ) -> ListenerType[SanicVar]: @@ -1183,6 +1165,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): This kills the Sanic """ if not self.is_stopping: + self.shutdown_tasks(timeout=0) self.is_stopping = True get_event_loop().stop() @@ -1456,7 +1439,29 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): return ".".join(parts) @classmethod - def _prep_task(cls, task, app, loop): + def _cancel_websocket_tasks(cls, app, loop): + for task in app.websocket_tasks: + task.cancel() + + @staticmethod + async def _listener( + app: Sanic, loop: AbstractEventLoop, listener: ListenerType + ): + maybe_coro = listener(app, loop) + if maybe_coro and isawaitable(maybe_coro): + await maybe_coro + + # -------------------------------------------------------------------- # + # Task management + # -------------------------------------------------------------------- # + + @classmethod + def _prep_task( + cls, + task, + app, + loop, + ): if callable(task): try: task = task(app) @@ -1466,14 +1471,22 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): return task @classmethod - def _loop_add_task(cls, task, app, loop): + def _loop_add_task( + cls, + task, + app, + loop, + *, + name: Optional[str] = None, + register: bool = True, + ) -> Task: prepped = cls._prep_task(task, app, loop) - loop.create_task(prepped) + task = loop.create_task(prepped, name=name) - @classmethod - def _cancel_websocket_tasks(cls, app, loop): - for task in app.websocket_tasks: - task.cancel() + if name and register: + app._task_registry[name] = task + + return task @staticmethod async def dispatch_delayed_tasks(app, loop): @@ -1486,13 +1499,132 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): prepped = app._prep_task(task, app, loop) await prepped - @staticmethod - async def _listener( - app: Sanic, loop: AbstractEventLoop, listener: ListenerType + def add_task( + self, + task: Union[Future[Any], Coroutine[Any, Any, Any], Awaitable[Any]], + *, + name: Optional[str] = None, + register: bool = True, + ) -> Optional[Task]: + """ + Schedule a task to run later, after the loop has started. + Different from asyncio.ensure_future in that it does not + also return a future, and the actual ensure_future call + is delayed until before server start. + + `See user guide re: background tasks + `__ + + :param task: future, couroutine or awaitable + """ + if name and sys.version_info == (3, 7): + name = None + error_logger.warning( + "Cannot set a name for a task when using Python 3.7. Your " + "task will be created without a name." + ) + try: + loop = self.loop # Will raise SanicError if loop is not started + return self._loop_add_task( + task, self, loop, name=name, register=register + ) + except SanicException: + task_name = f"sanic.delayed_task.{hash(task)}" + if not self._delayed_tasks: + self.after_server_start(partial(self.dispatch_delayed_tasks)) + + if name: + raise RuntimeError( + "Cannot name task outside of a running application" + ) + + self.signal(task_name)(partial(self.run_delayed_task, task=task)) + self._delayed_tasks.append(task_name) + return None + + def get_task( + self, name: str, *, raise_exception: bool = True + ) -> Optional[Task]: + if sys.version_info == (3, 7): + raise RuntimeError( + "This feature is only supported on using Python 3.8+." + ) + try: + return self._task_registry[name] + except KeyError: + if raise_exception: + raise SanicException( + f'Registered task named "{name}" not found.' + ) + return None + + async def cancel_task( + self, + name: str, + msg: Optional[str] = None, + *, + raise_exception: bool = True, + ) -> None: + if sys.version_info == (3, 7): + raise RuntimeError( + "This feature is only supported on using Python 3.8+." + ) + task = self.get_task(name, raise_exception=raise_exception) + if task and not task.cancelled(): + args: Tuple[str, ...] = () + if msg: + if sys.version_info >= (3, 9): + args = (msg,) + else: + raise RuntimeError( + "Cancelling a task with a message is only supported " + "on Python 3.9+." + ) + task.cancel(*args) + try: + await task + except CancelledError: + ... + + def purge_tasks(self): + if sys.version_info == (3, 7): + raise RuntimeError( + "This feature is only supported on using Python 3.8+." + ) + for task in self.tasks: + if task.done() or task.cancelled(): + name = task.get_name() + self._task_registry[name] = None + + self._task_registry = { + k: v for k, v in self._task_registry.items() if v is not None + } + + def shutdown_tasks( + self, timeout: Optional[float] = None, increment: float = 0.1 ): - maybe_coro = listener(app, loop) - if maybe_coro and isawaitable(maybe_coro): - await maybe_coro + if sys.version_info == (3, 7): + raise RuntimeError( + "This feature is only supported on using Python 3.8+." + ) + for task in self.tasks: + task.cancel() + + if timeout is None: + timeout = self.config.GRACEFUL_SHUTDOWN_TIMEOUT + + while len(self._task_registry) and timeout: + self.loop.run_until_complete(asyncio.sleep(increment)) + self.purge_tasks() + timeout -= increment + + @property + def tasks(self): + if sys.version_info == (3, 7): + raise RuntimeError( + "This feature is only supported on using Python 3.8+." + ) + return iter(self._task_registry.values()) # -------------------------------------------------------------------- # # ASGI diff --git a/sanic/server/runners.py b/sanic/server/runners.py index 94a29328..21c7cdaa 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -1,5 +1,7 @@ from __future__ import annotations +import sys + from ssl import SSLContext from typing import TYPE_CHECKING, Dict, Optional, Type, Union @@ -174,6 +176,9 @@ def serve( loop.run_until_complete(asyncio.sleep(0.1)) start_shutdown = start_shutdown + 0.1 + if sys.version_info > (3, 7): + app.shutdown_tasks(graceful - start_shutdown) + # Force close non-idle connection after waiting for # graceful_shutdown_timeout for conn in connections: diff --git a/tests/test_create_task.py b/tests/test_create_task.py index 99f724b5..c98666a9 100644 --- a/tests/test_create_task.py +++ b/tests/test_create_task.py @@ -1,7 +1,11 @@ import asyncio +import sys from threading import Event +import pytest + +from sanic.exceptions import SanicException from sanic.response import text @@ -48,3 +52,41 @@ def test_create_task_with_app_arg(app): _, response = app.test_client.get("/") assert response.text == "test_create_task_with_app_arg" + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") +def test_create_named_task(app): + async def dummy(): + ... + + @app.before_server_start + async def setup(app, _): + app.add_task(dummy, name="dummy_task") + + @app.after_server_start + async def stop(app, _): + task = app.get_task("dummy_task") + + assert app._task_registry + assert isinstance(task, asyncio.Task) + + assert task.get_name() == "dummy_task" + + app.stop() + + app.run() + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") +def test_create_named_task_fails_outside_app(app): + async def dummy(): + ... + + message = "Cannot name task outside of a running application" + with pytest.raises(RuntimeError, match=message): + app.add_task(dummy, name="dummy_task") + assert not app._task_registry + + message = 'Registered task named "dummy_task" not found.' + with pytest.raises(SanicException, match=message): + app.get_task("dummy_task") diff --git a/tests/test_tasks.py b/tests/test_tasks.py new file mode 100644 index 00000000..449bfd7f --- /dev/null +++ b/tests/test_tasks.py @@ -0,0 +1,91 @@ +import asyncio +import sys + +from asyncio.tasks import Task +from unittest.mock import Mock, call + +import pytest + +from sanic.app import Sanic +from sanic.response import empty + + +pytestmark = pytest.mark.asyncio + + +async def dummy(n=0): + for _ in range(n): + await asyncio.sleep(1) + return True + + +@pytest.fixture(autouse=True) +def mark_app_running(app): + app.is_running = True + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") +async def test_add_task_returns_task(app: Sanic): + task = app.add_task(dummy()) + + assert isinstance(task, Task) + assert len(app._task_registry) == 0 + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") +async def test_add_task_with_name(app: Sanic): + task = app.add_task(dummy(), name="dummy") + + assert isinstance(task, Task) + assert len(app._task_registry) == 1 + assert task is app.get_task("dummy") + + for task in app.tasks: + assert task in app._task_registry.values() + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") +async def test_cancel_task(app: Sanic): + task = app.add_task(dummy(3), name="dummy") + + assert task + assert not task.done() + assert not task.cancelled() + + await asyncio.sleep(0.1) + + assert not task.done() + assert not task.cancelled() + + await app.cancel_task("dummy") + + assert task.cancelled() + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") +async def test_purge_tasks(app: Sanic): + app.add_task(dummy(3), name="dummy") + + await app.cancel_task("dummy") + + assert len(app._task_registry) == 1 + + app.purge_tasks() + + assert len(app._task_registry) == 0 + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") +def test_shutdown_tasks_on_app_stop(app: Sanic): + app.shutdown_tasks = Mock() + + @app.route("/") + async def handler(_): + return empty() + + app.test_client.get("/") + + app.shutdown_tasks.call_args == [ + call(timeout=0), + call(15.0), + ]