Copy Blueprints Implementation (#2184)

This commit is contained in:
Zhiwei 2021-08-09 17:07:04 -05:00 committed by GitHub
parent e1cfbf0fd9
commit e2eefaac55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 163 additions and 9 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from copy import deepcopy
from types import SimpleNamespace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union
@ -12,6 +13,7 @@ from sanic_routing.route import Route # type: ignore
from sanic.base import BaseSanic
from sanic.blueprint_group import BlueprintGroup
from sanic.exceptions import SanicException
from sanic.helpers import Default, _default
from sanic.models.futures import FutureRoute, FutureStatic
from sanic.models.handler_types import (
ListenerType,
@ -40,7 +42,7 @@ class Blueprint(BaseSanic):
:param host: IP Address of FQDN for the sanic server to use.
:param version: Blueprint Version
:param strict_slashes: Enforce the API urls are requested with a
training */*
trailing */*
"""
__fake_slots__ = (
@ -76,15 +78,9 @@ class Blueprint(BaseSanic):
version_prefix: str = "/v",
):
super().__init__(name=name)
self._apps: Set[Sanic] = set()
self.reset()
self.ctx = SimpleNamespace()
self.exceptions: List[RouteHandler] = []
self.host = host
self.listeners: Dict[str, List[ListenerType]] = {}
self.middlewares: List[MiddlewareType] = []
self.routes: List[Route] = []
self.statics: List[RouteHandler] = []
self.strict_slashes = strict_slashes
self.url_prefix = (
url_prefix[:-1]
@ -93,7 +89,6 @@ class Blueprint(BaseSanic):
)
self.version = version
self.version_prefix = version_prefix
self.websocket_routes: List[Route] = []
def __repr__(self) -> str:
args = ", ".join(
@ -144,6 +139,81 @@ class Blueprint(BaseSanic):
kwargs["apply"] = False
return super().signal(event, *args, **kwargs)
def reset(self):
self._apps: Set[Sanic] = set()
self.exceptions: List[RouteHandler] = []
self.listeners: Dict[str, List[ListenerType]] = {}
self.middlewares: List[MiddlewareType] = []
self.routes: List[Route] = []
self.statics: List[RouteHandler] = []
self.websocket_routes: List[Route] = []
def copy(
self,
name: str,
url_prefix: Optional[Union[str, Default]] = _default,
version: Optional[Union[int, str, float, Default]] = _default,
version_prefix: Union[str, Default] = _default,
strict_slashes: Optional[Union[bool, Default]] = _default,
with_registration: bool = True,
with_ctx: bool = False,
):
"""
Copy a blueprint instance with some optional parameters to
override the values of attributes in the old instance.
:param name: unique name of the blueprint
:param url_prefix: URL to be prefixed before all route URLs
:param version: Blueprint Version
:param version_prefix: the prefix of the version number shown in the
URL.
:param strict_slashes: Enforce the API urls are requested with a
trailing */*
:param with_registration: whether register new blueprint instance with
sanic apps that were registered with the old instance or not.
:param with_ctx: whether ``ctx`` will be copied or not.
"""
attrs_backup = {
"_apps": self._apps,
"routes": self.routes,
"websocket_routes": self.websocket_routes,
"middlewares": self.middlewares,
"exceptions": self.exceptions,
"listeners": self.listeners,
"statics": self.statics,
}
self.reset()
new_bp = deepcopy(self)
new_bp.name = name
if not isinstance(url_prefix, Default):
new_bp.url_prefix = url_prefix
if not isinstance(version, Default):
new_bp.version = version
if not isinstance(strict_slashes, Default):
new_bp.strict_slashes = strict_slashes
if not isinstance(version_prefix, Default):
new_bp.version_prefix = version_prefix
for key, value in attrs_backup.items():
setattr(self, key, value)
if with_registration and self._apps:
if new_bp._future_statics:
raise SanicException(
"Static routes registered with the old blueprint instance,"
" cannot be registered again."
)
for app in self._apps:
app.blueprint(new_bp)
if not with_ctx:
new_bp.ctx = SimpleNamespace()
return new_bp
@staticmethod
def group(
*blueprints: Union[Blueprint, BlueprintGroup],

View File

@ -155,3 +155,17 @@ def import_string(module_name, package=None):
if ismodule(obj):
return obj
return obj()
class Default:
"""
It is used to replace `None` or `object()` as a sentinel
that represents a default value. Sometimes we want to set
a value to `None` so we cannot use `None` to represent the
default value, and `object()` is hard to be typed.
"""
pass
_default = Default()

View File

@ -0,0 +1,70 @@
from copy import deepcopy
from sanic import Blueprint, Sanic, blueprints, response
from sanic.response import text
def test_bp_copy(app: Sanic):
bp1 = Blueprint("test_bp1", version=1)
bp1.ctx.test = 1
assert hasattr(bp1.ctx, "test")
@bp1.route("/page")
def handle_request(request):
return text("Hello world!")
bp2 = bp1.copy(name="test_bp2", version=2)
assert id(bp1) != id(bp2)
assert bp1._apps == bp2._apps == set()
assert not hasattr(bp2.ctx, "test")
assert len(bp2._future_exceptions) == len(bp1._future_exceptions)
assert len(bp2._future_listeners) == len(bp1._future_listeners)
assert len(bp2._future_middleware) == len(bp1._future_middleware)
assert len(bp2._future_routes) == len(bp1._future_routes)
assert len(bp2._future_signals) == len(bp1._future_signals)
app.blueprint(bp1)
app.blueprint(bp2)
bp3 = bp1.copy(name="test_bp3", version=3, with_registration=True)
assert id(bp1) != id(bp3)
assert bp1._apps == bp3._apps and bp3._apps
assert not hasattr(bp3.ctx, "test")
bp4 = bp1.copy(name="test_bp4", version=4, with_ctx=True)
assert id(bp1) != id(bp4)
assert bp4.ctx.test == 1
bp5 = bp1.copy(name="test_bp5", version=5, with_registration=False)
assert id(bp1) != id(bp5)
assert not bp5._apps
assert bp1._apps != set()
app.blueprint(bp5)
bp6 = bp1.copy(
name="test_bp6",
version=6,
with_registration=True,
version_prefix="/version",
)
assert bp6._apps
assert bp6.version_prefix == "/version"
_, response = app.test_client.get("/v1/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v2/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v3/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v4/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/v5/page")
assert "Hello world!" in response.text
_, response = app.test_client.get("/version6/page")
assert "Hello world!" in response.text