Update app.event() to return context dict (#2088), and only fire for the specific requested event (#2826).

Also added support for passing conditions= and exclusive= arguments to app.event().
This commit is contained in:
Josh Bartlett 2023-09-23 22:59:44 +10:00
parent a5a9658896
commit 4d0a0a3570
5 changed files with 367 additions and 75 deletions

View File

@ -17,6 +17,7 @@ from asyncio import (
from asyncio.futures import Future
from collections import defaultdict, deque
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, wraps
from inspect import isawaitable
from os import environ
@ -663,7 +664,12 @@ class Sanic(
)
async def event(
self, event: str, timeout: Optional[Union[int, float]] = None
self,
event: Union[str, Enum],
timeout: Optional[Union[int, float]] = None,
*,
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> None:
"""Wait for a specific event to be triggered.
@ -686,13 +692,18 @@ class Sanic(
timeout (Optional[Union[int, float]]): An optional timeout value
in seconds. If provided, the wait will be terminated if the
timeout is reached. Defaults to `None`, meaning no timeout.
condition: If provided, method will only return when the signal
is dispatched with the given condition.
exclusive: When true (default), the signal can only be dispatched
when the condition has been met. When ``False``, the signal can
be dispatched either with or without it.
Raises:
NotFound: If the event is not found and auto-registration of
events is not enabled.
Returns:
None
The context dict of the dispatched signal.
Examples:
```python
@ -708,16 +719,16 @@ class Sanic(
```
"""
signal = self.signal_router.name_index.get(event)
if not signal:
waiter = self.signal_router.get_waiter(event, condition, exclusive)
if not waiter:
if self.config.EVENT_AUTOREGISTER:
self.signal_router.reset()
self.add_signal(None, event)
signal = self.signal_router.name_index[event]
waiter = self.signal_router.get_waiter(event, condition, exclusive)
self.signal_router.finalize()
else:
raise NotFound("Could not find signal %s" % event)
return await wait_for(signal.ctx.event.wait(), timeout=timeout)
return await wait_for(waiter.wait(), timeout=timeout)
def report_exception(
self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]]

View File

@ -512,33 +512,54 @@ class Blueprint(BaseSanic):
condition = kwargs.pop("condition", {})
condition.update({"__blueprint__": self.name})
kwargs["condition"] = condition
await asyncio.gather(
return await asyncio.gather(
*[app.dispatch(*args, **kwargs) for app in self.apps]
)
def event(self, event: str, timeout: Optional[Union[int, float]] = None):
def event(
self,
event: str,
timeout: Optional[Union[int, float]] = None,
*,
condition: Optional[Dict[str, Any]] = None,
):
"""Wait for a signal event to be dispatched.
Args:
event (str): Name of the signal event.
timeout (Optional[Union[int, float]]): Timeout for the event to be
dispatched.
condition: If provided, method will only return when the signal
is dispatched with the given condition.
Returns:
Awaitable: Awaitable for the event to be dispatched.
"""
events = set()
for app in self.apps:
signal = app.signal_router.name_index.get(event)
if not signal:
raise NotFound("Could not find signal %s" % event)
events.add(signal.ctx.event)
if condition is None:
condition = {}
condition.update({"__blueprint__": self.name})
return asyncio.wait(
[asyncio.create_task(event.wait()) for event in events],
waiters = []
for app in self.apps:
waiter = app.signal_router.get_waiter(event, condition, exclusive=False)
if not waiter:
raise NotFound("Could not find signal %s" % event)
waiters.append(waiter)
return self._event(waiters, timeout)
async def _event(self, waiters, timeout):
done, pending = await asyncio.wait(
[asyncio.create_task(waiter.wait()) for waiter in waiters],
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout,
)
for task in pending:
task.cancel()
if not done:
raise TimeoutError()
finished_task, = done
return finished_task.result()
@staticmethod
def _extract_value(*values):

View File

@ -88,7 +88,7 @@ class SignalMixin(metaclass=SanicMeta):
"""
if not handler:
async def noop():
async def noop(**context):
...
handler = noop

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import asyncio
from collections import deque
from dataclasses import dataclass
from enum import Enum
from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union, cast
@ -76,6 +78,38 @@ class Signal(Route):
"""A `Route` that is used to dispatch signals to handlers"""
@dataclass
class SignalWaiter:
"""A record representing a future waiting for a signal"""
signal: Signal
event_definition: str
trigger: str = ""
requirements: Optional[Dict[str, str]] = None
exclusive: bool = True
future: Optional[asyncio.Future] = None
async def wait(self):
"""Block until the signal is next dispatched.
Return the context of the signal dispatch, if any.
"""
loop = asyncio.get_running_loop()
self.future = loop.create_future()
self.signal.ctx.waiters.append(self)
try:
return await self.future
finally:
self.signal.ctx.waiters.remove(self)
def matches(self, event, condition):
return ((condition is None and not self.exclusive)
or (condition is None and not self.requirements)
or condition == self.requirements
) and (self.trigger or event == self.event_definition)
class SignalGroup(RouteGroup):
"""A `RouteGroup` that is used to dispatch signals to handlers"""
@ -160,18 +194,20 @@ class SignalRouter(BaseRouter):
error_logger.warning(str(e))
return None
events = [signal.ctx.event for signal in group]
for signal_event in events:
signal_event.set()
if context:
params.update(context)
params.pop("__trigger__", None)
signals = group.routes
if not reverse:
signals = signals[::-1]
try:
for signal in signals:
params.pop("__trigger__", None)
for waiter in signal.ctx.waiters:
if waiter.matches(event, condition):
waiter.future.set_result(dict(params))
for signal in signals:
requirements = signal.extra.requirements
if (
(condition is None and signal.ctx.exclusive is False)
@ -197,9 +233,6 @@ class SignalRouter(BaseRouter):
)
setattr(e, "__dispatched__", True)
raise e
finally:
for signal_event in events:
signal_event.clear()
async def dispatch(
self,
@ -244,14 +277,29 @@ class SignalRouter(BaseRouter):
await asyncio.sleep(0)
return task
def add( # type: ignore
def get_waiter(
self,
handler: SignalHandler,
event: str,
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Signal:
event_definition = event
event: Union[str, Enum],
condition: Optional[Dict[str, Any]],
exclusive: bool,
):
event_definition = str(event.value) if isinstance(event, Enum) else event
name, trigger, _ = self._get_event_parts(event_definition)
signal = cast(Signal, self.name_index.get(name))
if not signal:
return None
if event_definition.endswith(".*") and not trigger:
trigger = "*"
return SignalWaiter(
signal=signal,
event_definition=event_definition,
trigger=trigger,
requirements=condition,
exclusive=bool(exclusive),
)
def _get_event_parts(self, event):
parts = self._build_event_parts(event)
if parts[2].startswith("<"):
name = ".".join([*parts[:-1], "*"])
@ -263,6 +311,18 @@ class SignalRouter(BaseRouter):
if not trigger:
event = ".".join([*parts[:2], "<__trigger__>"])
return name, trigger, event
def add( # type: ignore
self,
handler: SignalHandler,
event: str,
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Signal:
event_definition = event
name, trigger, event = self._get_event_parts(event)
signal = super().add(
event,
handler,
@ -298,7 +358,7 @@ class SignalRouter(BaseRouter):
raise RuntimeError("Cannot finalize signals outside of event loop")
for signal in self.routes:
signal.ctx.event = asyncio.Event()
signal.ctx.waiters = deque()
return super().finalize(do_compile=do_compile, do_optimize=do_optimize)

View File

@ -78,6 +78,49 @@ def test_invalid_signal(app, signal):
...
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_event(app):
@app.signal("foo.bar.baz")
def sync_signal(*args):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.baz"))
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_correct_event(app):
# Check for https://github.com/sanic-org/sanic/issues/2826
@app.signal("foo.bar.baz")
def sync_signal(*args):
pass
@app.signal("foo.bar.spam")
def sync_signal(*args):
pass
app.signal_router.finalize()
baz_task = asyncio.create_task(app.event("foo.bar.baz"))
spam_task = asyncio.create_task(app.event("foo.bar.spam"))
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert baz_task.done()
assert not spam_task.done()
baz_task.result()
spam_task.cancel()
@pytest.mark.asyncio
async def test_dispatch_signal_with_enum_event(app):
counter = 0
@ -97,6 +140,26 @@ async def test_dispatch_signal_with_enum_event(app):
assert counter == 1
@pytest.mark.asyncio
async def test_dispatch_signal_with_enum_event_to_event(app):
class FooEnum(Enum):
FOO_BAR_BAZ = "foo.bar.baz"
@app.signal(FooEnum.FOO_BAR_BAZ)
def sync_signal(*args):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event(FooEnum.FOO_BAR_BAZ))
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_multiple_handlers(app):
counter = 0
@ -121,22 +184,45 @@ async def test_dispatch_signal_triggers_multiple_handlers(app):
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_triggers_event(app):
counter = 0
async def test_dispatch_signal_triggers_multiple_events(app):
@app.signal("foo.bar.baz")
def sync_signal(*args):
nonlocal app
nonlocal counter
group, *_ = app.signal_router.get("foo.bar.baz")
for signal in group:
counter += signal.ctx.event.is_set()
def sync_signal(*_):
pass
app.signal_router.finalize()
await app.dispatch("foo.bar.baz")
event_task1 = asyncio.create_task(app.event("foo.bar.baz"))
event_task2 = asyncio.create_task(app.event("foo.bar.baz"))
assert counter == 1
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert event_task1.done()
assert event_task2.done()
event_task1.result() # Will raise if there was an exception
event_task2.result() # Will raise if there was an exception
@pytest.mark.asyncio
async def test_dispatch_signal_with_multiple_handlers_triggers_event_once(app):
@app.signal("foo.bar.baz")
def sync_signal(*_):
pass
@app.signal("foo.bar.baz")
async def async_signal(*_):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.baz"))
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
@pytest.mark.asyncio
@ -155,6 +241,40 @@ async def test_dispatch_signal_triggers_dynamic_route(app):
assert counter == 9
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_parameterized_dynamic_route_event(app):
@app.signal("foo.bar.<baz:int>")
def sync_signal(baz):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.<baz:int>"))
await app.dispatch("foo.bar.9")
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_starred_dynamic_route_event(app):
@app.signal("foo.bar.<baz:int>")
def sync_signal(baz):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.*"))
await app.dispatch("foo.bar.9")
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_with_requirements(app):
counter = 0
@ -172,6 +292,26 @@ async def test_dispatch_signal_triggers_with_requirements(app):
assert counter == 1
@pytest.mark.asyncio
async def test_dispatch_signal_to_event_with_requirements(app):
@app.signal("foo.bar.baz")
def sync_signal(*_):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}))
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert not event_task.done()
await app.dispatch("foo.bar.baz", condition={"one": "two"})
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_with_requirements_exclusive(app):
counter = 0
@ -189,6 +329,28 @@ async def test_dispatch_signal_triggers_with_requirements_exclusive(app):
assert counter == 2
@pytest.mark.asyncio
async def test_dispatch_signal_to_event_with_requirements_exclusive(app):
@app.signal("foo.bar.baz")
def sync_signal(*_):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False))
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False))
await app.dispatch("foo.bar.baz", condition={"one": "two"})
await asyncio.sleep(0)
assert event_task.done()
event_task.result() # Will raise if there was an exception
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_with_context(app):
counter = 0
@ -204,6 +366,22 @@ async def test_dispatch_signal_triggers_with_context(app):
assert counter == 9
@pytest.mark.asyncio
async def test_dispatch_signal_to_event_with_context(app):
@app.signal("foo.bar.baz")
def sync_signal(**context):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.baz"))
await app.dispatch("foo.bar.baz", context={"amount": 9})
await asyncio.sleep(0)
assert event_task.done()
assert event_task.result()['amount'] == 9
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_with_context_fail(app):
counter = 0
@ -219,6 +397,22 @@ async def test_dispatch_signal_triggers_with_context_fail(app):
await app.dispatch("foo.bar.baz", {"amount": 9})
@pytest.mark.asyncio
async def test_dispatch_signal_to_dynamic_route_event(app):
@app.signal("foo.bar.<something>")
def sync_signal(**context):
pass
app.signal_router.finalize()
event_task = asyncio.create_task(app.event("foo.bar.<something>"))
await app.dispatch("foo.bar.baz")
await asyncio.sleep(0)
assert event_task.done()
assert event_task.result()['something'] == "baz"
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_on_bp(app):
bp = Blueprint("bp")
@ -267,61 +461,67 @@ async def test_dispatch_signal_triggers_on_bp_alone(app):
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_event(app):
app_counter = 0
async def test_dispatch_signal_triggers_event_on_bp(app):
bp = Blueprint("bp")
@app.signal("foo.bar.baz")
def app_signal():
...
async def do_wait():
nonlocal app_counter
await app.event("foo.bar.baz")
app_counter += 1
@bp.signal("foo.bar.baz")
def bp_signal():
...
app.blueprint(bp)
app.signal_router.finalize()
app_task = asyncio.create_task(app.event("foo.bar.baz"))
bp_task = asyncio.create_task(bp.event("foo.bar.baz"))
await asyncio.sleep(0)
await app.dispatch("foo.bar.baz")
waiter = app.event("foo.bar.baz")
assert isawaitable(waiter)
fut = asyncio.ensure_future(do_wait())
await app.dispatch("foo.bar.baz")
await fut
# Allow a few event loop iterations for tasks to finish
for _ in range(5):
await asyncio.sleep(0)
assert app_counter == 1
assert app_task.done()
assert bp_task.done()
app_task.result()
bp_task.result()
app_task = asyncio.create_task(app.event("foo.bar.baz"))
bp_task = asyncio.create_task(bp.event("foo.bar.baz"))
await asyncio.sleep(0)
await bp.dispatch("foo.bar.baz")
# Allow a few event loop iterations for tasks to finish
for _ in range(5):
await asyncio.sleep(0)
assert bp_task.done()
assert not app_task.done()
bp_task.result()
app_task.cancel()
@pytest.mark.asyncio
async def test_dispatch_signal_triggers_event_on_bp(app):
async def test_dispatch_signal_triggers_event_on_bp_with_context(app):
bp = Blueprint("bp")
bp_counter = 0
@bp.signal("foo.bar.baz")
def bp_signal():
...
async def do_wait():
nonlocal bp_counter
await bp.event("foo.bar.baz")
bp_counter += 1
app.blueprint(bp)
app.signal_router.finalize()
signal_group, *_ = app.signal_router.get(
"foo.bar.baz", condition={"blueprint": "bp"}
)
await bp.dispatch("foo.bar.baz")
waiter = bp.event("foo.bar.baz")
assert isawaitable(waiter)
fut = do_wait()
for signal in signal_group:
signal.ctx.event.set()
await asyncio.gather(fut)
assert bp_counter == 1
event_task = asyncio.create_task(bp.event("foo.bar.baz"))
await asyncio.sleep(0)
await app.dispatch("foo.bar.baz", context={"amount": 9})
for _ in range(5):
await asyncio.sleep(0)
assert event_task.done()
assert event_task.result()['amount'] == 9
def test_bad_finalize(app):