From 85e7b712b90a82bbf7f771732495515181272c62 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 17 Nov 2021 17:29:41 +0200 Subject: [PATCH] Allow early Blueprint registrations to still apply later added objects (#2260) --- sanic/app.py | 4 ++ sanic/blueprints.py | 126 +++++++++++++++++++++++++---------- sanic/models/futures.py | 4 ++ tests/test_blueprint_copy.py | 4 +- tests/test_blueprints.py | 28 ++++++++ 5 files changed, 129 insertions(+), 37 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 30af974a..d78c67de 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -72,6 +72,7 @@ from sanic.models.futures import ( FutureException, FutureListener, FutureMiddleware, + FutureRegistry, FutureRoute, FutureSignal, FutureStatic, @@ -115,6 +116,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "_future_exceptions", "_future_listeners", "_future_middleware", + "_future_registry", "_future_routes", "_future_signals", "_future_statics", @@ -187,6 +189,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self._test_manager: Any = None self._blueprint_order: List[Blueprint] = [] self._delayed_tasks: List[str] = [] + self._future_registry: FutureRegistry = FutureRegistry() self._state: ApplicationState = ApplicationState(app=self) self.blueprints: Dict[str, Blueprint] = {} self.config: Config = config or Config( @@ -1625,6 +1628,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): raise e async def _startup(self): + self._future_registry.clear() self.signalize() self.finalize() ErrorHandler.finalize(self.error_handler) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index e13cafcd..290773fa 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -4,7 +4,9 @@ import asyncio from collections import defaultdict from copy import deepcopy -from enum import Enum +from functools import wraps +from inspect import isfunction +from itertools import chain from types import SimpleNamespace from typing import ( TYPE_CHECKING, @@ -13,7 +15,9 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, + Tuple, Union, ) @@ -36,6 +40,32 @@ if TYPE_CHECKING: from sanic import Sanic # noqa +def lazy(func, as_decorator=True): + @wraps(func) + def decorator(bp, *args, **kwargs): + nonlocal as_decorator + kwargs["apply"] = False + pass_handler = None + + if args and isfunction(args[0]): + as_decorator = False + + def wrapper(handler): + future = func(bp, *args, **kwargs) + if as_decorator: + future = future(handler) + + if bp.registered: + for app in bp.apps: + bp.register(app, {}) + + return future + + return wrapper if as_decorator else wrapper(pass_handler) + + return decorator + + class Blueprint(BaseSanic): """ In *Sanic* terminology, a **Blueprint** is a logical collection of @@ -125,29 +155,16 @@ class Blueprint(BaseSanic): ) return self._apps - def route(self, *args, **kwargs): - kwargs["apply"] = False - return super().route(*args, **kwargs) + @property + def registered(self) -> bool: + return bool(self._apps) - def static(self, *args, **kwargs): - kwargs["apply"] = False - return super().static(*args, **kwargs) - - def middleware(self, *args, **kwargs): - kwargs["apply"] = False - return super().middleware(*args, **kwargs) - - def listener(self, *args, **kwargs): - kwargs["apply"] = False - return super().listener(*args, **kwargs) - - def exception(self, *args, **kwargs): - kwargs["apply"] = False - return super().exception(*args, **kwargs) - - def signal(self, event: Union[str, Enum], *args, **kwargs): - kwargs["apply"] = False - return super().signal(event, *args, **kwargs) + exception = lazy(BaseSanic.exception) + listener = lazy(BaseSanic.listener) + middleware = lazy(BaseSanic.middleware) + route = lazy(BaseSanic.route) + signal = lazy(BaseSanic.signal) + static = lazy(BaseSanic.static, as_decorator=False) def reset(self): self._apps: Set[Sanic] = set() @@ -284,6 +301,7 @@ class Blueprint(BaseSanic): middleware = [] exception_handlers = [] listeners = defaultdict(list) + registered = set() # Routes for future in self._future_routes: @@ -310,12 +328,15 @@ class Blueprint(BaseSanic): ) name = app._generate_name(future.name) + host = future.host or self.host + if isinstance(host, list): + host = tuple(host) apply_route = FutureRoute( future.handler, uri[1:] if uri.startswith("//") else uri, future.methods, - future.host or self.host, + host, strict_slashes, future.stream, version, @@ -329,6 +350,10 @@ class Blueprint(BaseSanic): error_format, ) + if (self, apply_route) in app._future_registry: + continue + + registered.add(apply_route) route = app._apply_route(apply_route) operation = ( routes.extend if isinstance(route, list) else routes.append @@ -340,6 +365,11 @@ class Blueprint(BaseSanic): # Prepend the blueprint URI prefix if available uri = url_prefix + future.uri if url_prefix else future.uri apply_route = FutureStatic(uri, *future[1:]) + + if (self, apply_route) in app._future_registry: + continue + + registered.add(apply_route) route = app._apply_static(apply_route) routes.append(route) @@ -348,30 +378,51 @@ class Blueprint(BaseSanic): if route_names: # Middleware for future in self._future_middleware: + if (self, future) in app._future_registry: + continue middleware.append(app._apply_middleware(future, route_names)) # Exceptions for future in self._future_exceptions: + if (self, future) in app._future_registry: + continue exception_handlers.append( app._apply_exception_handler(future, route_names) ) # Event listeners - for listener in self._future_listeners: - listeners[listener.event].append(app._apply_listener(listener)) + for future in self._future_listeners: + if (self, future) in app._future_registry: + continue + listeners[future.event].append(app._apply_listener(future)) # Signals - for signal in self._future_signals: - signal.condition.update({"blueprint": self.name}) - app._apply_signal(signal) + for future in self._future_signals: + if (self, future) in app._future_registry: + continue + future.condition.update({"blueprint": self.name}) + app._apply_signal(future) - self.routes = [route for route in routes if isinstance(route, Route)] - self.websocket_routes = [ + self.routes += [route for route in routes if isinstance(route, Route)] + self.websocket_routes += [ route for route in self.routes if route.ctx.websocket ] - self.middlewares = middleware - self.exceptions = exception_handlers - self.listeners = dict(listeners) + self.middlewares += middleware + self.exceptions += exception_handlers + self.listeners.update(dict(listeners)) + + if self.registered: + self.register_futures( + self.apps, + self, + chain( + registered, + self._future_middleware, + self._future_exceptions, + self._future_listeners, + self._future_signals, + ), + ) async def dispatch(self, *args, **kwargs): condition = kwargs.pop("condition", {}) @@ -403,3 +454,10 @@ class Blueprint(BaseSanic): value = v break return value + + @staticmethod + def register_futures( + apps: Set[Sanic], bp: Blueprint, futures: Sequence[Tuple[Any, ...]] + ): + for app in apps: + app._future_registry.update(set((bp, item) for item in futures)) diff --git a/sanic/models/futures.py b/sanic/models/futures.py index fe7d77eb..74ee92b9 100644 --- a/sanic/models/futures.py +++ b/sanic/models/futures.py @@ -60,3 +60,7 @@ class FutureSignal(NamedTuple): handler: SignalHandler event: str condition: Optional[Dict[str, str]] + + +class FutureRegistry(set): + ... diff --git a/tests/test_blueprint_copy.py b/tests/test_blueprint_copy.py index 033e2e20..ca8cd67e 100644 --- a/tests/test_blueprint_copy.py +++ b/tests/test_blueprint_copy.py @@ -1,6 +1,4 @@ -from copy import deepcopy - -from sanic import Blueprint, Sanic, blueprints, response +from sanic import Blueprint, Sanic from sanic.response import text diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index b6a23151..3aa4487a 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -1088,3 +1088,31 @@ def test_bp_set_attribute_warning(): "and will be removed in version 21.12. You should change your " "Blueprint instance to use instance.ctx.foo instead." ) + + +def test_early_registration(app): + assert len(app.router.routes) == 0 + + bp = Blueprint("bp") + + @bp.get("/one") + async def one(_): + return text("one") + + app.blueprint(bp) + + assert len(app.router.routes) == 1 + + @bp.get("/two") + async def two(_): + return text("two") + + @bp.get("/three") + async def three(_): + return text("three") + + assert len(app.router.routes) == 3 + + for path in ("one", "two", "three"): + _, response = app.test_client.get(f"/{path}") + assert response.text == path