Add signal reservations (#2060)

* Add signal reservations

* Simplify reservations
This commit is contained in:
Adam Hopkins 2021-03-14 15:21:59 +02:00 committed by GitHub
parent d4660d0ca7
commit 2fea954dcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 10 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from inspect import isawaitable from inspect import isawaitable
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
from sanic_routing import BaseRouter, Route # type: ignore from sanic_routing import BaseRouter, Route # type: ignore
from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.exceptions import NotFound # type: ignore
@ -13,6 +13,12 @@ from sanic.exceptions import InvalidSignal
from sanic.models.handler_types import SignalHandler from sanic.models.handler_types import SignalHandler
RESERVED_NAMESPACES = (
"server",
"http",
)
class Signal(Route): class Signal(Route):
def get_handler(self, raw_path, method, _): def get_handler(self, raw_path, method, _):
method = method or self.router.DEFAULT_METHOD method = method or self.router.DEFAULT_METHOD
@ -97,15 +103,7 @@ class SignalRouter(BaseRouter):
event: str, event: str,
condition: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, Any]] = None,
) -> Signal: ) -> Signal:
parts = path_to_parts(event, self.delimiter) parts = self._build_event_parts(event)
if (
len(parts) != 3
or parts[0].startswith("<")
or parts[1].startswith("<")
):
raise InvalidSignal(f"Invalid signal event: {event}")
if parts[2].startswith("<"): if parts[2].startswith("<"):
name = ".".join([*parts[:-1], "*"]) name = ".".join([*parts[:-1], "*"])
else: else:
@ -131,3 +129,18 @@ class SignalRouter(BaseRouter):
signal.ctx.event = asyncio.Event() signal.ctx.event = asyncio.Event()
return super().finalize(do_compile=do_compile) return super().finalize(do_compile=do_compile)
def _build_event_parts(self, event: str) -> Tuple[str, str, str]:
parts = path_to_parts(event, self.delimiter)
if (
len(parts) != 3
or parts[0].startswith("<")
or parts[1].startswith("<")
):
raise InvalidSignal("Invalid signal event: %s" % event)
if parts[0] in RESERVED_NAMESPACES:
raise InvalidSignal(
"Cannot declare reserved signal event: %s" % event
)
return parts

View File

@ -280,3 +280,23 @@ def test_event_on_bp_not_registered():
match="<Blueprint bp> has not yet been registered to an app", match="<Blueprint bp> has not yet been registered to an app",
): ):
bp.event("foo.bar.baz") bp.event("foo.bar.baz")
@pytest.mark.parametrize(
"event,expected",
(
("foo.bar.baz", True),
("server.init.before", False),
("http.request.start", False),
("sanic.notice.anything", True),
),
)
def test_signal_reservation(app, event, expected):
if not expected:
with pytest.raises(
InvalidSignal,
match=f"Cannot declare reserved signal event: {event}",
):
app.signal(event)(lambda: ...)
else:
app.signal(event)(lambda: ...)