Begin middleware revamp

This commit is contained in:
Adam Hopkins 2022-08-07 22:31:26 +03:00
parent 2f6f2bfa76
commit c72cbe4326
No known key found for this signature in database
GPG Key ID: 9F85EE6C807303FB
5 changed files with 208 additions and 98 deletions

View File

@ -701,7 +701,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- #
async def handle_exception(
self, request: Request, exception: BaseException
self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
@ -710,6 +713,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
:param exception: The exception that was raised
:raises ServerError: response 500
"""
response = None
await self.dispatch(
"http.lifecycle.exception",
inline=True,
@ -750,9 +754,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
if (
run_middleware
and request.route
and request.route.extra.request_middleware
):
response = await self._run_request_middleware(request)
# No middleware results
if not response:
try:
@ -832,7 +839,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
response: Optional[
Union[
BaseHTTPResponse,
Coroutine[Any, Any, Optional[BaseHTTPResponse]],
]
] = None
try:
await self.dispatch(
@ -877,9 +889,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=route.name
)
if request.route.extra.request_middleware:
response = await self._run_request_middleware(request)
# No middleware results
if not response:
@ -910,7 +921,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
if request.stream is not None:
response = request.stream.response
elif response is not None:
response = await request.respond(response)
response = await request.respond(response) # type: ignore
elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore
@ -928,7 +939,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
resp = await response(request) # type: ignore
await self.dispatch(
"http.lifecycle.response",
inline=True,
@ -937,7 +948,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"response": resp,
},
)
await response.eof()
await response.eof() # type: ignore
else:
if not hasattr(handler, "is_websocket"):
raise ServerError(
@ -949,7 +960,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
raise
except Exception as e:
# Response Generation Failed
await self.handle_exception(request, e)
await self.handle_exception(request, e, run_middleware=False)
async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
@ -1017,87 +1028,69 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
# Execution
# -------------------------------------------------------------------- #
async def _run_request_middleware(
self, request, request_name=None
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
async def _run_request_middleware(self, request): # no cov
request._request_middleware_started = True
# request.request_middleware_started is meant as a stop-gap solution
# until RFC 1630 is adopted
if applicable_middleware and not request.request_middleware_started:
request.request_middleware_started = True
for middleware in request.route.extra.request_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request)
if isawaitable(response):
response = await response
response = middleware(request)
if isawaitable(response):
response = await response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
if response:
return response
return None
async def _run_response_middleware(
self, request, response, request_name=None
): # no cov
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
async def _run_response_middleware(self, request, response): # no cov
for middleware in request.route.extra.response_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
return response
def _build_endpoint_name(self, *parts):
@ -1495,6 +1488,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
except FinalizationError as e:
if not Sanic.test_mode:
raise e
self.finalize_middleware()
def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin

61
sanic/middleware.py Normal file
View File

@ -0,0 +1,61 @@
from __future__ import annotations
from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Optional, Sequence, Union
from sanic.models.handler_types import MiddlewareType
class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()
class Middleware:
counter = count()
def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware.counter)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)
@property
def order(self):
return (self.priority, -self.definition)
@classmethod
def convert(
cls,
*middleware_collections: Sequence[
Optional[Union[Middleware, MiddlewareType]]
],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)

View File

@ -1,11 +1,17 @@
from collections import deque
from functools import partial
from operator import attrgetter
from typing import List
from sanic.base.meta import SanicMeta
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware
from sanic.router import Router
class MiddlewareMixin(metaclass=SanicMeta):
router: Router
def __init__(self, *args, **kwargs) -> None:
self._future_middleware: List[FutureMiddleware] = []
@ -13,7 +19,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
raise NotImplementedError # noqa
def middleware(
self, middleware_or_request, attach_to="request", apply=True
self,
middleware_or_request,
attach_to="request",
apply=True,
*,
priority=0
):
"""
Decorate and register middleware to be called before a request
@ -30,6 +41,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
def register_middleware(middleware, attach_to="request"):
nonlocal apply
location = (
MiddlewareLocation.REQUEST
if attach_to == "request"
else MiddlewareLocation.RESPONSE
)
middleware = Middleware(middleware, location, priority=priority)
future_middleware = FutureMiddleware(middleware, attach_to)
self._future_middleware.append(future_middleware)
if apply:
@ -46,7 +63,7 @@ class MiddlewareMixin(metaclass=SanicMeta):
register_middleware, attach_to=middleware_or_request
)
def on_request(self, middleware=None):
def on_request(self, middleware=None, *, priority=0):
"""Register a middleware to be called before a request is handled.
This is the same as *@app.middleware('request')*.
@ -54,11 +71,13 @@ class MiddlewareMixin(metaclass=SanicMeta):
:param: middleware: A callable that takes in request.
"""
if callable(middleware):
return self.middleware(middleware, "request")
return self.middleware(middleware, "request", priority=priority)
else:
return partial(self.middleware, attach_to="request")
return partial(
self.middleware, attach_to="request", priority=priority
)
def on_response(self, middleware=None):
def on_response(self, middleware=None, *, priority=0):
"""Register a middleware to be called after a response is created.
This is the same as *@app.middleware('response')*.
@ -67,6 +86,31 @@ class MiddlewareMixin(metaclass=SanicMeta):
A callable that takes in a request and its response.
"""
if callable(middleware):
return self.middleware(middleware, "response")
return self.middleware(middleware, "response", priority=priority)
else:
return partial(self.middleware, attach_to="response")
return partial(
self.middleware, attach_to="response", priority=priority
)
def finalize_middleware(self):
for route in self.router.routes:
request_middleware = Middleware.convert(
self.request_middleware,
self.named_request_middleware.get(route.name, deque()),
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE,
)
route.extra.request_middleware = sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
route.extra.response_middleware = sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]

View File

@ -51,7 +51,7 @@ from sanic.headers import (
parse_xforwarded,
)
from sanic.http import Stage
from sanic.log import error_logger, logger
from sanic.log import deprecation, error_logger, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse
@ -98,6 +98,7 @@ class Request:
"_port",
"_protocol",
"_remote_addr",
"_request_middleware_started",
"_scheme",
"_socket",
"_stream_id",
@ -121,7 +122,6 @@ class Request:
"parsed_token",
"raw_url",
"responded",
"request_middleware_started",
"route",
"stream",
"transport",
@ -173,7 +173,7 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list)
self.request_middleware_started = False
self._request_middleware_started = False
self.responded: bool = False
self.route: Optional[Route] = None
self.stream: Optional[Stream] = None
@ -214,6 +214,16 @@ class Request:
def generate_id(*_):
return uuid.uuid4()
@property
def request_middleware_started(self):
deprecation(
"Request.request_middleware_started has been deprecated and will"
"be removed. You should set a flag on the request context using"
"either middleware or signals if you need this feature.",
22.3,
)
return self._request_middleware_started
@property
def stream_id(self):
"""
@ -319,9 +329,10 @@ class Request:
response = await response # type: ignore
# Run response middleware
try:
response = await self.app._run_response_middleware(
self, response, request_name=self.name
)
if self.route and self.route.extra.response_middleware:
response = await self.app._run_response_middleware(
self, response
)
except CancelledErrors:
raise
except Exception:

View File

@ -187,7 +187,7 @@ class SignalRouter(BaseRouter):
fail_not_found=fail_not_found and inline,
reverse=reverse,
)
logger.debug(f"Dispatching signal: {event}")
logger.debug(f"Dispatching signal: {event}", extra={"verbosity": 1})
if inline:
return await dispatch