Begin middleware revamp (#2550)

This commit is contained in:
Adam Hopkins 2022-09-22 00:43:42 +03:00 committed by GitHub
parent 43ba381e7b
commit 6e32270036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 357 additions and 101 deletions

View File

@ -709,7 +709,10 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
async def handle_exception( async def handle_exception(
self, request: Request, exception: BaseException self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov ): # no cov
""" """
A handler that catches specific exceptions and outputs a response. A handler that catches specific exceptions and outputs a response.
@ -718,6 +721,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
:param exception: The exception that was raised :param exception: The exception that was raised
:raises ServerError: response 500 :raises ServerError: response 500
""" """
response = None
await self.dispatch( await self.dispatch(
"http.lifecycle.exception", "http.lifecycle.exception",
inline=True, inline=True,
@ -758,9 +762,11 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware
# -------------------------------------------- # # -------------------------------------------- #
response = await self._run_request_middleware( if run_middleware:
request, request_name=None middleware = (
) request.route and request.route.extra.request_middleware
) or self.request_middleware
response = await self._run_request_middleware(request, middleware)
# No middleware results # No middleware results
if not response: if not response:
try: try:
@ -840,7 +846,13 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# Define `response` var here to remove warnings about # Define `response` var here to remove warnings about
# allocation before assignment below. # allocation before assignment below.
response = None response: Optional[
Union[
BaseHTTPResponse,
Coroutine[Any, Any, Optional[BaseHTTPResponse]],
]
] = None
run_middleware = True
try: try:
await self.dispatch( await self.dispatch(
@ -885,9 +897,11 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware
# -------------------------------------------- # # -------------------------------------------- #
response = await self._run_request_middleware( run_middleware = False
request, request_name=route.name if request.route.extra.request_middleware:
) response = await self._run_request_middleware(
request, request.route.extra.request_middleware
)
# No middleware results # No middleware results
if not response: if not response:
@ -928,7 +942,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
if request.stream is not None: if request.stream is not None:
response = request.stream.response response = request.stream.response
elif response is not None: elif response is not None:
response = await request.respond(response) response = await request.respond(response) # type: ignore
elif not hasattr(handler, "is_websocket"): elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore response = request.stream.response # type: ignore
@ -946,7 +960,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
... ...
await response.send(end_stream=True) await response.send(end_stream=True)
elif isinstance(response, ResponseStream): elif isinstance(response, ResponseStream):
resp = await response(request) resp = await response(request) # type: ignore
await self.dispatch( await self.dispatch(
"http.lifecycle.response", "http.lifecycle.response",
inline=True, inline=True,
@ -955,7 +969,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
"response": resp, "response": resp,
}, },
) )
await response.eof() await response.eof() # type: ignore
else: else:
if not hasattr(handler, "is_websocket"): if not hasattr(handler, "is_websocket"):
raise ServerError( raise ServerError(
@ -967,7 +981,9 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
raise raise
except Exception as e: except Exception as e:
# Response Generation Failed # Response Generation Failed
await self.handle_exception(request, e) await self.handle_exception(
request, e, run_middleware=run_middleware
)
async def _websocket_handler( async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs self, handler, request, *args, subprotocols=None, **kwargs
@ -1036,86 +1052,72 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
async def _run_request_middleware( async def _run_request_middleware(
self, request, request_name=None self, request, middleware_collection
): # no cov ): # no cov
# The if improves speed. I don't know why request._request_middleware_started = True
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
# request.request_middleware_started is meant as a stop-gap solution for middleware in middleware_collection:
# until RFC 1630 is adopted await self.dispatch(
if applicable_middleware and not request.request_middleware_started: "http.middleware.before",
request.request_middleware_started = True inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in applicable_middleware: response = middleware(request)
await self.dispatch( if isawaitable(response):
"http.middleware.before", response = await response
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request) await self.dispatch(
if isawaitable(response): "http.middleware.after",
response = await response inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch( if response:
"http.middleware.after", return response
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
return None return None
async def _run_response_middleware( async def _run_response_middleware(
self, request, response, request_name=None self, request, response, middleware_collection
): # no cov ): # no cov
named_middleware = self.named_response_middleware.get( for middleware in middleware_collection:
request_name, deque() await self.dispatch(
) "http.middleware.before",
applicable_middleware = self.response_middleware + named_middleware inline=True,
if applicable_middleware: context={
for middleware in applicable_middleware: "request": request,
await self.dispatch( "response": response,
"http.middleware.before", },
inline=True, condition={"attach_to": "response"},
context={ )
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response) _response = middleware(request, response)
if isawaitable(_response): if isawaitable(_response):
_response = await _response _response = await _response
await self.dispatch( await self.dispatch(
"http.middleware.after", "http.middleware.after",
inline=True, inline=True,
context={ context={
"request": request, "request": request,
"response": _response if _response else response, "response": _response if _response else response,
}, },
condition={"attach_to": "response"}, condition={"attach_to": "response"},
) )
if _response: if _response:
response = _response response = _response
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response) response = request.stream.respond(response)
break break
return response return response
def _build_endpoint_name(self, *parts): def _build_endpoint_name(self, *parts):
@ -1528,6 +1530,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
except FinalizationError as e: except FinalizationError as e:
if not Sanic.test_mode: if not Sanic.test_mode:
raise e raise e
self.finalize_middleware()
def signalize(self, allow_fail_builtin=True): def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin self.signal_router.allow_fail_builtin = allow_fail_builtin

View File

@ -234,4 +234,7 @@ class ASGIApp:
self.stage = Stage.HANDLER self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request) await self.sanic_app.handle_request(self.request)
except Exception as e: except Exception as e:
await self.sanic_app.handle_exception(self.request, e) try:
await self.sanic_app.handle_exception(self.request, e)
except Exception as exc:
await self.sanic_app.handle_exception(self.request, exc, False)

View File

@ -428,7 +428,10 @@ class Http(Stream, metaclass=TouchUpMeta):
if self.request is None: if self.request is None:
self.create_empty_request() self.create_empty_request()
await app.handle_exception(self.request, exception) try:
await app.handle_exception(self.request, exception)
except Exception as e:
await app.handle_exception(self.request, e, False)
def create_empty_request(self) -> None: def create_empty_request(self) -> None:
""" """

66
sanic/middleware.py Normal file
View File

@ -0,0 +1,66 @@
from __future__ import annotations
from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Sequence, Union
from sanic.models.handler_types import MiddlewareType
class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()
class Middleware:
_counter = count()
__slots__ = ("func", "priority", "location", "definition")
def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware._counter)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)
@property
def order(self):
return (self.priority, -self.definition)
@classmethod
def convert(
cls,
*middleware_collections: Sequence[Union[Middleware, MiddlewareType]],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)
@classmethod
def reset_count(cls):
cls._counter = count()
cls.count = next(cls._counter)

View File

@ -1,11 +1,17 @@
from collections import deque
from functools import partial from functools import partial
from operator import attrgetter
from typing import List from typing import List
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware from sanic.models.futures import FutureMiddleware
from sanic.router import Router
class MiddlewareMixin(metaclass=SanicMeta): class MiddlewareMixin(metaclass=SanicMeta):
router: Router
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self._future_middleware: List[FutureMiddleware] = [] self._future_middleware: List[FutureMiddleware] = []
@ -13,7 +19,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
raise NotImplementedError # noqa raise NotImplementedError # noqa
def middleware( def middleware(
self, middleware_or_request, attach_to="request", apply=True self,
middleware_or_request,
attach_to="request",
apply=True,
*,
priority=0
): ):
""" """
Decorate and register middleware to be called before a request Decorate and register middleware to be called before a request
@ -30,6 +41,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
def register_middleware(middleware, attach_to="request"): def register_middleware(middleware, attach_to="request"):
nonlocal apply nonlocal apply
location = (
MiddlewareLocation.REQUEST
if attach_to == "request"
else MiddlewareLocation.RESPONSE
)
middleware = Middleware(middleware, location, priority=priority)
future_middleware = FutureMiddleware(middleware, attach_to) future_middleware = FutureMiddleware(middleware, attach_to)
self._future_middleware.append(future_middleware) self._future_middleware.append(future_middleware)
if apply: if apply:
@ -46,7 +63,7 @@ class MiddlewareMixin(metaclass=SanicMeta):
register_middleware, attach_to=middleware_or_request register_middleware, attach_to=middleware_or_request
) )
def on_request(self, middleware=None): def on_request(self, middleware=None, *, priority=0):
"""Register a middleware to be called before a request is handled. """Register a middleware to be called before a request is handled.
This is the same as *@app.middleware('request')*. This is the same as *@app.middleware('request')*.
@ -54,11 +71,13 @@ class MiddlewareMixin(metaclass=SanicMeta):
:param: middleware: A callable that takes in request. :param: middleware: A callable that takes in request.
""" """
if callable(middleware): if callable(middleware):
return self.middleware(middleware, "request") return self.middleware(middleware, "request", priority=priority)
else: else:
return partial(self.middleware, attach_to="request") return partial(
self.middleware, attach_to="request", priority=priority
)
def on_response(self, middleware=None): def on_response(self, middleware=None, *, priority=0):
"""Register a middleware to be called after a response is created. """Register a middleware to be called after a response is created.
This is the same as *@app.middleware('response')*. This is the same as *@app.middleware('response')*.
@ -67,6 +86,57 @@ class MiddlewareMixin(metaclass=SanicMeta):
A callable that takes in a request and its response. A callable that takes in a request and its response.
""" """
if callable(middleware): if callable(middleware):
return self.middleware(middleware, "response") return self.middleware(middleware, "response", priority=priority)
else: else:
return partial(self.middleware, attach_to="response") return partial(
self.middleware, attach_to="response", priority=priority
)
def finalize_middleware(self):
for route in self.router.routes:
request_middleware = Middleware.convert(
self.request_middleware,
self.named_request_middleware.get(route.name, deque()),
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE,
)
route.extra.request_middleware = deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
)
route.extra.response_middleware = deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
)
request_middleware = Middleware.convert(
self.request_middleware,
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
location=MiddlewareLocation.RESPONSE,
)
self.request_middleware = deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
)
self.response_middleware = deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
)

View File

@ -56,7 +56,7 @@ from sanic.headers import (
parse_xforwarded, parse_xforwarded,
) )
from sanic.http import Stage from sanic.http import Stage
from sanic.log import error_logger, logger from sanic.log import deprecation, error_logger, logger
from sanic.models.protocol_types import TransportProtocol from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse from sanic.response import BaseHTTPResponse, HTTPResponse
@ -103,6 +103,7 @@ class Request:
"_port", "_port",
"_protocol", "_protocol",
"_remote_addr", "_remote_addr",
"_request_middleware_started",
"_scheme", "_scheme",
"_socket", "_socket",
"_stream_id", "_stream_id",
@ -126,7 +127,6 @@ class Request:
"parsed_token", "parsed_token",
"raw_url", "raw_url",
"responded", "responded",
"request_middleware_started",
"route", "route",
"stream", "stream",
"transport", "transport",
@ -178,7 +178,7 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[ self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]] Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list) ] = defaultdict(list)
self.request_middleware_started = False self._request_middleware_started = False
self.responded: bool = False self.responded: bool = False
self.route: Optional[Route] = None self.route: Optional[Route] = None
self.stream: Optional[Stream] = None self.stream: Optional[Stream] = None
@ -219,6 +219,16 @@ class Request:
def generate_id(*_): def generate_id(*_):
return uuid.uuid4() return uuid.uuid4()
@property
def request_middleware_started(self):
deprecation(
"Request.request_middleware_started has been deprecated and will"
"be removed. You should set a flag on the request context using"
"either middleware or signals if you need this feature.",
23.3,
)
return self._request_middleware_started
@property @property
def stream_id(self): def stream_id(self):
""" """
@ -324,9 +334,13 @@ class Request:
response = await response # type: ignore response = await response # type: ignore
# Run response middleware # Run response middleware
try: try:
response = await self.app._run_response_middleware( middleware = (
self, response, request_name=self.name self.route and self.route.extra.response_middleware
) ) or self.app.response_middleware
if middleware:
response = await self.app._run_response_middleware(
self, response, middleware
)
except CancelledErrors: except CancelledErrors:
raise raise
except Exception: except Exception:

