From c72cbe432661bcefa6b3359aeb2a5583b857db53 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 7 Aug 2022 22:31:26 +0300 Subject: [PATCH] Begin middleware revamp --- sanic/app.py | 162 ++++++++++++++++++------------------- sanic/middleware.py | 61 ++++++++++++++ sanic/mixins/middleware.py | 58 +++++++++++-- sanic/request.py | 23 ++++-- sanic/signals.py | 2 +- 5 files changed, 208 insertions(+), 98 deletions(-) create mode 100644 sanic/middleware.py diff --git a/sanic/app.py b/sanic/app.py index 59056eb2..4926edd4 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -701,7 +701,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): # -------------------------------------------------------------------- # async def handle_exception( - self, request: Request, exception: BaseException + self, + request: Request, + exception: BaseException, + run_middleware: bool = True, ): # no cov """ A handler that catches specific exceptions and outputs a response. @@ -710,6 +713,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): :param exception: The exception that was raised :raises ServerError: response 500 """ + response = None await self.dispatch( "http.lifecycle.exception", inline=True, @@ -750,9 +754,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): # -------------------------------------------- # # Request Middleware # -------------------------------------------- # - response = await self._run_request_middleware( - request, request_name=None - ) + if ( + run_middleware + and request.route + and request.route.extra.request_middleware + ): + response = await self._run_request_middleware(request) # No middleware results if not response: try: @@ -832,7 +839,12 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): # Define `response` var here to remove warnings about # allocation before assignment below. - response = None + response: Optional[ + Union[ + BaseHTTPResponse, + Coroutine[Any, Any, Optional[BaseHTTPResponse]], + ] + ] = None try: await self.dispatch( @@ -877,9 +889,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): # -------------------------------------------- # # Request Middleware # -------------------------------------------- # - response = await self._run_request_middleware( - request, request_name=route.name - ) + if request.route.extra.request_middleware: + response = await self._run_request_middleware(request) # No middleware results if not response: @@ -910,7 +921,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): if request.stream is not None: response = request.stream.response elif response is not None: - response = await request.respond(response) + response = await request.respond(response) # type: ignore elif not hasattr(handler, "is_websocket"): response = request.stream.response # type: ignore @@ -928,7 +939,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): ... await response.send(end_stream=True) elif isinstance(response, ResponseStream): - resp = await response(request) + resp = await response(request) # type: ignore await self.dispatch( "http.lifecycle.response", inline=True, @@ -937,7 +948,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): "response": resp, }, ) - await response.eof() + await response.eof() # type: ignore else: if not hasattr(handler, "is_websocket"): raise ServerError( @@ -949,7 +960,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): raise except Exception as e: # Response Generation Failed - await self.handle_exception(request, e) + await self.handle_exception(request, e, run_middleware=False) async def _websocket_handler( self, handler, request, *args, subprotocols=None, **kwargs @@ -1017,87 +1028,69 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): # Execution # -------------------------------------------------------------------- # - async def _run_request_middleware( - self, request, request_name=None - ): # no cov - # The if improves speed. I don't know why - named_middleware = self.named_request_middleware.get( - request_name, deque() - ) - applicable_middleware = self.request_middleware + named_middleware + async def _run_request_middleware(self, request): # no cov + request._request_middleware_started = True - # request.request_middleware_started is meant as a stop-gap solution - # until RFC 1630 is adopted - if applicable_middleware and not request.request_middleware_started: - request.request_middleware_started = True + for middleware in request.route.extra.request_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) - for middleware in applicable_middleware: - await self.dispatch( - "http.middleware.before", - inline=True, - context={ - "request": request, - "response": None, - }, - condition={"attach_to": "request"}, - ) + response = middleware(request) + if isawaitable(response): + response = await response - response = middleware(request) - if isawaitable(response): - response = await response + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) - await self.dispatch( - "http.middleware.after", - inline=True, - context={ - "request": request, - "response": None, - }, - condition={"attach_to": "request"}, - ) - - if response: - return response + if response: + return response return None - async def _run_response_middleware( - self, request, response, request_name=None - ): # no cov - named_middleware = self.named_response_middleware.get( - request_name, deque() - ) - applicable_middleware = self.response_middleware + named_middleware - if applicable_middleware: - for middleware in applicable_middleware: - await self.dispatch( - "http.middleware.before", - inline=True, - context={ - "request": request, - "response": response, - }, - condition={"attach_to": "response"}, - ) + async def _run_response_middleware(self, request, response): # no cov + for middleware in request.route.extra.response_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": response, + }, + condition={"attach_to": "response"}, + ) - _response = middleware(request, response) - if isawaitable(_response): - _response = await _response + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response - await self.dispatch( - "http.middleware.after", - inline=True, - context={ - "request": request, - "response": _response if _response else response, - }, - condition={"attach_to": "response"}, - ) + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": _response if _response else response, + }, + condition={"attach_to": "response"}, + ) - if _response: - response = _response - if isinstance(response, BaseHTTPResponse): - response = request.stream.respond(response) - break + if _response: + response = _response + if isinstance(response, BaseHTTPResponse): + response = request.stream.respond(response) + break return response def _build_endpoint_name(self, *parts): @@ -1495,6 +1488,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): except FinalizationError as e: if not Sanic.test_mode: raise e + self.finalize_middleware() def signalize(self, allow_fail_builtin=True): self.signal_router.allow_fail_builtin = allow_fail_builtin diff --git a/sanic/middleware.py b/sanic/middleware.py new file mode 100644 index 00000000..e3214c05 --- /dev/null +++ b/sanic/middleware.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from collections import deque +from enum import IntEnum, auto +from itertools import count +from typing import Deque, Optional, Sequence, Union + +from sanic.models.handler_types import MiddlewareType + + +class MiddlewareLocation(IntEnum): + REQUEST = auto() + RESPONSE = auto() + + +class Middleware: + counter = count() + + def __init__( + self, + func: MiddlewareType, + location: MiddlewareLocation, + priority: int = 0, + ) -> None: + self.func = func + self.priority = priority + self.location = location + self.definition = next(Middleware.counter) + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"func=, " + f"priority={self.priority}, " + f"location={self.location.name})" + ) + + @property + def order(self): + return (self.priority, -self.definition) + + @classmethod + def convert( + cls, + *middleware_collections: Sequence[ + Optional[Union[Middleware, MiddlewareType]] + ], + location: MiddlewareLocation, + ) -> Deque[Middleware]: + return deque( + [ + middleware + if isinstance(middleware, Middleware) + else Middleware(middleware, location) + for collection in middleware_collections + for middleware in collection + ] + ) diff --git a/sanic/mixins/middleware.py b/sanic/mixins/middleware.py index 5ef9dc77..7615477c 100644 --- a/sanic/mixins/middleware.py +++ b/sanic/mixins/middleware.py @@ -1,11 +1,17 @@ +from collections import deque from functools import partial +from operator import attrgetter from typing import List from sanic.base.meta import SanicMeta +from sanic.middleware import Middleware, MiddlewareLocation from sanic.models.futures import FutureMiddleware +from sanic.router import Router class MiddlewareMixin(metaclass=SanicMeta): + router: Router + def __init__(self, *args, **kwargs) -> None: self._future_middleware: List[FutureMiddleware] = [] @@ -13,7 +19,12 @@ class MiddlewareMixin(metaclass=SanicMeta): raise NotImplementedError # noqa def middleware( - self, middleware_or_request, attach_to="request", apply=True + self, + middleware_or_request, + attach_to="request", + apply=True, + *, + priority=0 ): """ Decorate and register middleware to be called before a request @@ -30,6 +41,12 @@ class MiddlewareMixin(metaclass=SanicMeta): def register_middleware(middleware, attach_to="request"): nonlocal apply + location = ( + MiddlewareLocation.REQUEST + if attach_to == "request" + else MiddlewareLocation.RESPONSE + ) + middleware = Middleware(middleware, location, priority=priority) future_middleware = FutureMiddleware(middleware, attach_to) self._future_middleware.append(future_middleware) if apply: @@ -46,7 +63,7 @@ class MiddlewareMixin(metaclass=SanicMeta): register_middleware, attach_to=middleware_or_request ) - def on_request(self, middleware=None): + def on_request(self, middleware=None, *, priority=0): """Register a middleware to be called before a request is handled. This is the same as *@app.middleware('request')*. @@ -54,11 +71,13 @@ class MiddlewareMixin(metaclass=SanicMeta): :param: middleware: A callable that takes in request. """ if callable(middleware): - return self.middleware(middleware, "request") + return self.middleware(middleware, "request", priority=priority) else: - return partial(self.middleware, attach_to="request") + return partial( + self.middleware, attach_to="request", priority=priority + ) - def on_response(self, middleware=None): + def on_response(self, middleware=None, *, priority=0): """Register a middleware to be called after a response is created. This is the same as *@app.middleware('response')*. @@ -67,6 +86,31 @@ class MiddlewareMixin(metaclass=SanicMeta): A callable that takes in a request and its response. """ if callable(middleware): - return self.middleware(middleware, "response") + return self.middleware(middleware, "response", priority=priority) else: - return partial(self.middleware, attach_to="response") + return partial( + self.middleware, attach_to="response", priority=priority + ) + + def finalize_middleware(self): + for route in self.router.routes: + request_middleware = Middleware.convert( + self.request_middleware, + self.named_request_middleware.get(route.name, deque()), + location=MiddlewareLocation.REQUEST, + ) + response_middleware = Middleware.convert( + self.response_middleware, + self.named_response_middleware.get(route.name, deque()), + location=MiddlewareLocation.RESPONSE, + ) + route.extra.request_middleware = sorted( + request_middleware, + key=attrgetter("order"), + reverse=True, + ) + route.extra.response_middleware = sorted( + response_middleware, + key=attrgetter("order"), + reverse=True, + )[::-1] diff --git a/sanic/request.py b/sanic/request.py index 927c124a..17016c97 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -51,7 +51,7 @@ from sanic.headers import ( parse_xforwarded, ) from sanic.http import Stage -from sanic.log import error_logger, logger +from sanic.log import deprecation, error_logger, logger from sanic.models.protocol_types import TransportProtocol from sanic.response import BaseHTTPResponse, HTTPResponse @@ -98,6 +98,7 @@ class Request: "_port", "_protocol", "_remote_addr", + "_request_middleware_started", "_scheme", "_socket", "_stream_id", @@ -121,7 +122,6 @@ class Request: "parsed_token", "raw_url", "responded", - "request_middleware_started", "route", "stream", "transport", @@ -173,7 +173,7 @@ class Request: self.parsed_not_grouped_args: DefaultDict[ Tuple[bool, bool, str, str], List[Tuple[str, str]] ] = defaultdict(list) - self.request_middleware_started = False + self._request_middleware_started = False self.responded: bool = False self.route: Optional[Route] = None self.stream: Optional[Stream] = None @@ -214,6 +214,16 @@ class Request: def generate_id(*_): return uuid.uuid4() + @property + def request_middleware_started(self): + deprecation( + "Request.request_middleware_started has been deprecated and will" + "be removed. You should set a flag on the request context using" + "either middleware or signals if you need this feature.", + 22.3, + ) + return self._request_middleware_started + @property def stream_id(self): """ @@ -319,9 +329,10 @@ class Request: response = await response # type: ignore # Run response middleware try: - response = await self.app._run_response_middleware( - self, response, request_name=self.name - ) + if self.route and self.route.extra.response_middleware: + response = await self.app._run_response_middleware( + self, response + ) except CancelledErrors: raise except Exception: diff --git a/sanic/signals.py b/sanic/signals.py index 80c6300b..302bec2d 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -187,7 +187,7 @@ class SignalRouter(BaseRouter): fail_not_found=fail_not_found and inline, reverse=reverse, ) - logger.debug(f"Dispatching signal: {event}") + logger.debug(f"Dispatching signal: {event}", extra={"verbosity": 1}) if inline: return await dispatch