Add priority to register_middleware method (#2636)

This commit is contained in:
Adam Hopkins 2022-12-19 19:14:46 +02:00 committed by GitHub
parent 911485d52e
commit 2abe66b670
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 10 deletions

View File

@ -61,7 +61,7 @@ from sanic.exceptions import (
URLBuildError, URLBuildError,
) )
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.helpers import Default from sanic.helpers import Default, _default
from sanic.http import Stage from sanic.http import Stage
from sanic.log import ( from sanic.log import (
LOGGING_CONFIG_DEFAULTS, LOGGING_CONFIG_DEFAULTS,
@ -69,6 +69,7 @@ from sanic.log import (
error_logger, error_logger,
logger, logger,
) )
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.mixins.listeners import ListenerEvent from sanic.mixins.listeners import ListenerEvent
from sanic.mixins.startup import StartupMixin from sanic.mixins.startup import StartupMixin
from sanic.models.futures import ( from sanic.models.futures import (
@ -294,8 +295,12 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
return listener return listener
def register_middleware( def register_middleware(
self, middleware: MiddlewareType, attach_to: str = "request" self,
) -> MiddlewareType: middleware: Union[MiddlewareType, Middleware],
attach_to: str = "request",
*,
priority: Union[Default, int] = _default,
) -> Union[MiddlewareType, Middleware]:
""" """
Register an application level middleware that will be attached Register an application level middleware that will be attached
to all the API URLs registered under this application. to all the API URLs registered under this application.
@ -311,19 +316,37 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
**response** - Invoke before the response is returned back **response** - Invoke before the response is returned back
:return: decorated method :return: decorated method
""" """
if attach_to == "request": retval = middleware
location = MiddlewareLocation[attach_to.upper()]
if not isinstance(middleware, Middleware):
middleware = Middleware(
middleware,
location=location,
priority=priority if isinstance(priority, int) else 0,
)
elif middleware.priority != priority and isinstance(priority, int):
middleware = Middleware(
middleware.func,
location=middleware.location,
priority=priority,
)
if location is MiddlewareLocation.REQUEST:
if middleware not in self.request_middleware: if middleware not in self.request_middleware:
self.request_middleware.append(middleware) self.request_middleware.append(middleware)
if attach_to == "response": if location is MiddlewareLocation.RESPONSE:
if middleware not in self.response_middleware: if middleware not in self.response_middleware:
self.response_middleware.appendleft(middleware) self.response_middleware.appendleft(middleware)
return middleware return retval
def register_named_middleware( def register_named_middleware(
self, self,
middleware: MiddlewareType, middleware: MiddlewareType,
route_names: Iterable[str], route_names: Iterable[str],
attach_to: str = "request", attach_to: str = "request",
*,
priority: Union[Default, int] = _default,
): ):
""" """
Method for attaching middleware to specific routes. This is mainly an Method for attaching middleware to specific routes. This is mainly an
@ -337,19 +360,35 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
defaults to "request" defaults to "request"
:type attach_to: str, optional :type attach_to: str, optional
""" """
if attach_to == "request": retval = middleware
location = MiddlewareLocation[attach_to.upper()]
if not isinstance(middleware, Middleware):
middleware = Middleware(
middleware,
location=location,
priority=priority if isinstance(priority, int) else 0,
)
elif middleware.priority != priority and isinstance(priority, int):
middleware = Middleware(
middleware.func,
location=middleware.location,
priority=priority,
)
if location is MiddlewareLocation.REQUEST:
for _rn in route_names: for _rn in route_names:
if _rn not in self.named_request_middleware: if _rn not in self.named_request_middleware:
self.named_request_middleware[_rn] = deque() self.named_request_middleware[_rn] = deque()
if middleware not in self.named_request_middleware[_rn]: if middleware not in self.named_request_middleware[_rn]:
self.named_request_middleware[_rn].append(middleware) self.named_request_middleware[_rn].append(middleware)
if attach_to == "response": if location is MiddlewareLocation.RESPONSE:
for _rn in route_names: for _rn in route_names:
if _rn not in self.named_response_middleware: if _rn not in self.named_response_middleware:
self.named_response_middleware[_rn] = deque() self.named_response_middleware[_rn] = deque()
if middleware not in self.named_response_middleware[_rn]: if middleware not in self.named_response_middleware[_rn]:
self.named_response_middleware[_rn].appendleft(middleware) self.named_response_middleware[_rn].appendleft(middleware)
return middleware return retval
def _apply_exception_handler( def _apply_exception_handler(
self, self,

View File

@ -32,6 +32,9 @@ class Middleware:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
def __hash__(self) -> int:
return hash(self.func)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"{self.__class__.__name__}(" f"{self.__class__.__name__}("

View File

@ -3,7 +3,7 @@ from functools import partial
import pytest import pytest
from sanic import Sanic from sanic import Sanic
from sanic.middleware import Middleware from sanic.middleware import Middleware, MiddlewareLocation
from sanic.response import json from sanic.response import json
@ -40,6 +40,86 @@ def reset_middleware():
Middleware.reset_count() Middleware.reset_count()
def test_add_register_priority(app: Sanic):
def foo(*_):
...
app.register_middleware(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.register_middleware(foo, attach_to="response", priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore
def test_add_register_named_priority(app: Sanic):
def foo(*_):
...
app.register_named_middleware(foo, route_names=["foo"], priority=999)
assert len(app.named_request_middleware) == 1
assert len(app.named_response_middleware) == 0
assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore
app.register_named_middleware(
foo, attach_to="response", route_names=["foo"], priority=999
)
assert len(app.named_request_middleware) == 1
assert len(app.named_response_middleware) == 1
assert app.named_response_middleware["foo"][0].priority == 999 # type: ignore
def test_add_decorator_priority(app: Sanic):
def foo(*_):
...
app.middleware(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.middleware(foo, attach_to="response", priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore
def test_add_convenience_priority(app: Sanic):
def foo(*_):
...
app.on_request(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.on_response(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore
def test_add_conflicting_priority(app: Sanic):
def foo(*_):
...
middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998)
app.register_middleware(middleware=middleware, priority=999)
assert app.request_middleware[0].priority == 999 # type: ignore
middleware.priority == 998
def test_add_conflicting_priority_named(app: Sanic):
def foo(*_):
...
middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998)
app.register_named_middleware(
middleware=middleware, route_names=["foo"], priority=999
)
assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore
middleware.priority == 998
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expected,priorities", "expected,priorities",
PRIORITY_TEST_CASES, PRIORITY_TEST_CASES,