View File

@ -73,7 +73,7 @@ class Inspector:
def state_to_json(self): def state_to_json(self):
output = {"info": self.app_info} output = {"info": self.app_info}
output["workers"] = self._make_safe(dict(self.worker_state)) output["workers"] = self.make_safe(dict(self.worker_state))
return output return output
def reload(self): def reload(self):
@ -84,10 +84,11 @@ class Inspector:
message = "__TERMINATE__" message = "__TERMINATE__"
self._publisher.send(message) self._publisher.send(message)
def _make_safe(self, obj: Dict[str, Any]) -> Dict[str, Any]: @staticmethod
def make_safe(obj: Dict[str, Any]) -> Dict[str, Any]:
for key, value in obj.items(): for key, value in obj.items():
if isinstance(value, dict): if isinstance(value, dict):
obj[key] = self._make_safe(value) obj[key] = Inspector.make_safe(value)
elif isinstance(value, datetime): elif isinstance(value, datetime):
obj[key] = value.isoformat() obj[key] = value.isoformat()
return obj return obj

View File

@ -84,7 +84,7 @@ ujson = "ujson>=1.35" + env_dependency
uvloop = "uvloop>=0.5.3" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency
types_ujson = "types-ujson" + env_dependency types_ujson = "types-ujson" + env_dependency
requirements = [ requirements = [
"sanic-routing>=22.3.0,<22.6.0", "sanic-routing>=22.8.0",
"httptools>=0.0.10", "httptools>=0.0.10",
uvloop, uvloop,
ujson, ujson,
@ -94,7 +94,7 @@ requirements = [
] ]
tests_require = [ tests_require = [
"sanic-testing>=22.9.0b1", "sanic-testing>=22.9.0b2",
"pytest", "pytest",
"coverage", "coverage",
"beautifulsoup4", "beautifulsoup4",

View File

@ -18,6 +18,7 @@ from sanic.exceptions import SanicException
from sanic.helpers import _default from sanic.helpers import _default
from sanic.log import LOGGING_CONFIG_DEFAULTS from sanic.log import LOGGING_CONFIG_DEFAULTS
from sanic.response import text from sanic.response import text
from sanic.router import Route
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -152,8 +153,13 @@ def test_app_route_raise_value_error(app: Sanic):
def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch): def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch):
app.config.TOUCHUP = False
route = Mock(spec=Route)
route.extra.request_middleware = []
route.extra.response_middleware = []
def mockreturn(*args, **kwargs): def mockreturn(*args, **kwargs):
return Mock(), None, {} return route, None, {}
monkeypatch.setattr(app.router, "get", mockreturn) monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -0,0 +1,90 @@
from functools import partial
import pytest
from sanic import Sanic
from sanic.middleware import Middleware
from sanic.response import json
PRIORITY_TEST_CASES = (
([0, 1, 2], [1, 1, 1]),
([0, 1, 2], [1, 1, None]),
([0, 1, 2], [1, None, None]),
([0, 1, 2], [2, 1, None]),
([0, 1, 2], [2, 2, None]),
([0, 1, 2], [3, 2, 1]),
([0, 1, 2], [None, None, None]),
([0, 2, 1], [1, None, 1]),
([0, 2, 1], [2, None, 1]),
([0, 2, 1], [2, None, 2]),
([0, 2, 1], [3, 1, 2]),
([1, 0, 2], [1, 2, None]),
([1, 0, 2], [2, 3, 1]),
([1, 0, 2], [None, 1, None]),
([1, 2, 0], [1, 3, 2]),
([1, 2, 0], [None, 1, 1]),
([1, 2, 0], [None, 2, 1]),
([1, 2, 0], [None, 2, 2]),
([2, 0, 1], [1, None, 2]),
([2, 0, 1], [2, 1, 3]),
([2, 0, 1], [None, None, 1]),
([2, 1, 0], [1, 2, 3]),
([2, 1, 0], [None, 1, 2]),
)
@pytest.fixture(autouse=True)
def reset_middleware():
yield
Middleware.reset_count()
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_request_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_request(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order == expected
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_response_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, response, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_response(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order[::-1] == expected