Begin middleware revamp

This commit is contained in:
Adam Hopkins 2022-08-07 22:31:26 +03:00
parent 2f6f2bfa76
commit c72cbe4326
No known key found for this signature in database
GPG Key ID: 9F85EE6C807303FB
5 changed files with 208 additions and 98 deletions

View File

@ -701,7 +701,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
async def handle_exception( async def handle_exception(
self, request: Request, exception: BaseException self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov ): # no cov
""" """
A handler that catches specific exceptions and outputs a response. A handler that catches specific exceptions and outputs a response.
@ -710,6 +713,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
:param exception: The exception that was raised :param exception: The exception that was raised
:raises ServerError: response 500 :raises ServerError: response 500
""" """
response = None
await self.dispatch( await self.dispatch(
"http.lifecycle.exception", "http.lifecycle.exception",
inline=True, inline=True,
@ -750,9 +754,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware
# -------------------------------------------- # # -------------------------------------------- #
response = await self._run_request_middleware( if (
request, request_name=None run_middleware
) and request.route
and request.route.extra.request_middleware
):
response = await self._run_request_middleware(request)
# No middleware results # No middleware results
if not response: if not response:
try: try:
@ -832,7 +839,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# Define `response` var here to remove warnings about # Define `response` var here to remove warnings about
# allocation before assignment below. # allocation before assignment below.
response = None response: Optional[
Union[
BaseHTTPResponse,
Coroutine[Any, Any, Optional[BaseHTTPResponse]],
]
] = None
try: try:
await self.dispatch( await self.dispatch(
@ -877,9 +889,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware
# -------------------------------------------- # # -------------------------------------------- #
response = await self._run_request_middleware( if request.route.extra.request_middleware:
request, request_name=route.name response = await self._run_request_middleware(request)
)
# No middleware results # No middleware results
if not response: if not response:
@ -910,7 +921,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
if request.stream is not None: if request.stream is not None:
response = request.stream.response response = request.stream.response
elif response is not None: elif response is not None:
response = await request.respond(response) response = await request.respond(response) # type: ignore
elif not hasattr(handler, "is_websocket"): elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore response = request.stream.response # type: ignore
@ -928,7 +939,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
... ...
await response.send(end_stream=True) await response.send(end_stream=True)
elif isinstance(response, ResponseStream): elif isinstance(response, ResponseStream):
resp = await response(request) resp = await response(request) # type: ignore
await self.dispatch( await self.dispatch(
"http.lifecycle.response", "http.lifecycle.response",
inline=True, inline=True,
@ -937,7 +948,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"response": resp, "response": resp,
}, },
) )
await response.eof() await response.eof() # type: ignore
else: else:
if not hasattr(handler, "is_websocket"): if not hasattr(handler, "is_websocket"):
raise ServerError( raise ServerError(
@ -949,7 +960,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
raise raise
except Exception as e: except Exception as e:
# Response Generation Failed # Response Generation Failed
await self.handle_exception(request, e) await self.handle_exception(request, e, run_middleware=False)
async def _websocket_handler( async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs self, handler, request, *args, subprotocols=None, **kwargs
@ -1017,87 +1028,69 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# Execution # Execution
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
async def _run_request_middleware( async def _run_request_middleware(self, request): # no cov
self, request, request_name=None request._request_middleware_started = True
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
# request.request_middleware_started is meant as a stop-gap solution for middleware in request.route.extra.request_middleware:
# until RFC 1630 is adopted await self.dispatch(
if applicable_middleware and not request.request_middleware_started: "http.middleware.before",
request.request_middleware_started = True inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in applicable_middleware: response = middleware(request)
await self.dispatch( if isawaitable(response):
"http.middleware.before", response = await response
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request) await self.dispatch(
if isawaitable(response): "http.middleware.after",
response = await response inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch( if response:
"http.middleware.after", return response
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
return None return None
async def _run_response_middleware( async def _run_response_middleware(self, request, response): # no cov
self, request, response, request_name=None for middleware in request.route.extra.response_middleware:
): # no cov await self.dispatch(
named_middleware = self.named_response_middleware.get( "http.middleware.before",
request_name, deque() inline=True,
) context={
applicable_middleware = self.response_middleware + named_middleware "request": request,
if applicable_middleware: "response": response,
for middleware in applicable_middleware: },
await self.dispatch( condition={"attach_to": "response"},
"http.middleware.before", )
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response) _response = middleware(request, response)
if isawaitable(_response): if isawaitable(_response):
_response = await _response _response = await _response
await self.dispatch( await self.dispatch(
"http.middleware.after", "http.middleware.after",
inline=True, inline=True,
context={ context={
"request": request, "request": request,
"response": _response if _response else response, "response": _response if _response else response,
}, },
condition={"attach_to": "response"}, condition={"attach_to": "response"},
) )
if _response: if _response:
response = _response response = _response
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response) response = request.stream.respond(response)
break break
return response return response
def _build_endpoint_name(self, *parts): def _build_endpoint_name(self, *parts):
@ -1495,6 +1488,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
except FinalizationError as e: except FinalizationError as e:
if not Sanic.test_mode: if not Sanic.test_mode:
raise e raise e
self.finalize_middleware()
def signalize(self, allow_fail_builtin=True): def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin self.signal_router.allow_fail_builtin = allow_fail_builtin

61
sanic/middleware.py Normal file
View File

@ -0,0 +1,61 @@
from __future__ import annotations
from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Optional, Sequence, Union
from sanic.models.handler_types import MiddlewareType
class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()
class Middleware:
counter = count()
def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware.counter)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)
@property
def order(self):
return (self.priority, -self.definition)
@classmethod
def convert(
cls,
*middleware_collections: Sequence[
Optional[Union[Middleware, MiddlewareType]]
],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)

View File

@ -1,11 +1,17 @@
from collections import deque
from functools import partial from functools import partial
from operator import attrgetter
from typing import List from typing import List
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware from sanic.models.futures import FutureMiddleware
from sanic.router import Router
class MiddlewareMixin(metaclass=SanicMeta): class MiddlewareMixin(metaclass=SanicMeta):
router: Router
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self._future_middleware: List[FutureMiddleware] = [] self._future_middleware: List[FutureMiddleware] = []
@ -13,7 +19,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
raise NotImplementedError # noqa raise NotImplementedError # noqa
def middleware( def middleware(
self, middleware_or_request, attach_to="request", apply=True self,
middleware_or_request,
attach_to="request",
apply=True,
*,
priority=0
): ):
""" """
Decorate and register middleware to be called before a request Decorate and register middleware to be called before a request
@ -30,6 +41,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
def register_middleware(middleware, attach_to="request"): def register_middleware(middleware, attach_to="request"):
nonlocal apply nonlocal apply
location = (
MiddlewareLocation.REQUEST
if attach_to == "request"
else MiddlewareLocation.RESPONSE
)
middleware = Middleware(middleware, location, priority=priority)
future_middleware = FutureMiddleware(middleware, attach_to) future_middleware = FutureMiddleware(middleware, attach_to)
self._future_middleware.append(future_middleware) self._future_middleware.append(future_middleware)
if apply: if apply:
@ -46,7 +63,7 @@ class MiddlewareMixin(metaclass=SanicMeta):
register_middleware, attach_to=middleware_or_request register_middleware, attach_to=middleware_or_request
) )
def on_request(self, middleware=None): def on_request(self, middleware=None, *, priority=0):
"""Register a middleware to be called before a request is handled. """Register a middleware to be called before a request is handled.
This is the same as *@app.middleware('request')*. This is the same as *@app.middleware('request')*.
@ -54,11 +71,13 @@ class MiddlewareMixin(metaclass=SanicMeta):
:param: middleware: A callable that takes in request. :param: middleware: A callable that takes in request.
""" """
if callable(middleware): if callable(middleware):
return self.middleware(middleware, "request") return self.middleware(middleware, "request", priority=priority)
else: else:
return partial(self.middleware, attach_to="request") return partial(
self.middleware, attach_to="request", priority=priority
)
def on_response(self, middleware=None): def on_response(self, middleware=None, *, priority=0):
"""Register a middleware to be called after a response is created. """Register a middleware to be called after a response is created.
This is the same as *@app.middleware('response')*. This is the same as *@app.middleware('response')*.
@ -67,6 +86,31 @@ class MiddlewareMixin(metaclass=SanicMeta):
A callable that takes in a request and its response. A callable that takes in a request and its response.
""" """
if callable(middleware): if callable(middleware):
return self.middleware(middleware, "response") return self.middleware(middleware, "response", priority=priority)
else: else:
return partial(self.middleware, attach_to="response") return partial(
self.middleware, attach_to="response", priority=priority
)
def finalize_middleware(self):
for route in self.router.routes:
request_middleware = Middleware.convert(
self.request_middleware,
self.named_request_middleware.get(route.name, deque()),
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE,
)
route.extra.request_middleware = sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
route.extra.response_middleware = sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]

View File

@ -51,7 +51,7 @@ from sanic.headers import (
parse_xforwarded, parse_xforwarded,
) )
from sanic.http import Stage from sanic.http import Stage
from sanic.log import error_logger, logger from sanic.log import deprecation, error_logger, logger
from sanic.models.protocol_types import TransportProtocol from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse from sanic.response import BaseHTTPResponse, HTTPResponse
@ -98,6 +98,7 @@ class Request:
"_port", "_port",
"_protocol", "_protocol",
"_remote_addr", "_remote_addr",
"_request_middleware_started",
"_scheme", "_scheme",
"_socket", "_socket",
"_stream_id", "_stream_id",
@ -121,7 +122,6 @@ class Request:
"parsed_token", "parsed_token",
"raw_url", "raw_url",
"responded", "responded",
"request_middleware_started",
"route", "route",
"stream", "stream",
"transport", "transport",
@ -173,7 +173,7 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[ self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]] Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list) ] = defaultdict(list)
self.request_middleware_started = False self._request_middleware_started = False
self.responded: bool = False self.responded: bool = False
self.route: Optional[Route] = None self.route: Optional[Route] = None
self.stream: Optional[Stream] = None self.stream: Optional[Stream] = None
@ -214,6 +214,16 @@ class Request:
def generate_id(*_): def generate_id(*_):
return uuid.uuid4() return uuid.uuid4()
@property
def request_middleware_started(self):
deprecation(
"Request.request_middleware_started has been deprecated and will"
"be removed. You should set a flag on the request context using"
"either middleware or signals if you need this feature.",
22.3,
)
return self._request_middleware_started
@property @property
def stream_id(self): def stream_id(self):
""" """
@ -319,9 +329,10 @@ class Request:
response = await response # type: ignore response = await response # type: ignore
# Run response middleware # Run response middleware
try: try:
response = await self.app._run_response_middleware( if self.route and self.route.extra.response_middleware:
self, response, request_name=self.name response = await self.app._run_response_middleware(
) self, response
)
except CancelledErrors: except CancelledErrors:
raise raise
except Exception: except Exception:

View File

@ -187,7 +187,7 @@ class SignalRouter(BaseRouter):
fail_not_found=fail_not_found and inline, fail_not_found=fail_not_found and inline,
reverse=reverse, reverse=reverse,
) )
logger.debug(f"Dispatching signal: {event}") logger.debug(f"Dispatching signal: {event}", extra={"verbosity": 1})
if inline: if inline:
return await dispatch return await dispatch