Copy Blueprints Implementation (#2184)

This commit is contained in:
Zhiwei 2021-08-09 17:07:04 -05:00 committed by GitHub
parent e1cfbf0fd9
commit e2eefaac55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 163 additions and 9 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from copy import deepcopy
from types import SimpleNamespace from types import SimpleNamespace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union
@ -12,6 +13,7 @@ from sanic_routing.route import Route # type: ignore
from sanic.base import BaseSanic from sanic.base import BaseSanic
from sanic.blueprint_group import BlueprintGroup from sanic.blueprint_group import BlueprintGroup
from sanic.exceptions import SanicException from sanic.exceptions import SanicException
from sanic.helpers import Default, _default
from sanic.models.futures import FutureRoute, FutureStatic from sanic.models.futures import FutureRoute, FutureStatic
from sanic.models.handler_types import ( from sanic.models.handler_types import (
ListenerType, ListenerType,
@ -40,7 +42,7 @@ class Blueprint(BaseSanic):
:param host: IP Address of FQDN for the sanic server to use. :param host: IP Address of FQDN for the sanic server to use.
:param version: Blueprint Version :param version: Blueprint Version
:param strict_slashes: Enforce the API urls are requested with a :param strict_slashes: Enforce the API urls are requested with a
training */* trailing */*
""" """
__fake_slots__ = ( __fake_slots__ = (
@ -76,15 +78,9 @@ class Blueprint(BaseSanic):
version_prefix: str = "/v", version_prefix: str = "/v",
): ):
super().__init__(name=name) super().__init__(name=name)
self.reset()
self._apps: Set[Sanic] = set()
self.ctx = SimpleNamespace() self.ctx = SimpleNamespace()
self.exceptions: List[RouteHandler] = []
self.host = host self.host = host
self.listeners: Dict[str, List[ListenerType]] = {}
self.middlewares: List[MiddlewareType] = []
self.routes: List[Route] = []
self.statics: List[RouteHandler] = []
self.strict_slashes = strict_slashes self.strict_slashes = strict_slashes
self.url_prefix = ( self.url_prefix = (
url_prefix[:-1] url_prefix[:-1]
@ -93,7 +89,6 @@ class Blueprint(BaseSanic):
) )
self.version = version self.version = version
self.version_prefix = version_prefix self.version_prefix = version_prefix
self.websocket_routes: List[Route] = []
def __repr__(self) -> str: def __repr__(self) -> str:
args = ", ".join( args = ", ".join(
@ -144,6 +139,81 @@ class Blueprint(BaseSanic):
kwargs["apply"] = False kwargs["apply"] = False
return super().signal(event, *args, **kwargs) return super().signal(event, *args, **kwargs)
def reset(self):
self._apps: Set[Sanic] = set()
self.exceptions: List[RouteHandler] = []
self.listeners: Dict[str, List[ListenerType]] = {}
self.middlewares: List[MiddlewareType] = []
self.routes: List[Route] = []
self.statics: List[RouteHandler] = []
self.websocket_routes: List[Route] = []
def copy(
self,
name: str,
url_prefix: Optional[Union[str, Default]] = _default,
version: Optional[Union[int, str, float, Default]] = _default,
version_prefix: Union[str, Default] = _default,
strict_slashes: Optional[Union[bool, Default]] = _default,
with_registration: bool = True,
with_ctx: bool = False,
):
"""
Copy a blueprint instance with some optional parameters to
override the values of attributes in the old instance.
:param name: unique name of the blueprint
:param url_prefix: URL to be prefixed before all route URLs
:param version: Blueprint Version
:param version_prefix: the prefix of the version number shown in the
URL.
:param strict_slashes: Enforce the API urls are requested with a
trailing */*
:param with_registration: whether register new blueprint instance with
sanic apps that were registered with the old instance or not.
:param with_ctx: whether ``ctx`` will be copied or not.
"""
attrs_backup = {
"_apps": self._apps,
"routes": self.routes,
"websocket_routes": self.websocket_routes,
"middlewares": self.middlewares,
"exceptions": self.exceptions,
"listeners": self.listeners,
"statics": self.statics,
}
self.reset()
new_bp = deepcopy(self)
new_bp.name = name
if not isinstance(url_prefix, Default):
new_bp.url_prefix = url_prefix
if not isinstance(version, Default):
new_bp.version = version
if not isinstance(strict_slashes, Default):
new_bp.strict_slashes = strict_slashes
if not isinstance(version_prefix, Default):
new_bp.version_prefix = version_prefix
for key, value in attrs_backup.items():
setattr(self, key, value)
if with_registration and self._apps:
if new_bp._future_statics:
raise SanicException(
"Static routes registered with the old blueprint instance,"
" cannot be registered again."
)
for app in self._apps:
app.blueprint(new_bp)
if not with_ctx:
new_bp.ctx = SimpleNamespace()
return new_bp
@staticmethod @staticmethod
def group( def group(
*blueprints: Union[Blueprint, BlueprintGroup], *blueprints: Union[Blueprint, BlueprintGroup],

View File

@ -155,3 +155,17 @@ def import_string(module_name, package=None):
if ismodule(obj): if ismodule(obj):
return obj return obj
return obj() return obj()
class Default:
"""
It is used to replace `None` or `object()` as a sentinel
that represents a default value. Sometimes we want to set
a value to `None` so we cannot use `None` to represent the
default value, and `object()` is hard to be typed.
"""
pass
_default = Default()

View File

@ -0,0 +1,70 @@
from copy import deepcopy
from sanic import Blueprint, Sanic, blueprints, response
from sanic.response import text
def test_bp_copy(app: Sanic):
bp1 = Blueprint("test_bp1", version=1)
bp1.ctx.test = 1
assert hasattr(bp1.ctx, "test")
@bp1.route("/page")
def handle_request(request):
return text("Hello world!")
bp2 = bp1.copy(name="test_bp2", version=2)
assert id(bp1) != id(bp2)
assert bp1._apps == bp2._apps == set()
assert not hasattr(bp2.ctx, "test")
assert len(bp2._future_exceptions) == len(bp1._future_exceptions)
assert len(bp2._future_listeners) == len(bp1._future_listeners)
assert len(bp2._future_middleware) == len(bp1._future_middleware)
assert len(bp2._future_routes) == len(bp1._future_routes)
assert len(bp2._future_signals) == len(bp1._future_signals)
app.blueprint(bp1)
app.blueprint(bp2)
bp3 = bp1.copy(name="test_bp3", version=3, with_registration=True)
assert id(bp1) != id(bp3)
assert bp1._apps == bp3._apps and bp3._apps
assert not hasattr(bp3.ctx, "test")
bp4 = bp1.copy(name="test_bp4", version=4, with_ctx=True)
assert id(bp1) != id(bp4)
assert bp4.ctx.test == 1
bp5 = bp1.copy(name="test_bp5", version=5, with_registration=False)
assert id(bp1) != id(bp5)
assert not bp5._apps
assert bp1._apps != set()
app.blueprint(bp5)
bp6 = bp1.copy(
name="test_bp6",
version=6,
with_registration=True,
version_prefix="/version",
)
assert bp6._apps
assert bp6.version_prefix == "/version"
_, response = app.test_client.get("/v1/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v2/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v3/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v4/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v5/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/version6/page")
assert "Hello world!" in response.text