From 2fea954dcff2e0545277bb492192e3ee2516a095 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 14 Mar 2021 15:21:59 +0200 Subject: [PATCH] Add signal reservations (#2060) * Add signal reservations * Simplify reservations --- sanic/signals.py | 33 +++++++++++++++++++++++---------- tests/test_signals.py | 20 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/sanic/signals.py b/sanic/signals.py index 27f6bfa4..0e6f73f1 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from inspect import isawaitable -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from sanic_routing import BaseRouter, Route # type: ignore from sanic_routing.exceptions import NotFound # type: ignore @@ -13,6 +13,12 @@ from sanic.exceptions import InvalidSignal from sanic.models.handler_types import SignalHandler +RESERVED_NAMESPACES = ( + "server", + "http", +) + + class Signal(Route): def get_handler(self, raw_path, method, _): method = method or self.router.DEFAULT_METHOD @@ -97,15 +103,7 @@ class SignalRouter(BaseRouter): event: str, condition: Optional[Dict[str, Any]] = None, ) -> Signal: - parts = path_to_parts(event, self.delimiter) - - if ( - len(parts) != 3 - or parts[0].startswith("<") - or parts[1].startswith("<") - ): - raise InvalidSignal(f"Invalid signal event: {event}") - + parts = self._build_event_parts(event) if parts[2].startswith("<"): name = ".".join([*parts[:-1], "*"]) else: @@ -131,3 +129,18 @@ class SignalRouter(BaseRouter): signal.ctx.event = asyncio.Event() return super().finalize(do_compile=do_compile) + + def _build_event_parts(self, event: str) -> Tuple[str, str, str]: + parts = path_to_parts(event, self.delimiter) + if ( + len(parts) != 3 + or parts[0].startswith("<") + or parts[1].startswith("<") + ): + raise InvalidSignal("Invalid signal event: %s" % event) + + if parts[0] in RESERVED_NAMESPACES: + raise InvalidSignal( + "Cannot declare reserved signal event: %s" % event + ) + return parts diff --git a/tests/test_signals.py b/tests/test_signals.py index d06508f9..915d92d6 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -280,3 +280,23 @@ def test_event_on_bp_not_registered(): match=" has not yet been registered to an app", ): bp.event("foo.bar.baz") + + +@pytest.mark.parametrize( + "event,expected", + ( + ("foo.bar.baz", True), + ("server.init.before", False), + ("http.request.start", False), + ("sanic.notice.anything", True), + ), +) +def test_signal_reservation(app, event, expected): + if not expected: + with pytest.raises( + InvalidSignal, + match=f"Cannot declare reserved signal event: {event}", + ): + app.signal(event)(lambda: ...) + else: + app.signal(event)(lambda: ...)