Begin middleware revamp
This commit is contained in:
parent
2f6f2bfa76
commit
c72cbe4326
162
sanic/app.py
162
sanic/app.py
|
@ -701,7 +701,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
# -------------------------------------------------------------------- #
|
||||
|
||||
async def handle_exception(
|
||||
self, request: Request, exception: BaseException
|
||||
self,
|
||||
request: Request,
|
||||
exception: BaseException,
|
||||
run_middleware: bool = True,
|
||||
): # no cov
|
||||
"""
|
||||
A handler that catches specific exceptions and outputs a response.
|
||||
|
@ -710,6 +713,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
:param exception: The exception that was raised
|
||||
:raises ServerError: response 500
|
||||
"""
|
||||
response = None
|
||||
await self.dispatch(
|
||||
"http.lifecycle.exception",
|
||||
inline=True,
|
||||
|
@ -750,9 +754,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
response = await self._run_request_middleware(
|
||||
request, request_name=None
|
||||
)
|
||||
if (
|
||||
run_middleware
|
||||
and request.route
|
||||
and request.route.extra.request_middleware
|
||||
):
|
||||
response = await self._run_request_middleware(request)
|
||||
# No middleware results
|
||||
if not response:
|
||||
try:
|
||||
|
@ -832,7 +839,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
|
||||
# Define `response` var here to remove warnings about
|
||||
# allocation before assignment below.
|
||||
response = None
|
||||
response: Optional[
|
||||
Union[
|
||||
BaseHTTPResponse,
|
||||
Coroutine[Any, Any, Optional[BaseHTTPResponse]],
|
||||
]
|
||||
] = None
|
||||
try:
|
||||
|
||||
await self.dispatch(
|
||||
|
@ -877,9 +889,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
response = await self._run_request_middleware(
|
||||
request, request_name=route.name
|
||||
)
|
||||
if request.route.extra.request_middleware:
|
||||
response = await self._run_request_middleware(request)
|
||||
|
||||
# No middleware results
|
||||
if not response:
|
||||
|
@ -910,7 +921,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
if request.stream is not None:
|
||||
response = request.stream.response
|
||||
elif response is not None:
|
||||
response = await request.respond(response)
|
||||
response = await request.respond(response) # type: ignore
|
||||
elif not hasattr(handler, "is_websocket"):
|
||||
response = request.stream.response # type: ignore
|
||||
|
||||
|
@ -928,7 +939,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
...
|
||||
await response.send(end_stream=True)
|
||||
elif isinstance(response, ResponseStream):
|
||||
resp = await response(request)
|
||||
resp = await response(request) # type: ignore
|
||||
await self.dispatch(
|
||||
"http.lifecycle.response",
|
||||
inline=True,
|
||||
|
@ -937,7 +948,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
"response": resp,
|
||||
},
|
||||
)
|
||||
await response.eof()
|
||||
await response.eof() # type: ignore
|
||||
else:
|
||||
if not hasattr(handler, "is_websocket"):
|
||||
raise ServerError(
|
||||
|
@ -949,7 +960,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
raise
|
||||
except Exception as e:
|
||||
# Response Generation Failed
|
||||
await self.handle_exception(request, e)
|
||||
await self.handle_exception(request, e, run_middleware=False)
|
||||
|
||||
async def _websocket_handler(
|
||||
self, handler, request, *args, subprotocols=None, **kwargs
|
||||
|
@ -1017,87 +1028,69 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
# Execution
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
async def _run_request_middleware(
|
||||
self, request, request_name=None
|
||||
): # no cov
|
||||
# The if improves speed. I don't know why
|
||||
named_middleware = self.named_request_middleware.get(
|
||||
request_name, deque()
|
||||
)
|
||||
applicable_middleware = self.request_middleware + named_middleware
|
||||
async def _run_request_middleware(self, request): # no cov
|
||||
request._request_middleware_started = True
|
||||
|
||||
# request.request_middleware_started is meant as a stop-gap solution
|
||||
# until RFC 1630 is adopted
|
||||
if applicable_middleware and not request.request_middleware_started:
|
||||
request.request_middleware_started = True
|
||||
for middleware in request.route.extra.request_middleware:
|
||||
await self.dispatch(
|
||||
"http.middleware.before",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": None,
|
||||
},
|
||||
condition={"attach_to": "request"},
|
||||
)
|
||||
|
||||
for middleware in applicable_middleware:
|
||||
await self.dispatch(
|
||||
"http.middleware.before",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": None,
|
||||
},
|
||||
condition={"attach_to": "request"},
|
||||
)
|
||||
response = middleware(request)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
|
||||
response = middleware(request)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
await self.dispatch(
|
||||
"http.middleware.after",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": None,
|
||||
},
|
||||
condition={"attach_to": "request"},
|
||||
)
|
||||
|
||||
await self.dispatch(
|
||||
"http.middleware.after",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": None,
|
||||
},
|
||||
condition={"attach_to": "request"},
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
if response:
|
||||
return response
|
||||
return None
|
||||
|
||||
async def _run_response_middleware(
|
||||
self, request, response, request_name=None
|
||||
): # no cov
|
||||
named_middleware = self.named_response_middleware.get(
|
||||
request_name, deque()
|
||||
)
|
||||
applicable_middleware = self.response_middleware + named_middleware
|
||||
if applicable_middleware:
|
||||
for middleware in applicable_middleware:
|
||||
await self.dispatch(
|
||||
"http.middleware.before",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": response,
|
||||
},
|
||||
condition={"attach_to": "response"},
|
||||
)
|
||||
async def _run_response_middleware(self, request, response): # no cov
|
||||
for middleware in request.route.extra.response_middleware:
|
||||
await self.dispatch(
|
||||
"http.middleware.before",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": response,
|
||||
},
|
||||
condition={"attach_to": "response"},
|
||||
)
|
||||
|
||||
_response = middleware(request, response)
|
||||
if isawaitable(_response):
|
||||
_response = await _response
|
||||
_response = middleware(request, response)
|
||||
if isawaitable(_response):
|
||||
_response = await _response
|
||||
|
||||
await self.dispatch(
|
||||
"http.middleware.after",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": _response if _response else response,
|
||||
},
|
||||
condition={"attach_to": "response"},
|
||||
)
|
||||
await self.dispatch(
|
||||
"http.middleware.after",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": _response if _response else response,
|
||||
},
|
||||
condition={"attach_to": "response"},
|
||||
)
|
||||
|
||||
if _response:
|
||||
response = _response
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
response = request.stream.respond(response)
|
||||
break
|
||||
if _response:
|
||||
response = _response
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
response = request.stream.respond(response)
|
||||
break
|
||||
return response
|
||||
|
||||
def _build_endpoint_name(self, *parts):
|
||||
|
@ -1495,6 +1488,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
|
|||
except FinalizationError as e:
|
||||
if not Sanic.test_mode:
|
||||
raise e
|
||||
self.finalize_middleware()
|
||||
|
||||
def signalize(self, allow_fail_builtin=True):
|
||||
self.signal_router.allow_fail_builtin = allow_fail_builtin
|
||||
|
|
61
sanic/middleware.py
Normal file
61
sanic/middleware.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from enum import IntEnum, auto
|
||||
from itertools import count
|
||||
from typing import Deque, Optional, Sequence, Union
|
||||
|
||||
from sanic.models.handler_types import MiddlewareType
|
||||
|
||||
|
||||
class MiddlewareLocation(IntEnum):
|
||||
REQUEST = auto()
|
||||
RESPONSE = auto()
|
||||
|
||||
|
||||
class Middleware:
|
||||
counter = count()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: MiddlewareType,
|
||||
location: MiddlewareLocation,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
self.func = func
|
||||
self.priority = priority
|
||||
self.location = location
|
||||
self.definition = next(Middleware.counter)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"func=<function {self.func.__name__}>, "
|
||||
f"priority={self.priority}, "
|
||||
f"location={self.location.name})"
|
||||
)
|
||||
|
||||
@property
|
||||
def order(self):
|
||||
return (self.priority, -self.definition)
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls,
|
||||
*middleware_collections: Sequence[
|
||||
Optional[Union[Middleware, MiddlewareType]]
|
||||
],
|
||||
location: MiddlewareLocation,
|
||||
) -> Deque[Middleware]:
|
||||
return deque(
|
||||
[
|
||||
middleware
|
||||
if isinstance(middleware, Middleware)
|
||||
else Middleware(middleware, location)
|
||||
for collection in middleware_collections
|
||||
for middleware in collection
|
||||
]
|
||||
)
|
|
@ -1,11 +1,17 @@
|
|||
from collections import deque
|
||||
from functools import partial
|
||||
from operator import attrgetter
|
||||
from typing import List
|
||||
|
||||
from sanic.base.meta import SanicMeta
|
||||
from sanic.middleware import Middleware, MiddlewareLocation
|
||||
from sanic.models.futures import FutureMiddleware
|
||||
from sanic.router import Router
|
||||
|
||||
|
||||
class MiddlewareMixin(metaclass=SanicMeta):
|
||||
router: Router
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._future_middleware: List[FutureMiddleware] = []
|
||||
|
||||
|
@ -13,7 +19,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
|
|||
raise NotImplementedError # noqa
|
||||
|
||||
def middleware(
|
||||
self, middleware_or_request, attach_to="request", apply=True
|
||||
self,
|
||||
middleware_or_request,
|
||||
attach_to="request",
|
||||
apply=True,
|
||||
*,
|
||||
priority=0
|
||||
):
|
||||
"""
|
||||
Decorate and register middleware to be called before a request
|
||||
|
@ -30,6 +41,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
|
|||
def register_middleware(middleware, attach_to="request"):
|
||||
nonlocal apply
|
||||
|
||||
location = (
|
||||
MiddlewareLocation.REQUEST
|
||||
if attach_to == "request"
|
||||
else MiddlewareLocation.RESPONSE
|
||||
)
|
||||
middleware = Middleware(middleware, location, priority=priority)
|
||||
future_middleware = FutureMiddleware(middleware, attach_to)
|
||||
self._future_middleware.append(future_middleware)
|
||||
if apply:
|
||||
|
@ -46,7 +63,7 @@ class MiddlewareMixin(metaclass=SanicMeta):
|
|||
register_middleware, attach_to=middleware_or_request
|
||||
)
|
||||
|
||||
def on_request(self, middleware=None):
|
||||
def on_request(self, middleware=None, *, priority=0):
|
||||
"""Register a middleware to be called before a request is handled.
|
||||
|
||||
This is the same as *@app.middleware('request')*.
|
||||
|
@ -54,11 +71,13 @@ class MiddlewareMixin(metaclass=SanicMeta):
|
|||
:param: middleware: A callable that takes in request.
|
||||
"""
|
||||
if callable(middleware):
|
||||
return self.middleware(middleware, "request")
|
||||
return self.middleware(middleware, "request", priority=priority)
|
||||
else:
|
||||
return partial(self.middleware, attach_to="request")
|
||||
return partial(
|
||||
self.middleware, attach_to="request", priority=priority
|
||||
)
|
||||
|
||||
def on_response(self, middleware=None):
|
||||
def on_response(self, middleware=None, *, priority=0):
|
||||
"""Register a middleware to be called after a response is created.
|
||||
|
||||
This is the same as *@app.middleware('response')*.
|
||||
|
@ -67,6 +86,31 @@ class MiddlewareMixin(metaclass=SanicMeta):
|
|||
A callable that takes in a request and its response.
|
||||
"""
|
||||
if callable(middleware):
|
||||
return self.middleware(middleware, "response")
|
||||
return self.middleware(middleware, "response", priority=priority)
|
||||
else:
|
||||
return partial(self.middleware, attach_to="response")
|
||||
return partial(
|
||||
self.middleware, attach_to="response", priority=priority
|
||||
)
|
||||
|
||||
def finalize_middleware(self):
|
||||
for route in self.router.routes:
|
||||
request_middleware = Middleware.convert(
|
||||
self.request_middleware,
|
||||
self.named_request_middleware.get(route.name, deque()),
|
||||
location=MiddlewareLocation.REQUEST,
|
||||
)
|
||||
response_middleware = Middleware.convert(
|
||||
self.response_middleware,
|
||||
self.named_response_middleware.get(route.name, deque()),
|
||||
location=MiddlewareLocation.RESPONSE,
|
||||
)
|
||||
route.extra.request_middleware = sorted(
|
||||
request_middleware,
|
||||
key=attrgetter("order"),
|
||||
reverse=True,
|
||||
)
|
||||
route.extra.response_middleware = sorted(
|
||||
response_middleware,
|
||||
key=attrgetter("order"),
|
||||
reverse=True,
|
||||
)[::-1]
|
||||
|
|
|
@ -51,7 +51,7 @@ from sanic.headers import (
|
|||
parse_xforwarded,
|
||||
)
|
||||
from sanic.http import Stage
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.log import deprecation, error_logger, logger
|
||||
from sanic.models.protocol_types import TransportProtocol
|
||||
from sanic.response import BaseHTTPResponse, HTTPResponse
|
||||
|
||||
|
@ -98,6 +98,7 @@ class Request:
|
|||
"_port",
|
||||
"_protocol",
|
||||
"_remote_addr",
|
||||
"_request_middleware_started",
|
||||
"_scheme",
|
||||
"_socket",
|
||||
"_stream_id",
|
||||
|
@ -121,7 +122,6 @@ class Request:
|
|||
"parsed_token",
|
||||
"raw_url",
|
||||
"responded",
|
||||
"request_middleware_started",
|
||||
"route",
|
||||
"stream",
|
||||
"transport",
|
||||
|
@ -173,7 +173,7 @@ class Request:
|
|||
self.parsed_not_grouped_args: DefaultDict[
|
||||
Tuple[bool, bool, str, str], List[Tuple[str, str]]
|
||||
] = defaultdict(list)
|
||||
self.request_middleware_started = False
|
||||
self._request_middleware_started = False
|
||||
self.responded: bool = False
|
||||
self.route: Optional[Route] = None
|
||||
self.stream: Optional[Stream] = None
|
||||
|
@ -214,6 +214,16 @@ class Request:
|
|||
def generate_id(*_):
|
||||
return uuid.uuid4()
|
||||
|
||||
@property
|
||||
def request_middleware_started(self):
|
||||
deprecation(
|
||||
"Request.request_middleware_started has been deprecated and will"
|
||||
"be removed. You should set a flag on the request context using"
|
||||
"either middleware or signals if you need this feature.",
|
||||
22.3,
|
||||
)
|
||||
return self._request_middleware_started
|
||||
|
||||
@property
|
||||
def stream_id(self):
|
||||
"""
|
||||
|
@ -319,9 +329,10 @@ class Request:
|
|||
response = await response # type: ignore
|
||||
# Run response middleware
|
||||
try:
|
||||
response = await self.app._run_response_middleware(
|
||||
self, response, request_name=self.name
|
||||
)
|
||||
if self.route and self.route.extra.response_middleware:
|
||||
response = await self.app._run_response_middleware(
|
||||
self, response
|
||||
)
|
||||
except CancelledErrors:
|
||||
raise
|
||||
except Exception:
|
||||
|
|
|
@ -187,7 +187,7 @@ class SignalRouter(BaseRouter):
|
|||
fail_not_found=fail_not_found and inline,
|
||||
reverse=reverse,
|
||||
)
|
||||
logger.debug(f"Dispatching signal: {event}")
|
||||
logger.debug(f"Dispatching signal: {event}", extra={"verbosity": 1})
|
||||
|
||||
if inline:
|
||||
return await dispatch
|
||||
|
|
Loading…
Reference in New Issue
Block a user