Begin middleware revamp (#2550)

This commit is contained in:
Adam Hopkins 2022-09-22 00:43:42 +03:00 committed by GitHub
parent 43ba381e7b
commit 6e32270036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 357 additions and 101 deletions

View File

@ -709,7 +709,10 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- #
async def handle_exception(
self, request: Request, exception: BaseException
self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
@ -718,6 +721,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
:param exception: The exception that was raised
:raises ServerError: response 500
"""
response = None
await self.dispatch(
"http.lifecycle.exception",
inline=True,
@ -758,9 +762,11 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
if run_middleware:
middleware = (
request.route and request.route.extra.request_middleware
) or self.request_middleware
response = await self._run_request_middleware(request, middleware)
# No middleware results
if not response:
try:
@ -840,7 +846,13 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
response: Optional[
Union[
BaseHTTPResponse,
Coroutine[Any, Any, Optional[BaseHTTPResponse]],
]
] = None
run_middleware = True
try:
await self.dispatch(
@ -885,9 +897,11 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=route.name
)
run_middleware = False
if request.route.extra.request_middleware:
response = await self._run_request_middleware(
request, request.route.extra.request_middleware
)
# No middleware results
if not response:
@ -928,7 +942,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
if request.stream is not None:
response = request.stream.response
elif response is not None:
response = await request.respond(response)
response = await request.respond(response) # type: ignore
elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore
@ -946,7 +960,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
resp = await response(request) # type: ignore
await self.dispatch(
"http.lifecycle.response",
inline=True,
@ -955,7 +969,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
"response": resp,
},
)
await response.eof()
await response.eof() # type: ignore
else:
if not hasattr(handler, "is_websocket"):
raise ServerError(
@ -967,7 +981,9 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
raise
except Exception as e:
# Response Generation Failed
await self.handle_exception(request, e)
await self.handle_exception(
request, e, run_middleware=run_middleware
)
async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
@ -1036,86 +1052,72 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- #
async def _run_request_middleware(
self, request, request_name=None
self, request, middleware_collection
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
request._request_middleware_started = True
# request.request_middleware_started is meant as a stop-gap solution
# until RFC 1630 is adopted
if applicable_middleware and not request.request_middleware_started:
request.request_middleware_started = True
for middleware in middleware_collection:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request)
if isawaitable(response):
response = await response
response = middleware(request)
if isawaitable(response):
response = await response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
if response:
return response
return None
async def _run_response_middleware(
self, request, response, request_name=None
self, request, response, middleware_collection
): # no cov
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
for middleware in middleware_collection:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
return response
def _build_endpoint_name(self, *parts):
@ -1528,6 +1530,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
except FinalizationError as e:
if not Sanic.test_mode:
raise e
self.finalize_middleware()
def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin

View File

@ -234,4 +234,7 @@ class ASGIApp:
self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request)
except Exception as e:
await self.sanic_app.handle_exception(self.request, e)
try:
await self.sanic_app.handle_exception(self.request, e)
except Exception as exc:
await self.sanic_app.handle_exception(self.request, exc, False)

View File

@ -428,7 +428,10 @@ class Http(Stream, metaclass=TouchUpMeta):
if self.request is None:
self.create_empty_request()
await app.handle_exception(self.request, exception)
try:
await app.handle_exception(self.request, exception)
except Exception as e:
await app.handle_exception(self.request, e, False)
def create_empty_request(self) -> None:
"""

66
sanic/middleware.py Normal file
View File

@ -0,0 +1,66 @@
from __future__ import annotations
from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Sequence, Union
from sanic.models.handler_types import MiddlewareType
class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()
class Middleware:
_counter = count()
__slots__ = ("func", "priority", "location", "definition")
def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware._counter)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)
@property
def order(self):
return (self.priority, -self.definition)
@classmethod
def convert(
cls,
*middleware_collections: Sequence[Union[Middleware, MiddlewareType]],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)
@classmethod
def reset_count(cls):
cls._counter = count()
cls.count = next(cls._counter)

