GIT-37: fix blueprint middleware application (#1690)
* GIT-37: fix blueprint middleware application 1. If you register a middleware via `@blueprint.middleware` then it will apply only to the routes defined by the blueprint. 2. If you register a middleware via `@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group. 3. If you define a middleware via `@app.middleware` then it will be applied on all available routes Fixes #37 Signed-off-by: Harsha Narayana <harsha2k4@gmail.com> * GIT-37: add changelog Signed-off-by: Harsha Narayana <harsha2k4@gmail.com>
This commit is contained in:
parent
179a07942e
commit
a6077a1790
11
changelogs/37.bugfix.rst
Normal file
11
changelogs/37.bugfix.rst
Normal file
|
@ -0,0 +1,11 @@
|
|||
Fix blueprint middleware application
|
||||
|
||||
Currently, any blueprint middleware registered, irrespective of which blueprint was used to do so, was
|
||||
being applied to all of the routes created by the :code:`@app` and :code:`@blueprint` alike.
|
||||
|
||||
As part of this change, the blueprint based middleware application is enforced based on where they are
|
||||
registered.
|
||||
|
||||
- If you register a middleware via :code:`@blueprint.middleware` then it will apply only to the routes defined by the blueprint.
|
||||
- If you register a middleware via :code:`@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group.
|
||||
- If you define a middleware via :code:`@app.middleware` then it will be applied on all available routes
|
67
sanic/app.py
67
sanic/app.py
|
@ -85,7 +85,8 @@ class Sanic:
|
|||
self.is_request_stream = False
|
||||
self.websocket_enabled = False
|
||||
self.websocket_tasks = set()
|
||||
|
||||
self.named_request_middleware = {}
|
||||
self.named_response_middleware = {}
|
||||
# Register alternative method names
|
||||
self.go_fast = self.run
|
||||
|
||||
|
@ -178,7 +179,7 @@ class Sanic:
|
|||
:param stream:
|
||||
:param version:
|
||||
:param name: user defined route name for url_for
|
||||
:return: decorated function
|
||||
:return: tuple of routes, decorated function
|
||||
"""
|
||||
|
||||
# Fix case where the user did not prefix the URL with a /
|
||||
|
@ -204,7 +205,7 @@ class Sanic:
|
|||
if stream:
|
||||
handler.is_stream = stream
|
||||
|
||||
self.router.add(
|
||||
routes = self.router.add(
|
||||
uri=uri,
|
||||
methods=methods,
|
||||
handler=handler,
|
||||
|
@ -213,7 +214,7 @@ class Sanic:
|
|||
version=version,
|
||||
name=name,
|
||||
)
|
||||
return handler
|
||||
return routes, handler
|
||||
|
||||
return response
|
||||
|
||||
|
@ -462,7 +463,7 @@ class Sanic:
|
|||
:param subprotocols: optional list of str with supported subprotocols
|
||||
:param name: A unique name assigned to the URL so that it can
|
||||
be used with :func:`url_for`
|
||||
:return: decorated function
|
||||
:return: tuple of routes, decorated function
|
||||
"""
|
||||
self.enable_websocket()
|
||||
|
||||
|
@ -515,7 +516,7 @@ class Sanic:
|
|||
self.websocket_tasks.remove(fut)
|
||||
await ws.close()
|
||||
|
||||
self.router.add(
|
||||
routes = self.router.add(
|
||||
uri=uri,
|
||||
handler=websocket_handler,
|
||||
methods=frozenset({"GET"}),
|
||||
|
@ -523,7 +524,7 @@ class Sanic:
|
|||
strict_slashes=strict_slashes,
|
||||
name=name,
|
||||
)
|
||||
return handler
|
||||
return routes, handler
|
||||
|
||||
return response
|
||||
|
||||
|
@ -544,6 +545,7 @@ class Sanic:
|
|||
:param host: Host IP or FQDN details
|
||||
:param uri: URL path that will be mapped to the websocket
|
||||
handler
|
||||
handler
|
||||
:param strict_slashes: If the API endpoint needs to terminate
|
||||
with a "/" or not
|
||||
:param subprotocols: Subprotocols to be used with websocket
|
||||
|
@ -645,6 +647,22 @@ class Sanic:
|
|||
self.response_middleware.appendleft(middleware)
|
||||
return middleware
|
||||
|
||||
def register_named_middleware(
|
||||
self, middleware, route_names, attach_to="request"
|
||||
):
|
||||
if attach_to == "request":
|
||||
for _rn in route_names:
|
||||
if _rn not in self.named_request_middleware:
|
||||
self.named_request_middleware[_rn] = deque()
|
||||
if middleware not in self.named_request_middleware[_rn]:
|
||||
self.named_request_middleware[_rn].append(middleware)
|
||||
if attach_to == "response":
|
||||
for _rn in route_names:
|
||||
if _rn not in self.named_response_middleware:
|
||||
self.named_response_middleware[_rn] = deque()
|
||||
if middleware not in self.named_response_middleware[_rn]:
|
||||
self.named_response_middleware[_rn].append(middleware)
|
||||
|
||||
# Decorator
|
||||
def middleware(self, middleware_or_request):
|
||||
"""
|
||||
|
@ -916,20 +934,23 @@ class Sanic:
|
|||
# allocation before assignment below.
|
||||
response = None
|
||||
cancelled = False
|
||||
name = None
|
||||
try:
|
||||
# Fetch handler from router
|
||||
handler, args, kwargs, uri, name = self.router.get(request)
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
response = await self._run_request_middleware(request)
|
||||
response = await self._run_request_middleware(
|
||||
request, request_name=name
|
||||
)
|
||||
# No middleware results
|
||||
if not response:
|
||||
# -------------------------------------------- #
|
||||
# Execute Handler
|
||||
# -------------------------------------------- #
|
||||
|
||||
# Fetch handler from router
|
||||
handler, args, kwargs, uri = self.router.get(request)
|
||||
|
||||
request.uri_template = uri
|
||||
if handler is None:
|
||||
raise ServerError(
|
||||
|
@ -993,7 +1014,7 @@ class Sanic:
|
|||
if response is not None:
|
||||
try:
|
||||
response = await self._run_response_middleware(
|
||||
request, response
|
||||
request, response, request_name=name
|
||||
)
|
||||
except CancelledError:
|
||||
# Response middleware can timeout too, as above.
|
||||
|
@ -1265,10 +1286,14 @@ class Sanic:
|
|||
if isawaitable(result):
|
||||
await result
|
||||
|
||||
async def _run_request_middleware(self, request):
|
||||
async def _run_request_middleware(self, request, request_name=None):
|
||||
# The if improves speed. I don't know why
|
||||
if self.request_middleware:
|
||||
for middleware in self.request_middleware:
|
||||
named_middleware = self.named_request_middleware.get(
|
||||
request_name, deque()
|
||||
)
|
||||
applicable_middleware = self.request_middleware + named_middleware
|
||||
if applicable_middleware:
|
||||
for middleware in applicable_middleware:
|
||||
response = middleware(request)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
|
@ -1276,9 +1301,15 @@ class Sanic:
|
|||
return response
|
||||
return None
|
||||
|
||||
async def _run_response_middleware(self, request, response):
|
||||
if self.response_middleware:
|
||||
for middleware in self.response_middleware:
|
||||
async def _run_response_middleware(
|
||||
self, request, response, request_name=None
|
||||
):
|
||||
named_middleware = self.named_response_middleware.get(
|
||||
request_name, deque()
|
||||
)
|
||||
applicable_middleware = self.response_middleware + named_middleware
|
||||
if applicable_middleware:
|
||||
for middleware in applicable_middleware:
|
||||
_response = middleware(request, response)
|
||||
if isawaitable(_response):
|
||||
_response = await _response
|
||||
|
|
|
@ -104,6 +104,8 @@ class Blueprint:
|
|||
|
||||
url_prefix = options.get("url_prefix", self.url_prefix)
|
||||
|
||||
routes = []
|
||||
|
||||
# Routes
|
||||
for future in self.routes:
|
||||
# attach the blueprint name to the handler so that it can be
|
||||
|
@ -114,7 +116,7 @@ class Blueprint:
|
|||
|
||||
version = future.version or self.version
|
||||
|
||||
app.route(
|
||||
_routes, _ = app.route(
|
||||
uri=uri[1:] if uri.startswith("//") else uri,
|
||||
methods=future.methods,
|
||||
host=future.host or self.host,
|
||||
|
@ -123,6 +125,8 @@ class Blueprint:
|
|||
version=version,
|
||||
name=future.name,
|
||||
)(future.handler)
|
||||
if _routes:
|
||||
routes += _routes
|
||||
|
||||
for future in self.websocket_routes:
|
||||
# attach the blueprint name to the handler so that it can be
|
||||
|
@ -130,21 +134,27 @@ class Blueprint:
|
|||
future.handler.__blueprintname__ = self.name
|
||||
# Prepend the blueprint URI prefix if available
|
||||
uri = url_prefix + future.uri if url_prefix else future.uri
|
||||
app.websocket(
|
||||
_routes, _ = app.websocket(
|
||||
uri=uri,
|
||||
host=future.host or self.host,
|
||||
strict_slashes=future.strict_slashes,
|
||||
name=future.name,
|
||||
)(future.handler)
|
||||
if _routes:
|
||||
routes += _routes
|
||||
|
||||
route_names = [route.name for route in routes]
|
||||
# Middleware
|
||||
for future in self.middlewares:
|
||||
if future.args or future.kwargs:
|
||||
app.register_middleware(
|
||||
future.middleware, *future.args, **future.kwargs
|
||||
app.register_named_middleware(
|
||||
future.middleware,
|
||||
route_names,
|
||||
*future.args,
|
||||
**future.kwargs
|
||||
)
|
||||
else:
|
||||
app.register_middleware(future.middleware)
|
||||
app.register_named_middleware(future.middleware, route_names)
|
||||
|
||||
# Exceptions
|
||||
for future in self.exceptions:
|
||||
|
|
|
@ -140,21 +140,22 @@ class Router:
|
|||
docs for further details.
|
||||
:return: Nothing
|
||||
"""
|
||||
routes = []
|
||||
if version is not None:
|
||||
version = re.escape(str(version).strip("/").lstrip("v"))
|
||||
uri = "/".join(["/v{}".format(version), uri.lstrip("/")])
|
||||
# add regular version
|
||||
self._add(uri, methods, handler, host, name)
|
||||
routes.append(self._add(uri, methods, handler, host, name))
|
||||
|
||||
if strict_slashes:
|
||||
return
|
||||
return routes
|
||||
|
||||
if not isinstance(host, str) and host is not None:
|
||||
# we have gotten back to the top of the recursion tree where the
|
||||
# host was originally a list. By now, we've processed the strict
|
||||
# slashes logic on the leaf nodes (the individual host strings in
|
||||
# the list of host)
|
||||
return
|
||||
return routes
|
||||
|
||||
# Add versions with and without trailing /
|
||||
slashed_methods = self.routes_all.get(uri + "/", frozenset({}))
|
||||
|
@ -176,10 +177,12 @@ class Router:
|
|||
)
|
||||
# add version with trailing slash
|
||||
if slash_is_missing:
|
||||
self._add(uri + "/", methods, handler, host, name)
|
||||
routes.append(self._add(uri + "/", methods, handler, host, name))
|
||||
# add version without trailing slash
|
||||
elif without_slash_is_missing:
|
||||
self._add(uri[:-1], methods, handler, host, name)
|
||||
routes.append(self._add(uri[:-1], methods, handler, host, name))
|
||||
|
||||
return routes
|
||||
|
||||
def _add(self, uri, methods, handler, host=None, name=None):
|
||||
"""Add a handler to the route list
|
||||
|
@ -328,6 +331,7 @@ class Router:
|
|||
self.routes_dynamic[url_hash(uri)].append(route)
|
||||
else:
|
||||
self.routes_static[uri] = route
|
||||
return route
|
||||
|
||||
@staticmethod
|
||||
def check_dynamic_route_exists(pattern, routes_to_check, parameters):
|
||||
|
@ -442,6 +446,7 @@ class Router:
|
|||
method=method,
|
||||
allowed_methods=self.get_supported_methods(url),
|
||||
)
|
||||
|
||||
if route:
|
||||
if route.methods and method not in route.methods:
|
||||
raise method_not_supported
|
||||
|
@ -476,7 +481,7 @@ class Router:
|
|||
route_handler = route.handler
|
||||
if hasattr(route_handler, "handlers"):
|
||||
route_handler = route_handler.handlers[method]
|
||||
return route_handler, [], kwargs, route.uri
|
||||
return route_handler, [], kwargs, route.uri, route.name
|
||||
|
||||
def is_stream_handler(self, request):
|
||||
""" Handler for request is stream or not.
|
||||
|
|
|
@ -94,7 +94,7 @@ def test_app_route_raise_value_error(app):
|
|||
|
||||
def test_app_handle_request_handler_is_none(app, monkeypatch):
|
||||
def mockreturn(*args, **kwargs):
|
||||
return None, [], {}, ""
|
||||
return None, [], {}, "", ""
|
||||
|
||||
# Not sure how to make app.router.get() return None, so use mock here.
|
||||
monkeypatch.setattr(app.router, "get", mockreturn)
|
||||
|
|
|
@ -83,7 +83,7 @@ def test_bp_group_with_additional_route_params(app: Sanic):
|
|||
_, response = app.test_client.patch("/api/bp2/route/bp2", headers=header)
|
||||
assert response.text == "PATCH_bp2"
|
||||
|
||||
_, response = app.test_client.get("/v2/api/bp1/request_path")
|
||||
_, response = app.test_client.put("/v2/api/bp1/request_path")
|
||||
assert response.status == 401
|
||||
|
||||
|
||||
|
@ -141,8 +141,8 @@ def test_bp_group(app: Sanic):
|
|||
_, response = app.test_client.get("/api/bp3")
|
||||
assert response.text == "BP3_OK"
|
||||
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 4
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 2
|
||||
|
||||
|
||||
def test_bp_group_list_operations(app: Sanic):
|
||||
|
|
|
@ -268,7 +268,7 @@ def test_bp_middleware(app):
|
|||
request, response = app.test_client.get("/")
|
||||
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
assert response.text == "FAIL"
|
||||
|
||||
|
||||
def test_bp_exception_handler(app):
|
||||
|
|
Loading…
Reference in New Issue
Block a user