GIT-37: fix blueprint middleware application (#1690)

* GIT-37: fix blueprint middleware application

1. If you register a middleware via `@blueprint.middleware` then it will apply only to the routes defined by the blueprint.
2. If you register a middleware via `@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group.
3. If you define a middleware via `@app.middleware` then it will be applied on all available routes

Fixes #37

Signed-off-by: Harsha Narayana <harsha2k4@gmail.com>

* GIT-37: add changelog

Signed-off-by: Harsha Narayana <harsha2k4@gmail.com>
This commit is contained in:
Harsha Narayana 2019-12-20 21:31:05 +05:30 committed by Stephen Sadowski
parent 179a07942e
commit a6077a1790
7 changed files with 91 additions and 34 deletions

11
changelogs/37.bugfix.rst Normal file
View File

@ -0,0 +1,11 @@
Fix blueprint middleware application
Currently, any blueprint middleware registered, irrespective of which blueprint was used to do so, was
being applied to all of the routes created by the :code:`@app` and :code:`@blueprint` alike.
As part of this change, the blueprint based middleware application is enforced based on where they are
registered.
- If you register a middleware via :code:`@blueprint.middleware` then it will apply only to the routes defined by the blueprint.
- If you register a middleware via :code:`@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group.
- If you define a middleware via :code:`@app.middleware` then it will be applied on all available routes

View File

@ -85,7 +85,8 @@ class Sanic:
self.is_request_stream = False
self.websocket_enabled = False
self.websocket_tasks = set()
self.named_request_middleware = {}
self.named_response_middleware = {}
# Register alternative method names
self.go_fast = self.run
@ -178,7 +179,7 @@ class Sanic:
:param stream:
:param version:
:param name: user defined route name for url_for
:return: decorated function
:return: tuple of routes, decorated function
"""
# Fix case where the user did not prefix the URL with a /
@ -204,7 +205,7 @@ class Sanic:
if stream:
handler.is_stream = stream
self.router.add(
routes = self.router.add(
uri=uri,
methods=methods,
handler=handler,
@ -213,7 +214,7 @@ class Sanic:
version=version,
name=name,
)
return handler
return routes, handler
return response
@ -462,7 +463,7 @@ class Sanic:
:param subprotocols: optional list of str with supported subprotocols
:param name: A unique name assigned to the URL so that it can
be used with :func:`url_for`
:return: decorated function
:return: tuple of routes, decorated function
"""
self.enable_websocket()
@ -515,7 +516,7 @@ class Sanic:
self.websocket_tasks.remove(fut)
await ws.close()
self.router.add(
routes = self.router.add(
uri=uri,
handler=websocket_handler,
methods=frozenset({"GET"}),
@ -523,7 +524,7 @@ class Sanic:
strict_slashes=strict_slashes,
name=name,
)
return handler
return routes, handler
return response
@ -544,6 +545,7 @@ class Sanic:
:param host: Host IP or FQDN details
:param uri: URL path that will be mapped to the websocket
handler
handler
:param strict_slashes: If the API endpoint needs to terminate
with a "/" or not
:param subprotocols: Subprotocols to be used with websocket
@ -645,6 +647,22 @@ class Sanic:
self.response_middleware.appendleft(middleware)
return middleware
def register_named_middleware(
self, middleware, route_names, attach_to="request"
):
if attach_to == "request":
for _rn in route_names:
if _rn not in self.named_request_middleware:
self.named_request_middleware[_rn] = deque()
if middleware not in self.named_request_middleware[_rn]:
self.named_request_middleware[_rn].append(middleware)
if attach_to == "response":
for _rn in route_names:
if _rn not in self.named_response_middleware:
self.named_response_middleware[_rn] = deque()
if middleware not in self.named_response_middleware[_rn]:
self.named_response_middleware[_rn].append(middleware)
# Decorator
def middleware(self, middleware_or_request):
"""
@ -916,20 +934,23 @@ class Sanic:
# allocation before assignment below.
response = None
cancelled = False
name = None
try:
# Fetch handler from router
handler, args, kwargs, uri, name = self.router.get(request)
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(request)
response = await self._run_request_middleware(
request, request_name=name
)
# No middleware results
if not response:
# -------------------------------------------- #
# Execute Handler
# -------------------------------------------- #
# Fetch handler from router
handler, args, kwargs, uri = self.router.get(request)
request.uri_template = uri
if handler is None:
raise ServerError(
@ -993,7 +1014,7 @@ class Sanic:
if response is not None:
try:
response = await self._run_response_middleware(
request, response
request, response, request_name=name
)
except CancelledError:
# Response middleware can timeout too, as above.
@ -1265,10 +1286,14 @@ class Sanic:
if isawaitable(result):
await result
async def _run_request_middleware(self, request):
async def _run_request_middleware(self, request, request_name=None):
# The if improves speed. I don't know why
if self.request_middleware:
for middleware in self.request_middleware:
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
response = middleware(request)
if isawaitable(response):
response = await response
@ -1276,9 +1301,15 @@ class Sanic:
return response
return None
async def _run_response_middleware(self, request, response):
if self.response_middleware:
for middleware in self.response_middleware:
async def _run_response_middleware(
self, request, response, request_name=None
):
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response

View File

@ -104,6 +104,8 @@ class Blueprint:
url_prefix = options.get("url_prefix", self.url_prefix)
routes = []
# Routes
for future in self.routes:
# attach the blueprint name to the handler so that it can be
@ -114,7 +116,7 @@ class Blueprint:
version = future.version or self.version
app.route(
_routes, _ = app.route(
uri=uri[1:] if uri.startswith("//") else uri,
methods=future.methods,
host=future.host or self.host,
@ -123,6 +125,8 @@ class Blueprint:
version=version,
name=future.name,
)(future.handler)
if _routes:
routes += _routes
for future in self.websocket_routes:
# attach the blueprint name to the handler so that it can be
@ -130,21 +134,27 @@ class Blueprint:
future.handler.__blueprintname__ = self.name
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
app.websocket(
_routes, _ = app.websocket(
uri=uri,
host=future.host or self.host,
strict_slashes=future.strict_slashes,
name=future.name,
)(future.handler)
if _routes:
routes += _routes
route_names = [route.name for route in routes]
# Middleware
for future in self.middlewares:
if future.args or future.kwargs:
app.register_middleware(
future.middleware, *future.args, **future.kwargs
app.register_named_middleware(
future.middleware,
route_names,
*future.args,
**future.kwargs
)
else:
app.register_middleware(future.middleware)
app.register_named_middleware(future.middleware, route_names)
# Exceptions
for future in self.exceptions:

View File

@ -140,21 +140,22 @@ class Router:
docs for further details.
:return: Nothing
"""
routes = []
if version is not None:
version = re.escape(str(version).strip("/").lstrip("v"))
uri = "/".join(["/v{}".format(version), uri.lstrip("/")])
# add regular version
self._add(uri, methods, handler, host, name)
routes.append(self._add(uri, methods, handler, host, name))
if strict_slashes:
return
return routes
if not isinstance(host, str) and host is not None:
# we have gotten back to the top of the recursion tree where the
# host was originally a list. By now, we've processed the strict
# slashes logic on the leaf nodes (the individual host strings in
# the list of host)
return
return routes
# Add versions with and without trailing /
slashed_methods = self.routes_all.get(uri + "/", frozenset({}))
@ -176,10 +177,12 @@ class Router:
)
# add version with trailing slash
if slash_is_missing:
self._add(uri + "/", methods, handler, host, name)
routes.append(self._add(uri + "/", methods, handler, host, name))
# add version without trailing slash
elif without_slash_is_missing:
self._add(uri[:-1], methods, handler, host, name)
routes.append(self._add(uri[:-1], methods, handler, host, name))
return routes
def _add(self, uri, methods, handler, host=None, name=None):
"""Add a handler to the route list
@ -328,6 +331,7 @@ class Router:
self.routes_dynamic[url_hash(uri)].append(route)
else:
self.routes_static[uri] = route
return route
@staticmethod
def check_dynamic_route_exists(pattern, routes_to_check, parameters):
@ -442,6 +446,7 @@ class Router:
method=method,
allowed_methods=self.get_supported_methods(url),
)
if route:
if route.methods and method not in route.methods:
raise method_not_supported
@ -476,7 +481,7 @@ class Router:
route_handler = route.handler
if hasattr(route_handler, "handlers"):
route_handler = route_handler.handlers[method]
return route_handler, [], kwargs, route.uri
return route_handler, [], kwargs, route.uri, route.name
def is_stream_handler(self, request):
""" Handler for request is stream or not.

View File

@ -94,7 +94,7 @@ def test_app_route_raise_value_error(app):
def test_app_handle_request_handler_is_none(app, monkeypatch):
def mockreturn(*args, **kwargs):
return None, [], {}, ""
return None, [], {}, "", ""
# Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -83,7 +83,7 @@ def test_bp_group_with_additional_route_params(app: Sanic):
_, response = app.test_client.patch("/api/bp2/route/bp2", headers=header)
assert response.text == "PATCH_bp2"
_, response = app.test_client.get("/v2/api/bp1/request_path")
_, response = app.test_client.put("/v2/api/bp1/request_path")
assert response.status == 401
@ -141,8 +141,8 @@ def test_bp_group(app: Sanic):
_, response = app.test_client.get("/api/bp3")
assert response.text == "BP3_OK"
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 4
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 2
def test_bp_group_list_operations(app: Sanic):

View File

@ -268,7 +268,7 @@ def test_bp_middleware(app):
request, response = app.test_client.get("/")
assert response.status == 200
assert response.text == "OK"
assert response.text == "FAIL"
def test_bp_exception_handler(app):