From e3744095675df170398a374cc20832470200ea40 Mon Sep 17 00:00:00 2001 From: Zhiwei Date: Fri, 7 Jul 2023 07:56:42 -0400 Subject: [PATCH] Adding allow route overwrite option in blueprint (#2716) * Adding allow route overwrite option * Add test case for route overwriting after bp copy * Fix test * Fix * Add test case `test_bp_allow_override` * Remove conflicted future routes when overwriting is allowed * Improved test test_bp_copy_with_route_overwriting * Fix type * Fix type 2 * Add `test_bp_copy_without_route_overwriting` case * make `allow_route_overwrite` flag to be internal * Remove unwanted test case --------- Co-authored-by: Adam Hopkins --- sanic/app.py | 5 ++- sanic/blueprints.py | 10 ++++- sanic/mixins/routes.py | 8 +++- sanic/router.py | 2 + tests/test_blueprint_copy.py | 79 +++++++++++++++++++++++++++++++++++- 5 files changed, 99 insertions(+), 5 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 05ed7d54..527650cc 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -417,8 +417,11 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): def _apply_listener(self, listener: FutureListener): return self.register_listener(listener.listener, listener.event) - def _apply_route(self, route: FutureRoute) -> List[Route]: + def _apply_route( + self, route: FutureRoute, overwrite: bool = False + ) -> List[Route]: params = route._asdict() + params["overwrite"] = overwrite websocket = params.pop("websocket", False) subprotocols = params.pop("subprotocols", None) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index ef9c6756..a52495b9 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -93,6 +93,7 @@ class Blueprint(BaseSanic): "_future_listeners", "_future_exceptions", "_future_signals", + "_allow_route_overwrite", "copied_from", "ctx", "exceptions", @@ -119,6 +120,7 @@ class Blueprint(BaseSanic): ): super().__init__(name=name) self.reset() + self._allow_route_overwrite = False self.copied_from = "" self.ctx = SimpleNamespace() self.host = host @@ -169,6 +171,7 @@ class Blueprint(BaseSanic): def reset(self): self._apps: Set[Sanic] = set() + self._allow_route_overwrite = False self.exceptions: List[RouteHandler] = [] self.listeners: Dict[str, List[ListenerType[Any]]] = {} self.middlewares: List[MiddlewareType] = [] @@ -182,6 +185,7 @@ class Blueprint(BaseSanic): url_prefix: Optional[Union[str, Default]] = _default, version: Optional[Union[int, str, float, Default]] = _default, version_prefix: Union[str, Default] = _default, + allow_route_overwrite: Union[bool, Default] = _default, strict_slashes: Optional[Union[bool, Default]] = _default, with_registration: bool = True, with_ctx: bool = False, @@ -225,6 +229,8 @@ class Blueprint(BaseSanic): new_bp.strict_slashes = strict_slashes if not isinstance(version_prefix, Default): new_bp.version_prefix = version_prefix + if not isinstance(allow_route_overwrite, Default): + new_bp._allow_route_overwrite = allow_route_overwrite for key, value in attrs_backup.items(): setattr(self, key, value) @@ -360,7 +366,9 @@ class Blueprint(BaseSanic): continue registered.add(apply_route) - route = app._apply_route(apply_route) + route = app._apply_route( + apply_route, overwrite=self._allow_route_overwrite + ) # If it is a copied BP, then make sure all of the names of routes # matchup with the new BP name diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 49540279..438df171 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -159,7 +159,11 @@ class RouteMixin(BaseMixin, metaclass=SanicMeta): error_format, route_context, ) - + overwrite = getattr(self, "_allow_route_overwrite", False) + if overwrite: + self._future_routes = set( + filter(lambda x: x.uri != uri, self._future_routes) + ) self._future_routes.add(route) args = list(signature(handler).parameters.keys()) @@ -182,7 +186,7 @@ class RouteMixin(BaseMixin, metaclass=SanicMeta): handler.is_stream = stream if apply: - self._apply_route(route) + self._apply_route(route, overwrite=overwrite) if static: return route, handler diff --git a/sanic/router.py b/sanic/router.py index a469fd21..17339cb2 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -80,6 +80,7 @@ class Router(BaseRouter): unquote: bool = False, static: bool = False, version_prefix: str = "/v", + overwrite: bool = False, error_format: Optional[str] = None, ) -> Union[Route, List[Route]]: """ @@ -122,6 +123,7 @@ class Router(BaseRouter): name=name, strict=strict_slashes, unquote=unquote, + overwrite=overwrite, ) if isinstance(host, str): diff --git a/tests/test_blueprint_copy.py b/tests/test_blueprint_copy.py index 387cb8d2..5dd742ea 100644 --- a/tests/test_blueprint_copy.py +++ b/tests/test_blueprint_copy.py @@ -1,4 +1,8 @@ -from sanic import Blueprint, Sanic +import pytest + +from sanic_routing.exceptions import RouteExists + +from sanic import Blueprint, Request, Sanic from sanic.response import text @@ -74,3 +78,76 @@ def test_bp_copy(app: Sanic): assert "test_bp_copy.test_bp4.handle_request" in route_names assert "test_bp_copy.test_bp5.handle_request" in route_names assert "test_bp_copy.test_bp6.handle_request" in route_names + + +def test_bp_copy_without_route_overwriting(app: Sanic): + bpv1 = Blueprint("bp_v1", version=1, url_prefix="my_api") + + @bpv1.route("/") + async def handler(request: Request): + return text("v1") + + app.blueprint(bpv1) + + bpv2 = bpv1.copy("bp_v2", version=2, allow_route_overwrite=False) + bpv3 = bpv1.copy( + "bp_v3", + version=3, + allow_route_overwrite=False, + with_registration=False, + ) + + with pytest.raises(RouteExists, match="Route already registered*"): + + @bpv2.route("/") + async def handler(request: Request): + return text("v2") + + app.blueprint(bpv2) + + with pytest.raises(RouteExists, match="Route already registered*"): + + @bpv3.route("/") + async def handler(request: Request): + return text("v3") + + app.blueprint(bpv3) + + +def test_bp_copy_with_route_overwriting(app: Sanic): + bpv1 = Blueprint("bp_v1", version=1, url_prefix="my_api") + + @bpv1.route("/") + async def handler(request: Request): + return text("v1") + + app.blueprint(bpv1) + + bpv2 = bpv1.copy("bp_v2", version=2, allow_route_overwrite=True) + bpv3 = bpv1.copy( + "bp_v3", version=3, allow_route_overwrite=True, with_registration=False + ) + + @bpv2.route("/") + async def handler(request: Request): + return text("v2") + + app.blueprint(bpv2) + + @bpv3.route("/") + async def handler(request: Request): + return text("v3") + + app.blueprint(bpv3) + + _, response = app.test_client.get("/v1/my_api") + assert response.status == 200 + assert response.text == "v1" + + _, response = app.test_client.get("/v2/my_api") + assert response.status == 200 + assert response.text == "v2" + + _, response = app.test_client.get("/v3/my_api") + assert response.status == 200 + assert response.text == "v3"