GIT-37: fix blueprint middleware application (#1690)
* GIT-37: fix blueprint middleware application 1. If you register a middleware via `@blueprint.middleware` then it will apply only to the routes defined by the blueprint. 2. If you register a middleware via `@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group. 3. If you define a middleware via `@app.middleware` then it will be applied on all available routes Fixes #37 Signed-off-by: Harsha Narayana <harsha2k4@gmail.com> * GIT-37: add changelog Signed-off-by: Harsha Narayana <harsha2k4@gmail.com>
This commit is contained in:
parent
179a07942e
commit
a6077a1790
11
changelogs/37.bugfix.rst
Normal file
11
changelogs/37.bugfix.rst
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
Fix blueprint middleware application
|
||||||
|
|
||||||
|
Currently, any blueprint middleware registered, irrespective of which blueprint was used to do so, was
|
||||||
|
being applied to all of the routes created by the :code:`@app` and :code:`@blueprint` alike.
|
||||||
|
|
||||||
|
As part of this change, the blueprint based middleware application is enforced based on where they are
|
||||||
|
registered.
|
||||||
|
|
||||||
|
- If you register a middleware via :code:`@blueprint.middleware` then it will apply only to the routes defined by the blueprint.
|
||||||
|
- If you register a middleware via :code:`@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group.
|
||||||
|
- If you define a middleware via :code:`@app.middleware` then it will be applied on all available routes
|
67
sanic/app.py
67
sanic/app.py
|
@ -85,7 +85,8 @@ class Sanic:
|
||||||
self.is_request_stream = False
|
self.is_request_stream = False
|
||||||
self.websocket_enabled = False
|
self.websocket_enabled = False
|
||||||
self.websocket_tasks = set()
|
self.websocket_tasks = set()
|
||||||
|
self.named_request_middleware = {}
|
||||||
|
self.named_response_middleware = {}
|
||||||
# Register alternative method names
|
# Register alternative method names
|
||||||
self.go_fast = self.run
|
self.go_fast = self.run
|
||||||
|
|
||||||
|
@ -178,7 +179,7 @@ class Sanic:
|
||||||
:param stream:
|
:param stream:
|
||||||
:param version:
|
:param version:
|
||||||
:param name: user defined route name for url_for
|
:param name: user defined route name for url_for
|
||||||
:return: decorated function
|
:return: tuple of routes, decorated function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Fix case where the user did not prefix the URL with a /
|
# Fix case where the user did not prefix the URL with a /
|
||||||
|
@ -204,7 +205,7 @@ class Sanic:
|
||||||
if stream:
|
if stream:
|
||||||
handler.is_stream = stream
|
handler.is_stream = stream
|
||||||
|
|
||||||
self.router.add(
|
routes = self.router.add(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
methods=methods,
|
methods=methods,
|
||||||
handler=handler,
|
handler=handler,
|
||||||
|
@ -213,7 +214,7 @@ class Sanic:
|
||||||
version=version,
|
version=version,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
return handler
|
return routes, handler
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -462,7 +463,7 @@ class Sanic:
|
||||||
:param subprotocols: optional list of str with supported subprotocols
|
:param subprotocols: optional list of str with supported subprotocols
|
||||||
:param name: A unique name assigned to the URL so that it can
|
:param name: A unique name assigned to the URL so that it can
|
||||||
be used with :func:`url_for`
|
be used with :func:`url_for`
|
||||||
:return: decorated function
|
:return: tuple of routes, decorated function
|
||||||
"""
|
"""
|
||||||
self.enable_websocket()
|
self.enable_websocket()
|
||||||
|
|
||||||
|
@ -515,7 +516,7 @@ class Sanic:
|
||||||
self.websocket_tasks.remove(fut)
|
self.websocket_tasks.remove(fut)
|
||||||
await ws.close()
|
await ws.close()
|
||||||
|
|
||||||
self.router.add(
|
routes = self.router.add(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
handler=websocket_handler,
|
handler=websocket_handler,
|
||||||
methods=frozenset({"GET"}),
|
methods=frozenset({"GET"}),
|
||||||
|
@ -523,7 +524,7 @@ class Sanic:
|
||||||
strict_slashes=strict_slashes,
|
strict_slashes=strict_slashes,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
return handler
|
return routes, handler
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -544,6 +545,7 @@ class Sanic:
|
||||||
:param host: Host IP or FQDN details
|
:param host: Host IP or FQDN details
|
||||||
:param uri: URL path that will be mapped to the websocket
|
:param uri: URL path that will be mapped to the websocket
|
||||||
handler
|
handler
|
||||||
|
handler
|
||||||
:param strict_slashes: If the API endpoint needs to terminate
|
:param strict_slashes: If the API endpoint needs to terminate
|
||||||
with a "/" or not
|
with a "/" or not
|
||||||
:param subprotocols: Subprotocols to be used with websocket
|
:param subprotocols: Subprotocols to be used with websocket
|
||||||
|
@ -645,6 +647,22 @@ class Sanic:
|
||||||
self.response_middleware.appendleft(middleware)
|
self.response_middleware.appendleft(middleware)
|
||||||
return middleware
|
return middleware
|
||||||
|
|
||||||
|
def register_named_middleware(
|
||||||
|
self, middleware, route_names, attach_to="request"
|
||||||
|
):
|
||||||
|
if attach_to == "request":
|
||||||
|
for _rn in route_names:
|
||||||
|
if _rn not in self.named_request_middleware:
|
||||||
|
self.named_request_middleware[_rn] = deque()
|
||||||
|
if middleware not in self.named_request_middleware[_rn]:
|
||||||
|
self.named_request_middleware[_rn].append(middleware)
|
||||||
|
if attach_to == "response":
|
||||||
|
for _rn in route_names:
|
||||||
|
if _rn not in self.named_response_middleware:
|
||||||
|
self.named_response_middleware[_rn] = deque()
|
||||||
|
if middleware not in self.named_response_middleware[_rn]:
|
||||||
|
self.named_response_middleware[_rn].append(middleware)
|
||||||
|
|
||||||
# Decorator
|
# Decorator
|
||||||
def middleware(self, middleware_or_request):
|
def middleware(self, middleware_or_request):
|
||||||
"""
|
"""
|
||||||
|
@ -916,20 +934,23 @@ class Sanic:
|
||||||
# allocation before assignment below.
|
# allocation before assignment below.
|
||||||
response = None
|
response = None
|
||||||
cancelled = False
|
cancelled = False
|
||||||
|
name = None
|
||||||
try:
|
try:
|
||||||
|
# Fetch handler from router
|
||||||
|
handler, args, kwargs, uri, name = self.router.get(request)
|
||||||
|
|
||||||
# -------------------------------------------- #
|
# -------------------------------------------- #
|
||||||
# Request Middleware
|
# Request Middleware
|
||||||
# -------------------------------------------- #
|
# -------------------------------------------- #
|
||||||
response = await self._run_request_middleware(request)
|
response = await self._run_request_middleware(
|
||||||
|
request, request_name=name
|
||||||
|
)
|
||||||
# No middleware results
|
# No middleware results
|
||||||
if not response:
|
if not response:
|
||||||
# -------------------------------------------- #
|
# -------------------------------------------- #
|
||||||
# Execute Handler
|
# Execute Handler
|
||||||
# -------------------------------------------- #
|
# -------------------------------------------- #
|
||||||
|
|
||||||
# Fetch handler from router
|
|
||||||
handler, args, kwargs, uri = self.router.get(request)
|
|
||||||
|
|
||||||
request.uri_template = uri
|
request.uri_template = uri
|
||||||
if handler is None:
|
if handler is None:
|
||||||
raise ServerError(
|
raise ServerError(
|
||||||
|
@ -993,7 +1014,7 @@ class Sanic:
|
||||||
if response is not None:
|
if response is not None:
|
||||||
try:
|
try:
|
||||||
response = await self._run_response_middleware(
|
response = await self._run_response_middleware(
|
||||||
request, response
|
request, response, request_name=name
|
||||||
)
|
)
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
# Response middleware can timeout too, as above.
|
# Response middleware can timeout too, as above.
|
||||||
|
@ -1265,10 +1286,14 @@ class Sanic:
|
||||||
if isawaitable(result):
|
if isawaitable(result):
|
||||||
await result
|
await result
|
||||||
|
|
||||||
async def _run_request_middleware(self, request):
|
async def _run_request_middleware(self, request, request_name=None):
|
||||||
# The if improves speed. I don't know why
|
# The if improves speed. I don't know why
|
||||||
if self.request_middleware:
|
named_middleware = self.named_request_middleware.get(
|
||||||
for middleware in self.request_middleware:
|
request_name, deque()
|
||||||
|
)
|
||||||
|
applicable_middleware = self.request_middleware + named_middleware
|
||||||
|
if applicable_middleware:
|
||||||
|
for middleware in applicable_middleware:
|
||||||
response = middleware(request)
|
response = middleware(request)
|
||||||
if isawaitable(response):
|
if isawaitable(response):
|
||||||
response = await response
|
response = await response
|
||||||
|
@ -1276,9 +1301,15 @@ class Sanic:
|
||||||
return response
|
return response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _run_response_middleware(self, request, response):
|
async def _run_response_middleware(
|
||||||
if self.response_middleware:
|
self, request, response, request_name=None
|
||||||
for middleware in self.response_middleware:
|
):
|
||||||
|
named_middleware = self.named_response_middleware.get(
|
||||||
|
request_name, deque()
|
||||||
|
)
|
||||||
|
applicable_middleware = self.response_middleware + named_middleware
|
||||||
|
if applicable_middleware:
|
||||||
|
for middleware in applicable_middleware:
|
||||||
_response = middleware(request, response)
|
_response = middleware(request, response)
|
||||||
if isawaitable(_response):
|
if isawaitable(_response):
|
||||||
_response = await _response
|
_response = await _response
|
||||||
|
|
|
@ -104,6 +104,8 @@ class Blueprint:
|
||||||
|
|
||||||
url_prefix = options.get("url_prefix", self.url_prefix)
|
url_prefix = options.get("url_prefix", self.url_prefix)
|
||||||
|
|
||||||
|
routes = []
|
||||||
|
|
||||||
# Routes
|
# Routes
|
||||||
for future in self.routes:
|
for future in self.routes:
|
||||||
# attach the blueprint name to the handler so that it can be
|
# attach the blueprint name to the handler so that it can be
|
||||||
|
@ -114,7 +116,7 @@ class Blueprint:
|
||||||
|
|
||||||
version = future.version or self.version
|
version = future.version or self.version
|
||||||
|
|
||||||
app.route(
|
_routes, _ = app.route(
|
||||||
uri=uri[1:] if uri.startswith("//") else uri,
|
uri=uri[1:] if uri.startswith("//") else uri,
|
||||||
methods=future.methods,
|
methods=future.methods,
|
||||||
host=future.host or self.host,
|
host=future.host or self.host,
|
||||||
|
@ -123,6 +125,8 @@ class Blueprint:
|
||||||
version=version,
|
version=version,
|
||||||
name=future.name,
|
name=future.name,
|
||||||
)(future.handler)
|
)(future.handler)
|
||||||
|
if _routes:
|
||||||
|
routes += _routes
|
||||||
|
|
||||||
for future in self.websocket_routes:
|
for future in self.websocket_routes:
|
||||||
# attach the blueprint name to the handler so that it can be
|
# attach the blueprint name to the handler so that it can be
|
||||||
|
@ -130,21 +134,27 @@ class Blueprint:
|
||||||
future.handler.__blueprintname__ = self.name
|
future.handler.__blueprintname__ = self.name
|
||||||
# Prepend the blueprint URI prefix if available
|
# Prepend the blueprint URI prefix if available
|
||||||
uri = url_prefix + future.uri if url_prefix else future.uri
|
uri = url_prefix + future.uri if url_prefix else future.uri
|
||||||
app.websocket(
|
_routes, _ = app.websocket(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
host=future.host or self.host,
|
host=future.host or self.host,
|
||||||
strict_slashes=future.strict_slashes,
|
strict_slashes=future.strict_slashes,
|
||||||
name=future.name,
|
name=future.name,
|
||||||
)(future.handler)
|
)(future.handler)
|
||||||
|
if _routes:
|
||||||
|
routes += _routes
|
||||||
|
|
||||||
|
route_names = [route.name for route in routes]
|
||||||
# Middleware
|
# Middleware
|
||||||
for future in self.middlewares:
|
for future in self.middlewares:
|
||||||
if future.args or future.kwargs:
|
if future.args or future.kwargs:
|
||||||
app.register_middleware(
|
app.register_named_middleware(
|
||||||
future.middleware, *future.args, **future.kwargs
|
future.middleware,
|
||||||
|
route_names,
|
||||||
|
*future.args,
|
||||||
|
**future.kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
app.register_middleware(future.middleware)
|
app.register_named_middleware(future.middleware, route_names)
|
||||||
|
|
||||||
# Exceptions
|
# Exceptions
|
||||||
for future in self.exceptions:
|
for future in self.exceptions:
|
||||||
|
|
|
@ -140,21 +140,22 @@ class Router:
|
||||||
docs for further details.
|
docs for further details.
|
||||||
:return: Nothing
|
:return: Nothing
|
||||||
"""
|
"""
|
||||||
|
routes = []
|
||||||
if version is not None:
|
if version is not None:
|
||||||
version = re.escape(str(version).strip("/").lstrip("v"))
|
version = re.escape(str(version).strip("/").lstrip("v"))
|
||||||
uri = "/".join(["/v{}".format(version), uri.lstrip("/")])
|
uri = "/".join(["/v{}".format(version), uri.lstrip("/")])
|
||||||
# add regular version
|
# add regular version
|
||||||
self._add(uri, methods, handler, host, name)
|
routes.append(self._add(uri, methods, handler, host, name))
|
||||||
|
|
||||||
if strict_slashes:
|
if strict_slashes:
|
||||||
return
|
return routes
|
||||||
|
|
||||||
if not isinstance(host, str) and host is not None:
|
if not isinstance(host, str) and host is not None:
|
||||||
# we have gotten back to the top of the recursion tree where the
|
# we have gotten back to the top of the recursion tree where the
|
||||||
# host was originally a list. By now, we've processed the strict
|
# host was originally a list. By now, we've processed the strict
|
||||||
# slashes logic on the leaf nodes (the individual host strings in
|
# slashes logic on the leaf nodes (the individual host strings in
|
||||||
# the list of host)
|
# the list of host)
|
||||||
return
|
return routes
|
||||||
|
|
||||||
# Add versions with and without trailing /
|
# Add versions with and without trailing /
|
||||||
slashed_methods = self.routes_all.get(uri + "/", frozenset({}))
|
slashed_methods = self.routes_all.get(uri + "/", frozenset({}))
|
||||||
|
@ -176,10 +177,12 @@ class Router:
|
||||||
)
|
)
|
||||||
# add version with trailing slash
|
# add version with trailing slash
|
||||||
if slash_is_missing:
|
if slash_is_missing:
|
||||||
self._add(uri + "/", methods, handler, host, name)
|
routes.append(self._add(uri + "/", methods, handler, host, name))
|
||||||
# add version without trailing slash
|
# add version without trailing slash
|
||||||
elif without_slash_is_missing:
|
elif without_slash_is_missing:
|
||||||
self._add(uri[:-1], methods, handler, host, name)
|
routes.append(self._add(uri[:-1], methods, handler, host, name))
|
||||||
|
|
||||||
|
return routes
|
||||||
|
|
||||||
def _add(self, uri, methods, handler, host=None, name=None):
|
def _add(self, uri, methods, handler, host=None, name=None):
|
||||||
"""Add a handler to the route list
|
"""Add a handler to the route list
|
||||||
|
@ -328,6 +331,7 @@ class Router:
|
||||||
self.routes_dynamic[url_hash(uri)].append(route)
|
self.routes_dynamic[url_hash(uri)].append(route)
|
||||||
else:
|
else:
|
||||||
self.routes_static[uri] = route
|
self.routes_static[uri] = route
|
||||||
|
return route
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_dynamic_route_exists(pattern, routes_to_check, parameters):
|
def check_dynamic_route_exists(pattern, routes_to_check, parameters):
|
||||||
|
@ -442,6 +446,7 @@ class Router:
|
||||||
method=method,
|
method=method,
|
||||||
allowed_methods=self.get_supported_methods(url),
|
allowed_methods=self.get_supported_methods(url),
|
||||||
)
|
)
|
||||||
|
|
||||||
if route:
|
if route:
|
||||||
if route.methods and method not in route.methods:
|
if route.methods and method not in route.methods:
|
||||||
raise method_not_supported
|
raise method_not_supported
|
||||||
|
@ -476,7 +481,7 @@ class Router:
|
||||||
route_handler = route.handler
|
route_handler = route.handler
|
||||||
if hasattr(route_handler, "handlers"):
|
if hasattr(route_handler, "handlers"):
|
||||||
route_handler = route_handler.handlers[method]
|
route_handler = route_handler.handlers[method]
|
||||||
return route_handler, [], kwargs, route.uri
|
return route_handler, [], kwargs, route.uri, route.name
|
||||||
|
|
||||||
def is_stream_handler(self, request):
|
def is_stream_handler(self, request):
|
||||||
""" Handler for request is stream or not.
|
""" Handler for request is stream or not.
|
||||||
|
|
|
@ -94,7 +94,7 @@ def test_app_route_raise_value_error(app):
|
||||||
|
|
||||||
def test_app_handle_request_handler_is_none(app, monkeypatch):
|
def test_app_handle_request_handler_is_none(app, monkeypatch):
|
||||||
def mockreturn(*args, **kwargs):
|
def mockreturn(*args, **kwargs):
|
||||||
return None, [], {}, ""
|
return None, [], {}, "", ""
|
||||||
|
|
||||||
# Not sure how to make app.router.get() return None, so use mock here.
|
# Not sure how to make app.router.get() return None, so use mock here.
|
||||||
monkeypatch.setattr(app.router, "get", mockreturn)
|
monkeypatch.setattr(app.router, "get", mockreturn)
|
||||||
|
|
|
@ -83,7 +83,7 @@ def test_bp_group_with_additional_route_params(app: Sanic):
|
||||||
_, response = app.test_client.patch("/api/bp2/route/bp2", headers=header)
|
_, response = app.test_client.patch("/api/bp2/route/bp2", headers=header)
|
||||||
assert response.text == "PATCH_bp2"
|
assert response.text == "PATCH_bp2"
|
||||||
|
|
||||||
_, response = app.test_client.get("/v2/api/bp1/request_path")
|
_, response = app.test_client.put("/v2/api/bp1/request_path")
|
||||||
assert response.status == 401
|
assert response.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,8 +141,8 @@ def test_bp_group(app: Sanic):
|
||||||
_, response = app.test_client.get("/api/bp3")
|
_, response = app.test_client.get("/api/bp3")
|
||||||
assert response.text == "BP3_OK"
|
assert response.text == "BP3_OK"
|
||||||
|
|
||||||
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 4
|
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3
|
||||||
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4
|
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 2
|
||||||
|
|
||||||
|
|
||||||
def test_bp_group_list_operations(app: Sanic):
|
def test_bp_group_list_operations(app: Sanic):
|
||||||
|
|
|
@ -268,7 +268,7 @@ def test_bp_middleware(app):
|
||||||
request, response = app.test_client.get("/")
|
request, response = app.test_client.get("/")
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
assert response.text == "OK"
|
assert response.text == "FAIL"
|
||||||
|
|
||||||
|
|
||||||
def test_bp_exception_handler(app):
|
def test_bp_exception_handler(app):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user