From 4d0a0a357063950e2ec0b8cb34b7c846136f7dc9 Mon Sep 17 00:00:00 2001 From: Josh Bartlett Date: Sat, 23 Sep 2023 22:59:44 +1000 Subject: [PATCH] Update app.event() to return context dict (#2088), and only fire for the specific requested event (#2826). Also added support for passing conditions= and exclusive= arguments to app.event(). --- sanic/app.py | 23 +++- sanic/blueprints.py | 41 ++++-- sanic/mixins/signals.py | 2 +- sanic/signals.py | 92 ++++++++++--- tests/test_signals.py | 284 ++++++++++++++++++++++++++++++++++------ 5 files changed, 367 insertions(+), 75 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 452954ca..367aa962 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -17,6 +17,7 @@ from asyncio import ( from asyncio.futures import Future from collections import defaultdict, deque from contextlib import contextmanager, suppress +from enum import Enum from functools import partial, wraps from inspect import isawaitable from os import environ @@ -663,7 +664,12 @@ class Sanic( ) async def event( - self, event: str, timeout: Optional[Union[int, float]] = None + self, + event: Union[str, Enum], + timeout: Optional[Union[int, float]] = None, + *, + condition: Optional[Dict[str, Any]] = None, + exclusive: bool = True, ) -> None: """Wait for a specific event to be triggered. @@ -686,13 +692,18 @@ class Sanic( timeout (Optional[Union[int, float]]): An optional timeout value in seconds. If provided, the wait will be terminated if the timeout is reached. Defaults to `None`, meaning no timeout. + condition: If provided, method will only return when the signal + is dispatched with the given condition. + exclusive: When true (default), the signal can only be dispatched + when the condition has been met. When ``False``, the signal can + be dispatched either with or without it. Raises: NotFound: If the event is not found and auto-registration of events is not enabled. Returns: - None + The context dict of the dispatched signal. Examples: ```python @@ -708,16 +719,16 @@ class Sanic( ``` """ - signal = self.signal_router.name_index.get(event) - if not signal: + waiter = self.signal_router.get_waiter(event, condition, exclusive) + if not waiter: if self.config.EVENT_AUTOREGISTER: self.signal_router.reset() self.add_signal(None, event) - signal = self.signal_router.name_index[event] + waiter = self.signal_router.get_waiter(event, condition, exclusive) self.signal_router.finalize() else: raise NotFound("Could not find signal %s" % event) - return await wait_for(signal.ctx.event.wait(), timeout=timeout) + return await wait_for(waiter.wait(), timeout=timeout) def report_exception( self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]] diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 28199e9e..4aabd2db 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -512,33 +512,54 @@ class Blueprint(BaseSanic): condition = kwargs.pop("condition", {}) condition.update({"__blueprint__": self.name}) kwargs["condition"] = condition - await asyncio.gather( + return await asyncio.gather( *[app.dispatch(*args, **kwargs) for app in self.apps] ) - def event(self, event: str, timeout: Optional[Union[int, float]] = None): + def event( + self, + event: str, + timeout: Optional[Union[int, float]] = None, + *, + condition: Optional[Dict[str, Any]] = None, + ): """Wait for a signal event to be dispatched. Args: event (str): Name of the signal event. timeout (Optional[Union[int, float]]): Timeout for the event to be dispatched. + condition: If provided, method will only return when the signal + is dispatched with the given condition. Returns: Awaitable: Awaitable for the event to be dispatched. """ - events = set() - for app in self.apps: - signal = app.signal_router.name_index.get(event) - if not signal: - raise NotFound("Could not find signal %s" % event) - events.add(signal.ctx.event) + if condition is None: + condition = {} + condition.update({"__blueprint__": self.name}) - return asyncio.wait( - [asyncio.create_task(event.wait()) for event in events], + waiters = [] + for app in self.apps: + waiter = app.signal_router.get_waiter(event, condition, exclusive=False) + if not waiter: + raise NotFound("Could not find signal %s" % event) + waiters.append(waiter) + + return self._event(waiters, timeout) + + async def _event(self, waiters, timeout): + done, pending = await asyncio.wait( + [asyncio.create_task(waiter.wait()) for waiter in waiters], return_when=asyncio.FIRST_COMPLETED, timeout=timeout, ) + for task in pending: + task.cancel() + if not done: + raise TimeoutError() + finished_task, = done + return finished_task.result() @staticmethod def _extract_value(*values): diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index f1419dc4..7b23d73a 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -88,7 +88,7 @@ class SignalMixin(metaclass=SanicMeta): """ if not handler: - async def noop(): + async def noop(**context): ... handler = noop diff --git a/sanic/signals.py b/sanic/signals.py index fe252c12..513c43e0 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -2,6 +2,8 @@ from __future__ import annotations import asyncio +from collections import deque +from dataclasses import dataclass from enum import Enum from inspect import isawaitable from typing import Any, Dict, List, Optional, Tuple, Union, cast @@ -76,6 +78,38 @@ class Signal(Route): """A `Route` that is used to dispatch signals to handlers""" +@dataclass +class SignalWaiter: + """A record representing a future waiting for a signal""" + + signal: Signal + event_definition: str + trigger: str = "" + requirements: Optional[Dict[str, str]] = None + exclusive: bool = True + + future: Optional[asyncio.Future] = None + + async def wait(self): + """Block until the signal is next dispatched. + + Return the context of the signal dispatch, if any. + """ + loop = asyncio.get_running_loop() + self.future = loop.create_future() + self.signal.ctx.waiters.append(self) + try: + return await self.future + finally: + self.signal.ctx.waiters.remove(self) + + def matches(self, event, condition): + return ((condition is None and not self.exclusive) + or (condition is None and not self.requirements) + or condition == self.requirements + ) and (self.trigger or event == self.event_definition) + + class SignalGroup(RouteGroup): """A `RouteGroup` that is used to dispatch signals to handlers""" @@ -160,18 +194,20 @@ class SignalRouter(BaseRouter): error_logger.warning(str(e)) return None - events = [signal.ctx.event for signal in group] - for signal_event in events: - signal_event.set() if context: params.update(context) + params.pop("__trigger__", None) signals = group.routes if not reverse: signals = signals[::-1] try: for signal in signals: - params.pop("__trigger__", None) + for waiter in signal.ctx.waiters: + if waiter.matches(event, condition): + waiter.future.set_result(dict(params)) + + for signal in signals: requirements = signal.extra.requirements if ( (condition is None and signal.ctx.exclusive is False) @@ -197,9 +233,6 @@ class SignalRouter(BaseRouter): ) setattr(e, "__dispatched__", True) raise e - finally: - for signal_event in events: - signal_event.clear() async def dispatch( self, @@ -244,14 +277,29 @@ class SignalRouter(BaseRouter): await asyncio.sleep(0) return task - def add( # type: ignore - self, - handler: SignalHandler, - event: str, - condition: Optional[Dict[str, Any]] = None, - exclusive: bool = True, - ) -> Signal: - event_definition = event + def get_waiter( + self, + event: Union[str, Enum], + condition: Optional[Dict[str, Any]], + exclusive: bool, + ): + event_definition = str(event.value) if isinstance(event, Enum) else event + name, trigger, _ = self._get_event_parts(event_definition) + signal = cast(Signal, self.name_index.get(name)) + if not signal: + return None + + if event_definition.endswith(".*") and not trigger: + trigger = "*" + return SignalWaiter( + signal=signal, + event_definition=event_definition, + trigger=trigger, + requirements=condition, + exclusive=bool(exclusive), + ) + + def _get_event_parts(self, event): parts = self._build_event_parts(event) if parts[2].startswith("<"): name = ".".join([*parts[:-1], "*"]) @@ -263,6 +311,18 @@ class SignalRouter(BaseRouter): if not trigger: event = ".".join([*parts[:2], "<__trigger__>"]) + return name, trigger, event + + def add( # type: ignore + self, + handler: SignalHandler, + event: str, + condition: Optional[Dict[str, Any]] = None, + exclusive: bool = True, + ) -> Signal: + event_definition = event + name, trigger, event = self._get_event_parts(event) + signal = super().add( event, handler, @@ -298,7 +358,7 @@ class SignalRouter(BaseRouter): raise RuntimeError("Cannot finalize signals outside of event loop") for signal in self.routes: - signal.ctx.event = asyncio.Event() + signal.ctx.waiters = deque() return super().finalize(do_compile=do_compile, do_optimize=do_optimize) diff --git a/tests/test_signals.py b/tests/test_signals.py index 8a564a98..b451b9ce 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -78,6 +78,49 @@ def test_invalid_signal(app, signal): ... +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_event(app): + + @app.signal("foo.bar.baz") + def sync_signal(*args): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz")) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_correct_event(app): + # Check for https://github.com/sanic-org/sanic/issues/2826 + + @app.signal("foo.bar.baz") + def sync_signal(*args): + pass + + @app.signal("foo.bar.spam") + def sync_signal(*args): + pass + + app.signal_router.finalize() + + baz_task = asyncio.create_task(app.event("foo.bar.baz")) + spam_task = asyncio.create_task(app.event("foo.bar.spam")) + + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert baz_task.done() + assert not spam_task.done() + baz_task.result() + spam_task.cancel() + + @pytest.mark.asyncio async def test_dispatch_signal_with_enum_event(app): counter = 0 @@ -97,6 +140,26 @@ async def test_dispatch_signal_with_enum_event(app): assert counter == 1 +@pytest.mark.asyncio +async def test_dispatch_signal_with_enum_event_to_event(app): + + class FooEnum(Enum): + FOO_BAR_BAZ = "foo.bar.baz" + + @app.signal(FooEnum.FOO_BAR_BAZ) + def sync_signal(*args): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event(FooEnum.FOO_BAR_BAZ)) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_multiple_handlers(app): counter = 0 @@ -121,22 +184,45 @@ async def test_dispatch_signal_triggers_multiple_handlers(app): @pytest.mark.asyncio -async def test_dispatch_signal_triggers_triggers_event(app): - counter = 0 +async def test_dispatch_signal_triggers_multiple_events(app): @app.signal("foo.bar.baz") - def sync_signal(*args): - nonlocal app - nonlocal counter - group, *_ = app.signal_router.get("foo.bar.baz") - for signal in group: - counter += signal.ctx.event.is_set() + def sync_signal(*_): + pass app.signal_router.finalize() - await app.dispatch("foo.bar.baz") + event_task1 = asyncio.create_task(app.event("foo.bar.baz")) + event_task2 = asyncio.create_task(app.event("foo.bar.baz")) - assert counter == 1 + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert event_task1.done() + assert event_task2.done() + event_task1.result() # Will raise if there was an exception + event_task2.result() # Will raise if there was an exception + + +@pytest.mark.asyncio +async def test_dispatch_signal_with_multiple_handlers_triggers_event_once(app): + + @app.signal("foo.bar.baz") + def sync_signal(*_): + pass + + @app.signal("foo.bar.baz") + async def async_signal(*_): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz")) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -155,6 +241,40 @@ async def test_dispatch_signal_triggers_dynamic_route(app): assert counter == 9 +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_parameterized_dynamic_route_event(app): + + @app.signal("foo.bar.") + def sync_signal(baz): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.")) + await app.dispatch("foo.bar.9") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_starred_dynamic_route_event(app): + + @app.signal("foo.bar.") + def sync_signal(baz): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.*")) + await app.dispatch("foo.bar.9") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_requirements(app): counter = 0 @@ -172,6 +292,26 @@ async def test_dispatch_signal_triggers_with_requirements(app): assert counter == 1 +@pytest.mark.asyncio +async def test_dispatch_signal_to_event_with_requirements(app): + + @app.signal("foo.bar.baz") + def sync_signal(*_): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"})) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + assert not event_task.done() + + await app.dispatch("foo.bar.baz", condition={"one": "two"}) + await asyncio.sleep(0) + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_requirements_exclusive(app): counter = 0 @@ -189,6 +329,28 @@ async def test_dispatch_signal_triggers_with_requirements_exclusive(app): assert counter == 2 +@pytest.mark.asyncio +async def test_dispatch_signal_to_event_with_requirements_exclusive(app): + + @app.signal("foo.bar.baz") + def sync_signal(*_): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False)) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + assert event_task.done() + event_task.result() # Will raise if there was an exception + + event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False)) + await app.dispatch("foo.bar.baz", condition={"one": "two"}) + await asyncio.sleep(0) + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_context(app): counter = 0 @@ -204,6 +366,22 @@ async def test_dispatch_signal_triggers_with_context(app): assert counter == 9 +@pytest.mark.asyncio +async def test_dispatch_signal_to_event_with_context(app): + + @app.signal("foo.bar.baz") + def sync_signal(**context): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz")) + await app.dispatch("foo.bar.baz", context={"amount": 9}) + await asyncio.sleep(0) + assert event_task.done() + assert event_task.result()['amount'] == 9 + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_context_fail(app): counter = 0 @@ -219,6 +397,22 @@ async def test_dispatch_signal_triggers_with_context_fail(app): await app.dispatch("foo.bar.baz", {"amount": 9}) +@pytest.mark.asyncio +async def test_dispatch_signal_to_dynamic_route_event(app): + + @app.signal("foo.bar.") + def sync_signal(**context): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.")) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + assert event_task.done() + assert event_task.result()['something'] == "baz" + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_on_bp(app): bp = Blueprint("bp") @@ -267,61 +461,67 @@ async def test_dispatch_signal_triggers_on_bp_alone(app): @pytest.mark.asyncio -async def test_dispatch_signal_triggers_event(app): - app_counter = 0 +async def test_dispatch_signal_triggers_event_on_bp(app): + bp = Blueprint("bp") @app.signal("foo.bar.baz") def app_signal(): ... - async def do_wait(): - nonlocal app_counter - await app.event("foo.bar.baz") - app_counter += 1 + @bp.signal("foo.bar.baz") + def bp_signal(): + ... + app.blueprint(bp) app.signal_router.finalize() + app_task = asyncio.create_task(app.event("foo.bar.baz")) + bp_task = asyncio.create_task(bp.event("foo.bar.baz")) + await asyncio.sleep(0) await app.dispatch("foo.bar.baz") - waiter = app.event("foo.bar.baz") - assert isawaitable(waiter) - fut = asyncio.ensure_future(do_wait()) - await app.dispatch("foo.bar.baz") - await fut + # Allow a few event loop iterations for tasks to finish + for _ in range(5): + await asyncio.sleep(0) - assert app_counter == 1 + assert app_task.done() + assert bp_task.done() + app_task.result() + bp_task.result() + + app_task = asyncio.create_task(app.event("foo.bar.baz")) + bp_task = asyncio.create_task(bp.event("foo.bar.baz")) + await asyncio.sleep(0) + await bp.dispatch("foo.bar.baz") + + # Allow a few event loop iterations for tasks to finish + for _ in range(5): + await asyncio.sleep(0) + + assert bp_task.done() + assert not app_task.done() + bp_task.result() + app_task.cancel() @pytest.mark.asyncio -async def test_dispatch_signal_triggers_event_on_bp(app): +async def test_dispatch_signal_triggers_event_on_bp_with_context(app): bp = Blueprint("bp") - bp_counter = 0 @bp.signal("foo.bar.baz") def bp_signal(): ... - async def do_wait(): - nonlocal bp_counter - await bp.event("foo.bar.baz") - bp_counter += 1 - app.blueprint(bp) app.signal_router.finalize() - signal_group, *_ = app.signal_router.get( - "foo.bar.baz", condition={"blueprint": "bp"} - ) - await bp.dispatch("foo.bar.baz") - waiter = bp.event("foo.bar.baz") - assert isawaitable(waiter) - - fut = do_wait() - for signal in signal_group: - signal.ctx.event.set() - await asyncio.gather(fut) - - assert bp_counter == 1 + event_task = asyncio.create_task(bp.event("foo.bar.baz")) + await asyncio.sleep(0) + await app.dispatch("foo.bar.baz", context={"amount": 9}) + for _ in range(5): + await asyncio.sleep(0) + assert event_task.done() + assert event_task.result()['amount'] == 9 def test_bad_finalize(app):