View File

@ -1,11 +1,17 @@
from collections import deque
from functools import partial
from operator import attrgetter
from typing import List
from sanic.base.meta import SanicMeta
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware
from sanic.router import Router
class MiddlewareMixin(metaclass=SanicMeta):
router: Router
def __init__(self, *args, **kwargs) -> None:
self._future_middleware: List[FutureMiddleware] = []
@ -13,7 +19,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
raise NotImplementedError # noqa
def middleware(
self, middleware_or_request, attach_to="request", apply=True
self,
middleware_or_request,
attach_to="request",
apply=True,
*,
priority=0
):
"""
Decorate and register middleware to be called before a request
@ -30,6 +41,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
def register_middleware(middleware, attach_to="request"):
nonlocal apply
location = (
MiddlewareLocation.REQUEST
if attach_to == "request"
else MiddlewareLocation.RESPONSE
)
middleware = Middleware(middleware, location, priority=priority)
future_middleware = FutureMiddleware(middleware, attach_to)
self._future_middleware.append(future_middleware)
if apply:
@ -46,7 +63,7 @@ class MiddlewareMixin(metaclass=SanicMeta):
register_middleware, attach_to=middleware_or_request
)
def on_request(self, middleware=None):
def on_request(self, middleware=None, *, priority=0):
"""Register a middleware to be called before a request is handled.
This is the same as *@app.middleware('request')*.
@ -54,11 +71,13 @@ class MiddlewareMixin(metaclass=SanicMeta):
:param: middleware: A callable that takes in request.
"""
if callable(middleware):
return self.middleware(middleware, "request")
return self.middleware(middleware, "request", priority=priority)
else:
return partial(self.middleware, attach_to="request")
return partial(
self.middleware, attach_to="request", priority=priority
)
def on_response(self, middleware=None):
def on_response(self, middleware=None, *, priority=0):
"""Register a middleware to be called after a response is created.
This is the same as *@app.middleware('response')*.
@ -67,6 +86,57 @@ class MiddlewareMixin(metaclass=SanicMeta):
A callable that takes in a request and its response.
"""
if callable(middleware):
return self.middleware(middleware, "response")
return self.middleware(middleware, "response", priority=priority)
else:
return partial(self.middleware, attach_to="response")
return partial(
self.middleware, attach_to="response", priority=priority
)
def finalize_middleware(self):
for route in self.router.routes:
request_middleware = Middleware.convert(
self.request_middleware,
self.named_request_middleware.get(route.name, deque()),
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE,
)
route.extra.request_middleware = deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
)
route.extra.response_middleware = deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
)
request_middleware = Middleware.convert(
self.request_middleware,
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
location=MiddlewareLocation.RESPONSE,
)
self.request_middleware = deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
)
self.response_middleware = deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
)

View File

@ -56,7 +56,7 @@ from sanic.headers import (
parse_xforwarded,
)
from sanic.http import Stage
from sanic.log import error_logger, logger
from sanic.log import deprecation, error_logger, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse
@ -103,6 +103,7 @@ class Request:
"_port",
"_protocol",
"_remote_addr",
"_request_middleware_started",
"_scheme",
"_socket",
"_stream_id",
@ -126,7 +127,6 @@ class Request:
"parsed_token",
"raw_url",
"responded",
"request_middleware_started",
"route",
"stream",
"transport",
@ -178,7 +178,7 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list)
self.request_middleware_started = False
self._request_middleware_started = False
self.responded: bool = False
self.route: Optional[Route] = None
self.stream: Optional[Stream] = None
@ -219,6 +219,16 @@ class Request:
def generate_id(*_):
return uuid.uuid4()
@property
def request_middleware_started(self):
deprecation(
"Request.request_middleware_started has been deprecated and will"
"be removed. You should set a flag on the request context using"
"either middleware or signals if you need this feature.",
23.3,
)
return self._request_middleware_started
@property
def stream_id(self):
"""
@ -324,9 +334,13 @@ class Request:
response = await response # type: ignore
# Run response middleware
try:
response = await self.app._run_response_middleware(
self, response, request_name=self.name
)
middleware = (
self.route and self.route.extra.response_middleware
) or self.app.response_middleware
if middleware:
response = await self.app._run_response_middleware(
self, response, middleware
)
except CancelledErrors:
raise
except Exception:

