Add named tasks (#2304)

This commit is contained in:
Adam Hopkins 2021-12-20 23:50:04 +02:00 committed by GitHub
parent abe062b371
commit d799c5f03c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 314 additions and 44 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import logging.config import logging.config
import os import os
@ -11,6 +12,7 @@ from asyncio import (
AbstractEventLoop, AbstractEventLoop,
CancelledError, CancelledError,
Protocol, Protocol,
Task,
ensure_future, ensure_future,
get_event_loop, get_event_loop,
wait_for, wait_for,
@ -125,6 +127,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"_future_signals", "_future_signals",
"_future_statics", "_future_statics",
"_state", "_state",
"_task_registry",
"_test_client", "_test_client",
"_test_manager", "_test_manager",
"asgi", "asgi",
@ -188,17 +191,22 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"load_env or env_prefix" "load_env or env_prefix"
) )
self.config: Config = config or Config(
load_env=load_env,
env_prefix=env_prefix,
)
self._asgi_client: Any = None self._asgi_client: Any = None
self._test_client: Any = None
self._test_manager: Any = None
self._blueprint_order: List[Blueprint] = [] self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = [] self._delayed_tasks: List[str] = []
self._future_registry: FutureRegistry = FutureRegistry() self._future_registry: FutureRegistry = FutureRegistry()
self._state: ApplicationState = ApplicationState(app=self) self._state: ApplicationState = ApplicationState(app=self)
self._task_registry: Dict[str, Task] = {}
self._test_client: Any = None
self._test_manager: Any = None
self.asgi = False
self.auto_reload = False
self.blueprints: Dict[str, Blueprint] = {} self.blueprints: Dict[str, Blueprint] = {}
self.config: Config = config or Config(
load_env=load_env, env_prefix=env_prefix
)
self.configure_logging: bool = configure_logging self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace() self.ctx: Any = ctx or SimpleNamespace()
self.debug = False self.debug = False
@ -250,32 +258,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
# Registration # Registration
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
def add_task(
self,
task: Union[Future[Any], Coroutine[Any, Any, Any], Awaitable[Any]],
) -> None:
"""
Schedule a task to run later, after the loop has started.
Different from asyncio.ensure_future in that it does not
also return a future, and the actual ensure_future call
is delayed until before server start.
`See user guide re: background tasks
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`__
:param task: future, couroutine or awaitable
"""
try:
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:
task_name = f"sanic.delayed_task.{hash(task)}"
if not self._delayed_tasks:
self.after_server_start(partial(self.dispatch_delayed_tasks))
self.signal(task_name)(partial(self.run_delayed_task, task=task))
self._delayed_tasks.append(task_name)
def register_listener( def register_listener(
self, listener: ListenerType[SanicVar], event: str self, listener: ListenerType[SanicVar], event: str
) -> ListenerType[SanicVar]: ) -> ListenerType[SanicVar]:
@ -1183,6 +1165,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
This kills the Sanic This kills the Sanic
""" """
if not self.is_stopping: if not self.is_stopping:
self.shutdown_tasks(timeout=0)
self.is_stopping = True self.is_stopping = True
get_event_loop().stop() get_event_loop().stop()
@ -1456,7 +1439,29 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
return ".".join(parts) return ".".join(parts)
@classmethod @classmethod
def _prep_task(cls, task, app, loop): def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()
@staticmethod
async def _listener(
app: Sanic, loop: AbstractEventLoop, listener: ListenerType
):
maybe_coro = listener(app, loop)
if maybe_coro and isawaitable(maybe_coro):
await maybe_coro
# -------------------------------------------------------------------- #
# Task management
# -------------------------------------------------------------------- #
@classmethod
def _prep_task(
cls,
task,
app,
loop,
):
if callable(task): if callable(task):
try: try:
task = task(app) task = task(app)
@ -1466,14 +1471,22 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
return task return task
@classmethod @classmethod
def _loop_add_task(cls, task, app, loop): def _loop_add_task(
cls,
task,
app,
loop,
*,
name: Optional[str] = None,
register: bool = True,
) -> Task:
prepped = cls._prep_task(task, app, loop) prepped = cls._prep_task(task, app, loop)
loop.create_task(prepped) task = loop.create_task(prepped, name=name)
@classmethod if name and register:
def _cancel_websocket_tasks(cls, app, loop): app._task_registry[name] = task
for task in app.websocket_tasks:
task.cancel() return task
@staticmethod @staticmethod
async def dispatch_delayed_tasks(app, loop): async def dispatch_delayed_tasks(app, loop):
@ -1486,13 +1499,132 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
prepped = app._prep_task(task, app, loop) prepped = app._prep_task(task, app, loop)
await prepped await prepped
@staticmethod def add_task(
async def _listener( self,
app: Sanic, loop: AbstractEventLoop, listener: ListenerType task: Union[Future[Any], Coroutine[Any, Any, Any], Awaitable[Any]],
*,
name: Optional[str] = None,
register: bool = True,
) -> Optional[Task]:
"""
Schedule a task to run later, after the loop has started.
Different from asyncio.ensure_future in that it does not
also return a future, and the actual ensure_future call
is delayed until before server start.
`See user guide re: background tasks
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`__
:param task: future, couroutine or awaitable
"""
if name and sys.version_info == (3, 7):
name = None
error_logger.warning(
"Cannot set a name for a task when using Python 3.7. Your "
"task will be created without a name."
)
try:
loop = self.loop # Will raise SanicError if loop is not started
return self._loop_add_task(
task, self, loop, name=name, register=register
)
except SanicException:
task_name = f"sanic.delayed_task.{hash(task)}"
if not self._delayed_tasks:
self.after_server_start(partial(self.dispatch_delayed_tasks))
if name:
raise RuntimeError(
"Cannot name task outside of a running application"
)
self.signal(task_name)(partial(self.run_delayed_task, task=task))
self._delayed_tasks.append(task_name)
return None
def get_task(
self, name: str, *, raise_exception: bool = True
) -> Optional[Task]:
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
try:
return self._task_registry[name]
except KeyError:
if raise_exception:
raise SanicException(
f'Registered task named "{name}" not found.'
)
return None
async def cancel_task(
self,
name: str,
msg: Optional[str] = None,
*,
raise_exception: bool = True,
) -> None:
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
task = self.get_task(name, raise_exception=raise_exception)
if task and not task.cancelled():
args: Tuple[str, ...] = ()
if msg:
if sys.version_info >= (3, 9):
args = (msg,)
else:
raise RuntimeError(
"Cancelling a task with a message is only supported "
"on Python 3.9+."
)
task.cancel(*args)
try:
await task
except CancelledError:
...
def purge_tasks(self):
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
for task in self.tasks:
if task.done() or task.cancelled():
name = task.get_name()
self._task_registry[name] = None
self._task_registry = {
k: v for k, v in self._task_registry.items() if v is not None
}
def shutdown_tasks(
self, timeout: Optional[float] = None, increment: float = 0.1
): ):
maybe_coro = listener(app, loop) if sys.version_info == (3, 7):
if maybe_coro and isawaitable(maybe_coro): raise RuntimeError(
await maybe_coro "This feature is only supported on using Python 3.8+."
)
for task in self.tasks:
task.cancel()
if timeout is None:
timeout = self.config.GRACEFUL_SHUTDOWN_TIMEOUT
while len(self._task_registry) and timeout:
self.loop.run_until_complete(asyncio.sleep(increment))
self.purge_tasks()
timeout -= increment
@property
def tasks(self):
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
return iter(self._task_registry.values())
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# ASGI # ASGI

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import sys
from ssl import SSLContext from ssl import SSLContext
from typing import TYPE_CHECKING, Dict, Optional, Type, Union from typing import TYPE_CHECKING, Dict, Optional, Type, Union
@ -174,6 +176,9 @@ def serve(
loop.run_until_complete(asyncio.sleep(0.1)) loop.run_until_complete(asyncio.sleep(0.1))
start_shutdown = start_shutdown + 0.1 start_shutdown = start_shutdown + 0.1
if sys.version_info > (3, 7):
app.shutdown_tasks(graceful - start_shutdown)
# Force close non-idle connection after waiting for # Force close non-idle connection after waiting for
# graceful_shutdown_timeout # graceful_shutdown_timeout
for conn in connections: for conn in connections:

View File

@ -1,7 +1,11 @@
import asyncio import asyncio
import sys
from threading import Event from threading import Event
import pytest
from sanic.exceptions import SanicException
from sanic.response import text from sanic.response import text
@ -48,3 +52,41 @@ def test_create_task_with_app_arg(app):
_, response = app.test_client.get("/") _, response = app.test_client.get("/")
assert response.text == "test_create_task_with_app_arg" assert response.text == "test_create_task_with_app_arg"
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
def test_create_named_task(app):
async def dummy():
...
@app.before_server_start
async def setup(app, _):
app.add_task(dummy, name="dummy_task")
@app.after_server_start
async def stop(app, _):
task = app.get_task("dummy_task")
assert app._task_registry
assert isinstance(task, asyncio.Task)
assert task.get_name() == "dummy_task"
app.stop()
app.run()
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
def test_create_named_task_fails_outside_app(app):
async def dummy():
...
message = "Cannot name task outside of a running application"
with pytest.raises(RuntimeError, match=message):
app.add_task(dummy, name="dummy_task")
assert not app._task_registry
message = 'Registered task named "dummy_task" not found.'
with pytest.raises(SanicException, match=message):
app.get_task("dummy_task")

91
tests/test_tasks.py Normal file
View File

@ -0,0 +1,91 @@
import asyncio
import sys
from asyncio.tasks import Task
from unittest.mock import Mock, call
import pytest
from sanic.app import Sanic
from sanic.response import empty
pytestmark = pytest.mark.asyncio
async def dummy(n=0):
for _ in range(n):
await asyncio.sleep(1)
return True
@pytest.fixture(autouse=True)
def mark_app_running(app):
app.is_running = True
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
async def test_add_task_returns_task(app: Sanic):
task = app.add_task(dummy())
assert isinstance(task, Task)
assert len(app._task_registry) == 0
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
async def test_add_task_with_name(app: Sanic):
task = app.add_task(dummy(), name="dummy")
assert isinstance(task, Task)
assert len(app._task_registry) == 1
assert task is app.get_task("dummy")
for task in app.tasks:
assert task in app._task_registry.values()
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
async def test_cancel_task(app: Sanic):
task = app.add_task(dummy(3), name="dummy")
assert task
assert not task.done()
assert not task.cancelled()
await asyncio.sleep(0.1)
assert not task.done()
assert not task.cancelled()
await app.cancel_task("dummy")
assert task.cancelled()
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
async def test_purge_tasks(app: Sanic):
app.add_task(dummy(3), name="dummy")
await app.cancel_task("dummy")
assert len(app._task_registry) == 1
app.purge_tasks()
assert len(app._task_registry) == 0
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
def test_shutdown_tasks_on_app_stop(app: Sanic):
app.shutdown_tasks = Mock()
@app.route("/")
async def handler(_):
return empty()
app.test_client.get("/")
app.shutdown_tasks.call_args == [
call(timeout=0),
call(15.0),
]