View File

@ -73,7 +73,7 @@ class Inspector:
def state_to_json(self):
output = {"info": self.app_info}
output["workers"] = self._make_safe(dict(self.worker_state))
output["workers"] = self.make_safe(dict(self.worker_state))
return output
def reload(self):
@ -84,10 +84,11 @@ class Inspector:
message = "__TERMINATE__"
self._publisher.send(message)
def _make_safe(self, obj: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def make_safe(obj: Dict[str, Any]) -> Dict[str, Any]:
for key, value in obj.items():
if isinstance(value, dict):
obj[key] = self._make_safe(value)
obj[key] = Inspector.make_safe(value)
elif isinstance(value, datetime):
obj[key] = value.isoformat()
return obj

View File

@ -84,7 +84,7 @@ ujson = "ujson>=1.35" + env_dependency
uvloop = "uvloop>=0.5.3" + env_dependency
types_ujson = "types-ujson" + env_dependency
requirements = [
"sanic-routing>=22.3.0,<22.6.0",
"sanic-routing>=22.8.0",
"httptools>=0.0.10",
uvloop,
ujson,
@ -94,7 +94,7 @@ requirements = [
]
tests_require = [
"sanic-testing>=22.9.0b1",
"sanic-testing>=22.9.0b2",
"pytest",
"coverage",
"beautifulsoup4",

View File

@ -18,6 +18,7 @@ from sanic.exceptions import SanicException
from sanic.helpers import _default
from sanic.log import LOGGING_CONFIG_DEFAULTS
from sanic.response import text
from sanic.router import Route
@pytest.fixture(autouse=True)
@ -152,8 +153,13 @@ def test_app_route_raise_value_error(app: Sanic):
def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch):
app.config.TOUCHUP = False
route = Mock(spec=Route)
route.extra.request_middleware = []
route.extra.response_middleware = []
def mockreturn(*args, **kwargs):
return Mock(), None, {}
return route, None, {}
monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -0,0 +1,90 @@
from functools import partial
import pytest
from sanic import Sanic
from sanic.middleware import Middleware
from sanic.response import json
PRIORITY_TEST_CASES = (
([0, 1, 2], [1, 1, 1]),
([0, 1, 2], [1, 1, None]),
([0, 1, 2], [1, None, None]),
([0, 1, 2], [2, 1, None]),
([0, 1, 2], [2, 2, None]),
([0, 1, 2], [3, 2, 1]),
([0, 1, 2], [None, None, None]),
([0, 2, 1], [1, None, 1]),
([0, 2, 1], [2, None, 1]),
([0, 2, 1], [2, None, 2]),
([0, 2, 1], [3, 1, 2]),
([1, 0, 2], [1, 2, None]),
([1, 0, 2], [2, 3, 1]),
([1, 0, 2], [None, 1, None]),
([1, 2, 0], [1, 3, 2]),
([1, 2, 0], [None, 1, 1]),
([1, 2, 0], [None, 2, 1]),
([1, 2, 0], [None, 2, 2]),
([2, 0, 1], [1, None, 2]),
([2, 0, 1], [2, 1, 3]),
([2, 0, 1], [None, None, 1]),
([2, 1, 0], [1, 2, 3]),
([2, 1, 0], [None, 1, 2]),
)
@pytest.fixture(autouse=True)
def reset_middleware():
yield
Middleware.reset_count()
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_request_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_request(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order == expected
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_response_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, response, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_response(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order[::-1] == expected