Compare commits
	
		
			12 Commits
		
	
	
		
			remove-get
			...
			middleware
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 38b4ccf2bc | ||
|   | 8b970dd490 | ||
|   | c9be17e8da | ||
|   | 19f642b364 | ||
|   | c4c39cb082 | ||
|   | c7bac72137 | ||
|   | beb5c62767 | ||
|   | 09b59d34fe | ||
|   | 78bc475bb1 | ||
|   | b59131504b | ||
|   | 782e0881e5 | ||
|   | c72cbe4326 | 
							
								
								
									
										297
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										297
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -21,7 +21,6 @@ from functools import partial | ||||
| from inspect import isawaitable | ||||
| from os import environ | ||||
| from socket import socket | ||||
| from traceback import format_exc | ||||
| from types import SimpleNamespace | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
| @@ -54,12 +53,7 @@ from sanic.blueprint_group import BlueprintGroup | ||||
| from sanic.blueprints import Blueprint | ||||
| from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support | ||||
| from sanic.config import SANIC_PREFIX, Config | ||||
| from sanic.exceptions import ( | ||||
|     BadRequest, | ||||
|     SanicException, | ||||
|     ServerError, | ||||
|     URLBuildError, | ||||
| ) | ||||
| from sanic.exceptions import BadRequest, SanicException, URLBuildError | ||||
| from sanic.handlers import ErrorHandler | ||||
| from sanic.helpers import _default | ||||
| from sanic.http import Stage | ||||
| @@ -83,7 +77,7 @@ from sanic.models.futures import ( | ||||
| from sanic.models.handler_types import ListenerType, MiddlewareType | ||||
| from sanic.models.handler_types import Sanic as SanicVar | ||||
| from sanic.request import Request | ||||
| from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream | ||||
| from sanic.response import BaseHTTPResponse | ||||
| from sanic.router import Router | ||||
| from sanic.server.websockets.impl import ConnectionClosed | ||||
| from sanic.signals import Signal, SignalRouter | ||||
| @@ -709,265 +703,15 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta): | ||||
|     # -------------------------------------------------------------------- # | ||||
|  | ||||
|     async def handle_exception( | ||||
|         self, request: Request, exception: BaseException | ||||
|         self, | ||||
|         request: Request, | ||||
|         exception: BaseException, | ||||
|         run_middleware: bool = True, | ||||
|     ):  # no cov | ||||
|         """ | ||||
|         A handler that catches specific exceptions and outputs a response. | ||||
|  | ||||
|         :param request: The current request object | ||||
|         :param exception: The exception that was raised | ||||
|         :raises ServerError: response 500 | ||||
|         """ | ||||
|         await self.dispatch( | ||||
|             "http.lifecycle.exception", | ||||
|             inline=True, | ||||
|             context={"request": request, "exception": exception}, | ||||
|         ) | ||||
|  | ||||
|         if ( | ||||
|             request.stream is not None | ||||
|             and request.stream.stage is not Stage.HANDLER | ||||
|         ): | ||||
|             error_logger.exception(exception, exc_info=True) | ||||
|             logger.error( | ||||
|                 "The error response will not be sent to the client for " | ||||
|                 f'the following exception:"{exception}". A previous response ' | ||||
|                 "has at least partially been sent." | ||||
|             ) | ||||
|  | ||||
|             handler = self.error_handler._lookup( | ||||
|                 exception, request.name if request else None | ||||
|             ) | ||||
|             if handler: | ||||
|                 logger.warning( | ||||
|                     "An error occurred while handling the request after at " | ||||
|                     "least some part of the response was sent to the client. " | ||||
|                     "The response from your custom exception handler " | ||||
|                     f"{handler.__name__} will not be sent to the client." | ||||
|                     "Exception handlers should only be used to generate the " | ||||
|                     "exception responses. If you would like to perform any " | ||||
|                     "other action on a raised exception, consider using a " | ||||
|                     "signal handler like " | ||||
|                     '`@app.signal("http.lifecycle.exception")`\n' | ||||
|                     "For further information, please see the docs: " | ||||
|                     "https://sanicframework.org/en/guide/advanced/" | ||||
|                     "signals.html", | ||||
|                 ) | ||||
|             return | ||||
|  | ||||
|         # -------------------------------------------- # | ||||
|         # Request Middleware | ||||
|         # -------------------------------------------- # | ||||
|         response = await self._run_request_middleware( | ||||
|             request, request_name=None | ||||
|         ) | ||||
|         # No middleware results | ||||
|         if not response: | ||||
|             try: | ||||
|                 response = self.error_handler.response(request, exception) | ||||
|                 if isawaitable(response): | ||||
|                     response = await response | ||||
|             except Exception as e: | ||||
|                 if isinstance(e, SanicException): | ||||
|                     response = self.error_handler.default(request, e) | ||||
|                 elif self.debug: | ||||
|                     response = HTTPResponse( | ||||
|                         ( | ||||
|                             f"Error while handling error: {e}\n" | ||||
|                             f"Stack: {format_exc()}" | ||||
|                         ), | ||||
|                         status=500, | ||||
|                     ) | ||||
|                 else: | ||||
|                     response = HTTPResponse( | ||||
|                         "An error occurred while handling an error", status=500 | ||||
|                     ) | ||||
|         if response is not None: | ||||
|             try: | ||||
|                 request.reset_response() | ||||
|                 response = await request.respond(response) | ||||
|             except BaseException: | ||||
|                 # Skip response middleware | ||||
|                 if request.stream: | ||||
|                     request.stream.respond(response) | ||||
|                 await response.send(end_stream=True) | ||||
|                 raise | ||||
|         else: | ||||
|             if request.stream: | ||||
|                 response = request.stream.response | ||||
|  | ||||
|         # Marked for cleanup and DRY with handle_request/handle_exception | ||||
|         # when ResponseStream is no longer supporder | ||||
|         if isinstance(response, BaseHTTPResponse): | ||||
|             await self.dispatch( | ||||
|                 "http.lifecycle.response", | ||||
|                 inline=True, | ||||
|                 context={ | ||||
|                     "request": request, | ||||
|                     "response": response, | ||||
|                 }, | ||||
|             ) | ||||
|             await response.send(end_stream=True) | ||||
|         elif isinstance(response, ResponseStream): | ||||
|             resp = await response(request) | ||||
|             await self.dispatch( | ||||
|                 "http.lifecycle.response", | ||||
|                 inline=True, | ||||
|                 context={ | ||||
|                     "request": request, | ||||
|                     "response": resp, | ||||
|                 }, | ||||
|             ) | ||||
|             await response.eof() | ||||
|         else: | ||||
|             raise ServerError( | ||||
|                 f"Invalid response type {response!r} (need HTTPResponse)" | ||||
|             ) | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     async def handle_request(self, request: Request):  # no cov | ||||
|         """Take a request from the HTTP Server and return a response object | ||||
|         to be sent back The HTTP Server only expects a response object, so | ||||
|         exception handling must be done here | ||||
|  | ||||
|         :param request: HTTP Request object | ||||
|         :return: Nothing | ||||
|         """ | ||||
|         await self.dispatch( | ||||
|             "http.lifecycle.handle", | ||||
|             inline=True, | ||||
|             context={"request": request}, | ||||
|         ) | ||||
|  | ||||
|         # Define `response` var here to remove warnings about | ||||
|         # allocation before assignment below. | ||||
|         response = None | ||||
|         try: | ||||
|  | ||||
|             await self.dispatch( | ||||
|                 "http.routing.before", | ||||
|                 inline=True, | ||||
|                 context={"request": request}, | ||||
|             ) | ||||
|             # Fetch handler from router | ||||
|             route, handler, kwargs = self.router.get( | ||||
|                 request.path, | ||||
|                 request.method, | ||||
|                 request.headers.getone("host", None), | ||||
|             ) | ||||
|  | ||||
|             request._match_info = {**kwargs} | ||||
|             request.route = route | ||||
|  | ||||
|             await self.dispatch( | ||||
|                 "http.routing.after", | ||||
|                 inline=True, | ||||
|                 context={ | ||||
|                     "request": request, | ||||
|                     "route": route, | ||||
|                     "kwargs": kwargs, | ||||
|                     "handler": handler, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|             if ( | ||||
|                 request.stream | ||||
|                 and request.stream.request_body | ||||
|                 and not route.ctx.ignore_body | ||||
|             ): | ||||
|  | ||||
|                 if hasattr(handler, "is_stream"): | ||||
|                     # Streaming handler: lift the size limit | ||||
|                     request.stream.request_max_size = float("inf") | ||||
|                 else: | ||||
|                     # Non-streaming handler: preload body | ||||
|                     await request.receive_body() | ||||
|  | ||||
|             # -------------------------------------------- # | ||||
|             # Request Middleware | ||||
|             # -------------------------------------------- # | ||||
|             response = await self._run_request_middleware( | ||||
|                 request, request_name=route.name | ||||
|             ) | ||||
|  | ||||
|             # No middleware results | ||||
|             if not response: | ||||
|                 # -------------------------------------------- # | ||||
|                 # Execute Handler | ||||
|                 # -------------------------------------------- # | ||||
|  | ||||
|                 if handler is None: | ||||
|                     raise ServerError( | ||||
|                         ( | ||||
|                             "'None' was returned while requesting a " | ||||
|                             "handler from the router" | ||||
|                         ) | ||||
|                     ) | ||||
|  | ||||
|                 # Run response handler | ||||
|                 await self.dispatch( | ||||
|                     "http.handler.before", | ||||
|                     inline=True, | ||||
|                     context={"request": request}, | ||||
|                 ) | ||||
|                 response = handler(request, **request.match_info) | ||||
|                 if isawaitable(response): | ||||
|                     response = await response | ||||
|                 await self.dispatch( | ||||
|                     "http.handler.after", | ||||
|                     inline=True, | ||||
|                     context={"request": request}, | ||||
|                 ) | ||||
|  | ||||
|             if request.responded: | ||||
|                 if response is not None: | ||||
|                     error_logger.error( | ||||
|                         "The response object returned by the route handler " | ||||
|                         "will not be sent to client. The request has already " | ||||
|                         "been responded to." | ||||
|                     ) | ||||
|                 if request.stream is not None: | ||||
|                     response = request.stream.response | ||||
|             elif response is not None: | ||||
|                 response = await request.respond(response) | ||||
|             elif not hasattr(handler, "is_websocket"): | ||||
|                 response = request.stream.response  # type: ignore | ||||
|  | ||||
|             # Marked for cleanup and DRY with handle_request/handle_exception | ||||
|             # when ResponseStream is no longer supporder | ||||
|             if isinstance(response, BaseHTTPResponse): | ||||
|                 await self.dispatch( | ||||
|                     "http.lifecycle.response", | ||||
|                     inline=True, | ||||
|                     context={ | ||||
|                         "request": request, | ||||
|                         "response": response, | ||||
|                     }, | ||||
|                 ) | ||||
|                 ... | ||||
|                 await response.send(end_stream=True) | ||||
|             elif isinstance(response, ResponseStream): | ||||
|                 resp = await response(request) | ||||
|                 await self.dispatch( | ||||
|                     "http.lifecycle.response", | ||||
|                     inline=True, | ||||
|                     context={ | ||||
|                         "request": request, | ||||
|                         "response": resp, | ||||
|                     }, | ||||
|                 ) | ||||
|                 await response.eof() | ||||
|             else: | ||||
|                 if not hasattr(handler, "is_websocket"): | ||||
|                     raise ServerError( | ||||
|                         f"Invalid response type {response!r} " | ||||
|                         "(need HTTPResponse)" | ||||
|                     ) | ||||
|  | ||||
|         except CancelledError: | ||||
|             raise | ||||
|         except Exception as e: | ||||
|             # Response Generation Failed | ||||
|             await self.handle_exception(request, e) | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     async def _websocket_handler( | ||||
|         self, handler, request, *args, subprotocols=None, **kwargs | ||||
| @@ -1036,20 +780,11 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta): | ||||
|     # -------------------------------------------------------------------- # | ||||
|  | ||||
|     async def _run_request_middleware( | ||||
|         self, request, request_name=None | ||||
|         self, request, middleware_collection | ||||
|     ):  # no cov | ||||
|         # The if improves speed.  I don't know why | ||||
|         named_middleware = self.named_request_middleware.get( | ||||
|             request_name, deque() | ||||
|         ) | ||||
|         applicable_middleware = self.request_middleware + named_middleware | ||||
|         request._request_middleware_started = True | ||||
|  | ||||
|         # request.request_middleware_started is meant as a stop-gap solution | ||||
|         # until RFC 1630 is adopted | ||||
|         if applicable_middleware and not request.request_middleware_started: | ||||
|             request.request_middleware_started = True | ||||
|  | ||||
|             for middleware in applicable_middleware: | ||||
|         for middleware in middleware_collection: | ||||
|             await self.dispatch( | ||||
|                 "http.middleware.before", | ||||
|                 inline=True, | ||||
| @@ -1079,14 +814,9 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta): | ||||
|         return None | ||||
|  | ||||
|     async def _run_response_middleware( | ||||
|         self, request, response, request_name=None | ||||
|         self, request, response, middleware_collection | ||||
|     ):  # no cov | ||||
|         named_middleware = self.named_response_middleware.get( | ||||
|             request_name, deque() | ||||
|         ) | ||||
|         applicable_middleware = self.response_middleware + named_middleware | ||||
|         if applicable_middleware: | ||||
|             for middleware in applicable_middleware: | ||||
|         for middleware in middleware_collection: | ||||
|             await self.dispatch( | ||||
|                 "http.middleware.before", | ||||
|                 inline=True, | ||||
| @@ -1528,6 +1258,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta): | ||||
|         except FinalizationError as e: | ||||
|             if not Sanic.test_mode: | ||||
|                 raise e | ||||
|         self.finalize_middleware() | ||||
|  | ||||
|     def signalize(self, allow_fail_builtin=True): | ||||
|         self.signal_router.allow_fail_builtin = allow_fail_builtin | ||||
|   | ||||
| @@ -7,6 +7,7 @@ from urllib.parse import quote | ||||
|  | ||||
| from sanic.compat import Header | ||||
| from sanic.exceptions import ServerError | ||||
| from sanic.handlers import RequestManager | ||||
| from sanic.helpers import _default | ||||
| from sanic.http import Stage | ||||
| from sanic.log import logger | ||||
| @@ -230,8 +231,9 @@ class ASGIApp: | ||||
|         """ | ||||
|         Handle the incoming request. | ||||
|         """ | ||||
|         manager = RequestManager.create(self.request) | ||||
|         try: | ||||
|             self.stage = Stage.HANDLER | ||||
|             await self.sanic_app.handle_request(self.request) | ||||
|             await manager.handle() | ||||
|         except Exception as e: | ||||
|             await self.sanic_app.handle_exception(self.request, e) | ||||
|             await manager.error(e) | ||||
|   | ||||
| @@ -1,16 +1,317 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from functools import partial | ||||
| from inspect import isawaitable | ||||
| from traceback import format_exc | ||||
| from typing import Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| from sanic_routing import Route | ||||
|  | ||||
| from sanic.errorpages import BaseRenderer, TextRenderer, exception_response | ||||
| from sanic.exceptions import ( | ||||
|     HeaderNotFound, | ||||
|     InvalidRangeType, | ||||
|     RangeNotSatisfiable, | ||||
|     SanicException, | ||||
|     ServerError, | ||||
| ) | ||||
| from sanic.log import deprecation, error_logger | ||||
| from sanic.http.constants import Stage | ||||
| from sanic.log import deprecation, error_logger, logger | ||||
| from sanic.models.handler_types import RouteHandler | ||||
| from sanic.response import text | ||||
| from sanic.request import Request | ||||
| from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream, text | ||||
| from sanic.touchup import TouchUpMeta | ||||
|  | ||||
|  | ||||
| class RequestHandler: | ||||
|     def __init__(self, func, request_middleware, response_middleware): | ||||
|         self.func = func.func if isinstance(func, RequestHandler) else func | ||||
|         self.request_middleware = request_middleware | ||||
|         self.response_middleware = response_middleware | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         return self.func(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| class RequestManager(metaclass=TouchUpMeta): | ||||
|     __touchup__ = ( | ||||
|         "cleanup", | ||||
|         "run_request_middleware", | ||||
|         "run_response_middleware", | ||||
|     ) | ||||
|     __slots__ = ( | ||||
|         "handler", | ||||
|         "request_middleware_run", | ||||
|         "request_middleware", | ||||
|         "request", | ||||
|         "response_middleware_run", | ||||
|         "response_middleware", | ||||
|     ) | ||||
|     request: Request | ||||
|  | ||||
|     def __init__(self, request: Request): | ||||
|         self.request_middleware_run = False | ||||
|         self.response_middleware_run = False | ||||
|         self.handler = self._noop | ||||
|         self.set_request(request) | ||||
|  | ||||
|     @classmethod | ||||
|     def create(cls, request: Request) -> RequestManager: | ||||
|         return cls(request) | ||||
|  | ||||
|     def set_request(self, request: Request): | ||||
|         request._manager = self | ||||
|         self.request = request | ||||
|         self.request_middleware = request.app.request_middleware | ||||
|         self.response_middleware = request.app.response_middleware | ||||
|  | ||||
|     async def handle(self): | ||||
|         route = self.resolve_route() | ||||
|  | ||||
|         if self.handler is None: | ||||
|             await self.error( | ||||
|                 ServerError( | ||||
|                     ( | ||||
|                         "'None' was returned while requesting a " | ||||
|                         "handler from the router" | ||||
|                     ) | ||||
|                 ) | ||||
|             ) | ||||
|             return | ||||
|  | ||||
|         if ( | ||||
|             self.request.stream | ||||
|             and self.request.stream.request_body | ||||
|             and not route.ctx.ignore_body | ||||
|         ): | ||||
|             await self.receive_body() | ||||
|  | ||||
|         await self.lifecycle( | ||||
|             partial(self.handler, self.request, **self.request.match_info) | ||||
|         ) | ||||
|  | ||||
|     async def lifecycle(self, handler, raise_exception: bool = False): | ||||
|         response: Optional[BaseHTTPResponse] = None | ||||
|         if not self.request_middleware_run and self.request_middleware: | ||||
|             response = await self.run( | ||||
|                 self.run_request_middleware, raise_exception | ||||
|             ) | ||||
|  | ||||
|         if not response: | ||||
|             # Run response handler | ||||
|             response = await self.run(handler, raise_exception) | ||||
|  | ||||
|         if not self.response_middleware_run and self.response_middleware: | ||||
|             response = await self.run( | ||||
|                 partial(self.run_response_middleware, response), | ||||
|                 raise_exception, | ||||
|             ) | ||||
|  | ||||
|         await self.cleanup(response) | ||||
|  | ||||
|     async def run( | ||||
|         self, operation, raise_exception: bool = False | ||||
|     ) -> Optional[BaseHTTPResponse]: | ||||
|         try: | ||||
|             response = operation() | ||||
|             if isawaitable(response): | ||||
|                 response = await response | ||||
|         except Exception as e: | ||||
|             if raise_exception: | ||||
|                 raise | ||||
|             response = await self.error(e) | ||||
|         return response | ||||
|  | ||||
|     async def error(self, exception: Exception): | ||||
|         error_handler = self.request.app.error_handler | ||||
|         if ( | ||||
|             self.request.stream is not None | ||||
|             and self.request.stream.stage is not Stage.HANDLER | ||||
|         ): | ||||
|             error_logger.exception(exception, exc_info=True) | ||||
|             logger.error( | ||||
|                 "The error response will not be sent to the client for " | ||||
|                 f'the following exception:"{exception}". A previous response ' | ||||
|                 "has at least partially been sent." | ||||
|             ) | ||||
|  | ||||
|             handler = error_handler._lookup( | ||||
|                 exception, self.request.name if self.request else None | ||||
|             ) | ||||
|             if handler: | ||||
|                 logger.warning( | ||||
|                     "An error occurred while handling the request after at " | ||||
|                     "least some part of the response was sent to the client. " | ||||
|                     "The response from your custom exception handler " | ||||
|                     f"{handler.__name__} will not be sent to the client." | ||||
|                     "Exception handlers should only be used to generate the " | ||||
|                     "exception responses. If you would like to perform any " | ||||
|                     "other action on a raised exception, consider using a " | ||||
|                     "signal handler like " | ||||
|                     '`@app.signal("http.lifecycle.exception")`\n' | ||||
|                     "For further information, please see the docs: " | ||||
|                     "https://sanicframework.org/en/guide/advanced/" | ||||
|                     "signals.html", | ||||
|                 ) | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             await self.lifecycle( | ||||
|                 partial(error_handler.response, self.request, exception), True | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             if isinstance(e, SanicException): | ||||
|                 response = error_handler.default(self.request, e) | ||||
|             elif self.request.app.debug: | ||||
|                 response = HTTPResponse( | ||||
|                     ( | ||||
|                         f"Error while handling error: {e}\n" | ||||
|                         f"Stack: {format_exc()}" | ||||
|                     ), | ||||
|                     status=500, | ||||
|                 ) | ||||
|             else: | ||||
|                 error_logger.exception(e) | ||||
|                 response = HTTPResponse( | ||||
|                     "An error occurred while handling an error", status=500 | ||||
|                 ) | ||||
|             return response | ||||
|         return None | ||||
|  | ||||
|     async def cleanup(self, response: Optional[BaseHTTPResponse]): | ||||
|         if self.request.responded: | ||||
|             if response is not None: | ||||
|                 error_logger.error( | ||||
|                     "The response object returned by the route handler " | ||||
|                     "will not be sent to client. The request has already " | ||||
|                     "been responded to." | ||||
|                 ) | ||||
|             if self.request.stream is not None: | ||||
|                 response = self.request.stream.response | ||||
|         elif response is not None: | ||||
|             self.request.reset_response() | ||||
|             response = await self.request.respond(response)  # type: ignore | ||||
|         elif not hasattr(self.handler, "is_websocket"): | ||||
|             response = self.request.stream.response  # type: ignore | ||||
|  | ||||
|         if isinstance(response, BaseHTTPResponse): | ||||
|             await self.request.app.dispatch( | ||||
|                 "http.lifecycle.response", | ||||
|                 inline=True, | ||||
|                 context={"request": self.request, "response": response}, | ||||
|             ) | ||||
|             await response.send(end_stream=True) | ||||
|         elif isinstance(response, ResponseStream): | ||||
|             await response(self.request)  # type: ignore | ||||
|             await response.eof()  # type: ignore | ||||
|             await self.request.app.dispatch( | ||||
|                 "http.lifecycle.response", | ||||
|                 inline=True, | ||||
|                 context={"request": self.request, "response": response}, | ||||
|             ) | ||||
|         else: | ||||
|             if not hasattr(self.handler, "is_websocket"): | ||||
|                 raise ServerError( | ||||
|                     f"Invalid response type {response!r} " | ||||
|                     "(need HTTPResponse)" | ||||
|                 ) | ||||
|  | ||||
|     async def receive_body(self): | ||||
|         if hasattr(self.handler, "is_stream"): | ||||
|             # Streaming handler: lift the size limit | ||||
|             self.request.stream.request_max_size = float("inf") | ||||
|         else: | ||||
|             # Non-streaming handler: preload body | ||||
|             await self.request.receive_body() | ||||
|  | ||||
|     async def run_request_middleware(self) -> Optional[BaseHTTPResponse]: | ||||
|         self.request._request_middleware_started = True | ||||
|         self.request_middleware_run = True | ||||
|  | ||||
|         for middleware in self.request_middleware: | ||||
|             await self.request.app.dispatch( | ||||
|                 "http.middleware.before", | ||||
|                 inline=True, | ||||
|                 context={"request": self.request, "response": None}, | ||||
|                 condition={"attach_to": "request"}, | ||||
|             ) | ||||
|  | ||||
|             try: | ||||
|                 response = await self.run(partial(middleware, self.request)) | ||||
|             except Exception: | ||||
|                 error_logger.exception( | ||||
|                     "Exception occurred in one of request middleware handlers" | ||||
|                 ) | ||||
|                 raise | ||||
|  | ||||
|             await self.request.app.dispatch( | ||||
|                 "http.middleware.after", | ||||
|                 inline=True, | ||||
|                 context={"request": self.request, "response": None}, | ||||
|                 condition={"attach_to": "request"}, | ||||
|             ) | ||||
|  | ||||
|             if response: | ||||
|                 return response | ||||
|         return None | ||||
|  | ||||
|     async def run_response_middleware( | ||||
|         self, response: BaseHTTPResponse | ||||
|     ) -> BaseHTTPResponse: | ||||
|         self.response_middleware_run = True | ||||
|         for middleware in self.response_middleware: | ||||
|             await self.request.app.dispatch( | ||||
|                 "http.middleware.before", | ||||
|                 inline=True, | ||||
|                 context={"request": self.request, "response": None}, | ||||
|                 condition={"attach_to": "request"}, | ||||
|             ) | ||||
|  | ||||
|             try: | ||||
|                 resp = await self.run( | ||||
|                     partial(middleware, self.request, response), True | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 error_logger.exception( | ||||
|                     "Exception occurred in one of response middleware handlers" | ||||
|                 ) | ||||
|                 await self.error(e) | ||||
|                 resp = None | ||||
|  | ||||
|             await self.request.app.dispatch( | ||||
|                 "http.middleware.after", | ||||
|                 inline=True, | ||||
|                 context={"request": self.request, "response": None}, | ||||
|                 condition={"attach_to": "request"}, | ||||
|             ) | ||||
|  | ||||
|             if resp: | ||||
|                 return resp | ||||
|         return response | ||||
|  | ||||
|     def resolve_route(self) -> Route: | ||||
|         # Fetch handler from router | ||||
|         route, handler, kwargs = self.request.app.router.get( | ||||
|             self.request.path, | ||||
|             self.request.method, | ||||
|             self.request.headers.getone("host", None), | ||||
|         ) | ||||
|  | ||||
|         self.request._match_info = {**kwargs} | ||||
|         self.request.route = route | ||||
|         self.handler = handler | ||||
|  | ||||
|         if handler and handler.request_middleware: | ||||
|             self.request_middleware = handler.request_middleware | ||||
|  | ||||
|         if handler and handler.response_middleware: | ||||
|             self.response_middleware = handler.response_middleware | ||||
|  | ||||
|         return route | ||||
|  | ||||
|     @staticmethod | ||||
|     def _noop(_): | ||||
|         ... | ||||
|  | ||||
|  | ||||
| class ErrorHandler: | ||||
|   | ||||
| @@ -124,7 +124,8 @@ class Http(Stream, metaclass=TouchUpMeta): | ||||
|  | ||||
|                 self.stage = Stage.HANDLER | ||||
|                 self.request.conn_info = self.protocol.conn_info | ||||
|                 await self.protocol.request_handler(self.request) | ||||
|  | ||||
|                 await self.request.manager.handle() | ||||
|  | ||||
|                 # Handler finished, response should've been sent | ||||
|                 if self.stage is Stage.HANDLER and not self.upgrade_websocket: | ||||
| @@ -250,6 +251,7 @@ class Http(Stream, metaclass=TouchUpMeta): | ||||
|             transport=self.protocol.transport, | ||||
|             app=self.protocol.app, | ||||
|         ) | ||||
|         self.protocol.request_handler.create(request) | ||||
|         self.protocol.request_class._current.set(request) | ||||
|         await self.dispatch( | ||||
|             "http.lifecycle.request", | ||||
| @@ -423,12 +425,11 @@ class Http(Stream, metaclass=TouchUpMeta): | ||||
|  | ||||
|         # From request and handler states we can respond, otherwise be silent | ||||
|         if self.stage is Stage.HANDLER: | ||||
|             app = self.protocol.app | ||||
|  | ||||
|             if self.request is None: | ||||
|                 self.create_empty_request() | ||||
|                 self.protocol.request_handler.create(self.request) | ||||
|  | ||||
|             await app.handle_exception(self.request, exception) | ||||
|             await self.request.manager.error(exception) | ||||
|  | ||||
|     def create_empty_request(self) -> None: | ||||
|         """ | ||||
|   | ||||
							
								
								
									
										66
									
								
								sanic/middleware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								sanic/middleware.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from collections import deque | ||||
| from enum import IntEnum, auto | ||||
| from itertools import count | ||||
| from typing import Deque, Sequence, Union | ||||
|  | ||||
| from sanic.models.handler_types import MiddlewareType | ||||
|  | ||||
|  | ||||
| class MiddlewareLocation(IntEnum): | ||||
|     REQUEST = auto() | ||||
|     RESPONSE = auto() | ||||
|  | ||||
|  | ||||
| class Middleware: | ||||
|     _counter = count() | ||||
|  | ||||
|     __slots__ = ("func", "priority", "location", "definition") | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         func: MiddlewareType, | ||||
|         location: MiddlewareLocation = MiddlewareLocation.REQUEST, | ||||
|         priority: int = 0, | ||||
|     ) -> None: | ||||
|         self.func = func | ||||
|         self.priority = priority | ||||
|         self.location = location | ||||
|         self.definition = next(Middleware._counter) | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         return self.func(*args, **kwargs) | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         name = getattr(self.func, "__name__", str(self.func)) | ||||
|         return ( | ||||
|             f"{self.__class__.__name__}(" | ||||
|             f"func=<function {name}>, " | ||||
|             f"priority={self.priority}, " | ||||
|             f"location={self.location.name})" | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def order(self): | ||||
|         return (self.priority, -self.definition) | ||||
|  | ||||
|     @classmethod | ||||
|     def convert( | ||||
|         cls, | ||||
|         *middleware_collections: Sequence[Union[Middleware, MiddlewareType]], | ||||
|         location: MiddlewareLocation, | ||||
|     ) -> Deque[Middleware]: | ||||
|         return deque( | ||||
|             [ | ||||
|                 middleware | ||||
|                 if isinstance(middleware, Middleware) | ||||
|                 else Middleware(middleware, location) | ||||
|                 for collection in middleware_collections | ||||
|                 for middleware in collection | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def reset_count(cls): | ||||
|         cls._counter = count() | ||||
| @@ -1,11 +1,18 @@ | ||||
| from collections import deque | ||||
| from functools import partial | ||||
| from operator import attrgetter | ||||
| from typing import List | ||||
|  | ||||
| from sanic.base.meta import SanicMeta | ||||
| from sanic.handlers import RequestHandler | ||||
| from sanic.middleware import Middleware, MiddlewareLocation | ||||
| from sanic.models.futures import FutureMiddleware | ||||
| from sanic.router import Router | ||||
|  | ||||
|  | ||||
| class MiddlewareMixin(metaclass=SanicMeta): | ||||
|     router: Router | ||||
|  | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         self._future_middleware: List[FutureMiddleware] = [] | ||||
|  | ||||
| @@ -13,7 +20,12 @@ class MiddlewareMixin(metaclass=SanicMeta): | ||||
|         raise NotImplementedError  # noqa | ||||
|  | ||||
|     def middleware( | ||||
|         self, middleware_or_request, attach_to="request", apply=True | ||||
|         self, | ||||
|         middleware_or_request, | ||||
|         attach_to="request", | ||||
|         apply=True, | ||||
|         *, | ||||
|         priority=0 | ||||
|     ): | ||||
|         """ | ||||
|         Decorate and register middleware to be called before a request | ||||
| @@ -30,6 +42,12 @@ class MiddlewareMixin(metaclass=SanicMeta): | ||||
|         def register_middleware(middleware, attach_to="request"): | ||||
|             nonlocal apply | ||||
|  | ||||
|             location = ( | ||||
|                 MiddlewareLocation.REQUEST | ||||
|                 if attach_to == "request" | ||||
|                 else MiddlewareLocation.RESPONSE | ||||
|             ) | ||||
|             middleware = Middleware(middleware, location, priority=priority) | ||||
|             future_middleware = FutureMiddleware(middleware, attach_to) | ||||
|             self._future_middleware.append(future_middleware) | ||||
|             if apply: | ||||
| @@ -46,7 +64,7 @@ class MiddlewareMixin(metaclass=SanicMeta): | ||||
|                 register_middleware, attach_to=middleware_or_request | ||||
|             ) | ||||
|  | ||||
|     def on_request(self, middleware=None): | ||||
|     def on_request(self, middleware=None, *, priority=0): | ||||
|         """Register a middleware to be called before a request is handled. | ||||
|  | ||||
|         This is the same as *@app.middleware('request')*. | ||||
| @@ -54,11 +72,13 @@ class MiddlewareMixin(metaclass=SanicMeta): | ||||
|         :param: middleware: A callable that takes in request. | ||||
|         """ | ||||
|         if callable(middleware): | ||||
|             return self.middleware(middleware, "request") | ||||
|             return self.middleware(middleware, "request", priority=priority) | ||||
|         else: | ||||
|             return partial(self.middleware, attach_to="request") | ||||
|             return partial( | ||||
|                 self.middleware, attach_to="request", priority=priority | ||||
|             ) | ||||
|  | ||||
|     def on_response(self, middleware=None): | ||||
|     def on_response(self, middleware=None, *, priority=0): | ||||
|         """Register a middleware to be called after a response is created. | ||||
|  | ||||
|         This is the same as *@app.middleware('response')*. | ||||
| @@ -67,6 +87,61 @@ class MiddlewareMixin(metaclass=SanicMeta): | ||||
|             A callable that takes in a request and its response. | ||||
|         """ | ||||
|         if callable(middleware): | ||||
|             return self.middleware(middleware, "response") | ||||
|             return self.middleware(middleware, "response", priority=priority) | ||||
|         else: | ||||
|             return partial(self.middleware, attach_to="response") | ||||
|             return partial( | ||||
|                 self.middleware, attach_to="response", priority=priority | ||||
|             ) | ||||
|  | ||||
|     def finalize_middleware(self): | ||||
|         for route in self.router.routes: | ||||
|             request_middleware = Middleware.convert( | ||||
|                 self.request_middleware, | ||||
|                 self.named_request_middleware.get(route.name, deque()), | ||||
|                 location=MiddlewareLocation.REQUEST, | ||||
|             ) | ||||
|             response_middleware = Middleware.convert( | ||||
|                 self.response_middleware, | ||||
|                 self.named_response_middleware.get(route.name, deque()), | ||||
|                 location=MiddlewareLocation.RESPONSE, | ||||
|             ) | ||||
|  | ||||
|             route.handler = RequestHandler( | ||||
|                 route.handler, | ||||
|                 deque( | ||||
|                     sorted( | ||||
|                         request_middleware, | ||||
|                         key=attrgetter("order"), | ||||
|                         reverse=True, | ||||
|                     ) | ||||
|                 ), | ||||
|                 deque( | ||||
|                     sorted( | ||||
|                         response_middleware, | ||||
|                         key=attrgetter("order"), | ||||
|                         reverse=True, | ||||
|                     )[::-1] | ||||
|                 ), | ||||
|             ) | ||||
|         request_middleware = Middleware.convert( | ||||
|             self.request_middleware, | ||||
|             location=MiddlewareLocation.REQUEST, | ||||
|         ) | ||||
|         response_middleware = Middleware.convert( | ||||
|             self.response_middleware, | ||||
|             location=MiddlewareLocation.RESPONSE, | ||||
|         ) | ||||
|         self.request_middleware = deque( | ||||
|             sorted( | ||||
|                 request_middleware, | ||||
|                 key=attrgetter("order"), | ||||
|                 reverse=True, | ||||
|             ) | ||||
|         ) | ||||
|         self.response_middleware = deque( | ||||
|             sorted( | ||||
|                 response_middleware, | ||||
|                 key=attrgetter("order"), | ||||
|                 reverse=True, | ||||
|             )[::-1] | ||||
|         ) | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from contextvars import ContextVar | ||||
| from functools import partial | ||||
| from inspect import isawaitable | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
| @@ -23,6 +24,7 @@ from sanic.models.http_types import Credentials | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from sanic.handlers import RequestManager | ||||
|     from sanic.server import ConnInfo | ||||
|     from sanic.app import Sanic | ||||
|  | ||||
| @@ -37,7 +39,7 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse | ||||
| from httptools import parse_url | ||||
| from httptools.parser.errors import HttpParserInvalidURLError | ||||
|  | ||||
| from sanic.compat import CancelledErrors, Header | ||||
| from sanic.compat import Header | ||||
| from sanic.constants import ( | ||||
|     CACHEABLE_HTTP_METHODS, | ||||
|     DEFAULT_HTTP_CONTENT_TYPE, | ||||
| @@ -56,7 +58,7 @@ from sanic.headers import ( | ||||
|     parse_xforwarded, | ||||
| ) | ||||
| from sanic.http import Stage | ||||
| from sanic.log import error_logger, logger | ||||
| from sanic.log import deprecation, error_logger, logger | ||||
| from sanic.models.protocol_types import TransportProtocol | ||||
| from sanic.response import BaseHTTPResponse, HTTPResponse | ||||
|  | ||||
| @@ -99,10 +101,12 @@ class Request: | ||||
|         "_cookies", | ||||
|         "_id", | ||||
|         "_ip", | ||||
|         "_manager", | ||||
|         "_parsed_url", | ||||
|         "_port", | ||||
|         "_protocol", | ||||
|         "_remote_addr", | ||||
|         "_request_middleware_started", | ||||
|         "_scheme", | ||||
|         "_socket", | ||||
|         "_stream_id", | ||||
| @@ -126,7 +130,6 @@ class Request: | ||||
|         "parsed_token", | ||||
|         "raw_url", | ||||
|         "responded", | ||||
|         "request_middleware_started", | ||||
|         "route", | ||||
|         "stream", | ||||
|         "transport", | ||||
| @@ -178,10 +181,11 @@ class Request: | ||||
|         self.parsed_not_grouped_args: DefaultDict[ | ||||
|             Tuple[bool, bool, str, str], List[Tuple[str, str]] | ||||
|         ] = defaultdict(list) | ||||
|         self.request_middleware_started = False | ||||
|         self._request_middleware_started = False | ||||
|         self.responded: bool = False | ||||
|         self.route: Optional[Route] = None | ||||
|         self.stream: Optional[Stream] = None | ||||
|         self._manager: Optional[RequestManager] = None | ||||
|         self._cookies: Optional[Dict[str, str]] = None | ||||
|         self._match_info: Dict[str, Any] = {} | ||||
|         self._protocol = None | ||||
| @@ -219,6 +223,16 @@ class Request: | ||||
|     def generate_id(*_): | ||||
|         return uuid.uuid4() | ||||
|  | ||||
|     @property | ||||
|     def request_middleware_started(self): | ||||
|         deprecation( | ||||
|             "Request.request_middleware_started has been deprecated and will" | ||||
|             "be removed. You should set a flag on the request context using" | ||||
|             "either middleware or signals if you need this feature.", | ||||
|             22.3, | ||||
|         ) | ||||
|         return self._request_middleware_started | ||||
|  | ||||
|     @property | ||||
|     def stream_id(self): | ||||
|         """ | ||||
| @@ -233,6 +247,10 @@ class Request: | ||||
|             ) | ||||
|         return self._stream_id | ||||
|  | ||||
|     @property | ||||
|     def manager(self): | ||||
|         return self._manager | ||||
|  | ||||
|     def reset_response(self): | ||||
|         try: | ||||
|             if ( | ||||
| @@ -323,15 +341,13 @@ class Request: | ||||
|             if isawaitable(response): | ||||
|                 response = await response  # type: ignore | ||||
|         # Run response middleware | ||||
|         try: | ||||
|             response = await self.app._run_response_middleware( | ||||
|                 self, response, request_name=self.name | ||||
|             ) | ||||
|         except CancelledErrors: | ||||
|             raise | ||||
|         except Exception: | ||||
|             error_logger.exception( | ||||
|                 "Exception occurred in one of response middleware handlers" | ||||
|         if ( | ||||
|             self._manager | ||||
|             and not self._manager.response_middleware_run | ||||
|             and self._manager.response_middleware | ||||
|         ): | ||||
|             response = await self._manager.run( | ||||
|                 partial(self._manager.run_response_middleware, response) | ||||
|             ) | ||||
|         self.responded = True | ||||
|         return response | ||||
|   | ||||
| @@ -13,6 +13,7 @@ from sanic_routing.route import Route | ||||
| from sanic.constants import HTTP_METHODS | ||||
| from sanic.errorpages import check_error_format | ||||
| from sanic.exceptions import MethodNotAllowed, NotFound, SanicException | ||||
| from sanic.handlers import RequestHandler | ||||
| from sanic.models.handler_types import RouteHandler | ||||
|  | ||||
|  | ||||
| @@ -31,9 +32,11 @@ class Router(BaseRouter): | ||||
|  | ||||
|     def _get( | ||||
|         self, path: str, method: str, host: Optional[str] | ||||
|     ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: | ||||
|     ) -> Tuple[Route, RequestHandler, Dict[str, Any]]: | ||||
|         try: | ||||
|             return self.resolve( | ||||
|             # We know this will always be RequestHandler, so we can ignore | ||||
|             # typing issue here | ||||
|             return self.resolve(  # type: ignore | ||||
|                 path=path, | ||||
|                 method=method, | ||||
|                 extra={"host": host} if host else None, | ||||
| @@ -50,7 +53,7 @@ class Router(BaseRouter): | ||||
|     @lru_cache(maxsize=ROUTER_CACHE_SIZE) | ||||
|     def get(  # type: ignore | ||||
|         self, path: str, method: str, host: Optional[str] | ||||
|     ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: | ||||
|     ) -> Tuple[Route, RequestHandler, Dict[str, Any]]: | ||||
|         """ | ||||
|         Retrieve a `Route` object containing the details about how to handle | ||||
|         a response for a given request | ||||
| @@ -59,7 +62,7 @@ class Router(BaseRouter): | ||||
|         :type request: Request | ||||
|         :return: details needed for handling the request and returning the | ||||
|             correct response | ||||
|         :rtype: Tuple[ Route, RouteHandler, Dict[str, Any]] | ||||
|         :rtype: Tuple[ Route, RequestHandler, Dict[str, Any]] | ||||
|         """ | ||||
|         return self._get(path, method, host) | ||||
|  | ||||
| @@ -114,7 +117,7 @@ class Router(BaseRouter): | ||||
|  | ||||
|         params = dict( | ||||
|             path=uri, | ||||
|             handler=handler, | ||||
|             handler=RequestHandler(handler, [], []), | ||||
|             methods=frozenset(map(str, methods)) if methods else None, | ||||
|             name=name, | ||||
|             strict=strict_slashes, | ||||
|   | ||||
| @@ -2,6 +2,7 @@ from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| from sanic.handlers import RequestManager | ||||
| from sanic.http.constants import HTTP | ||||
| from sanic.http.http3 import Http3 | ||||
| from sanic.touchup.meta import TouchUpMeta | ||||
| @@ -57,7 +58,7 @@ class HttpProtocolMixin: | ||||
|     def _setup(self): | ||||
|         self.request: Optional[Request] = None | ||||
|         self.access_log = self.app.config.ACCESS_LOG | ||||
|         self.request_handler = self.app.handle_request | ||||
|         self.request_handler = RequestManager | ||||
|         self.error_handler = self.app.error_handler | ||||
|         self.request_timeout = self.app.config.REQUEST_TIMEOUT | ||||
|         self.response_timeout = self.app.config.RESPONSE_TIMEOUT | ||||
|   | ||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -84,7 +84,7 @@ ujson = "ujson>=1.35" + env_dependency | ||||
| uvloop = "uvloop>=0.5.3" + env_dependency | ||||
| types_ujson = "types-ujson" + env_dependency | ||||
| requirements = [ | ||||
|     "sanic-routing>=22.3.0,<22.6.0", | ||||
|     "sanic-routing>=22.8.0", | ||||
|     "httptools>=0.0.10", | ||||
|     uvloop, | ||||
|     ujson, | ||||
|   | ||||
| @@ -152,8 +152,11 @@ def test_app_route_raise_value_error(app: Sanic): | ||||
|  | ||||
|  | ||||
| def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch): | ||||
|     mock = Mock() | ||||
|     mock.handler = None | ||||
|  | ||||
|     def mockreturn(*args, **kwargs): | ||||
|         return Mock(), None, {} | ||||
|         return mock, None, {} | ||||
|  | ||||
|     monkeypatch.setattr(app.router, "get", mockreturn) | ||||
|  | ||||
|   | ||||
| @@ -2,12 +2,22 @@ import logging | ||||
|  | ||||
| from asyncio import CancelledError | ||||
| from itertools import count | ||||
| from unittest.mock import Mock | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from sanic.exceptions import NotFound | ||||
| from sanic.middleware import Middleware | ||||
| from sanic.request import Request | ||||
| from sanic.response import HTTPResponse, json, text | ||||
|  | ||||
|  | ||||
| @pytest.fixture(autouse=True) | ||||
| def reset_middleware(): | ||||
|     yield | ||||
|     Middleware.reset_count() | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------ # | ||||
| #  GET | ||||
| # ------------------------------------------------------------ # | ||||
| @@ -183,7 +193,7 @@ def test_middleware_response_raise_exception(app, caplog): | ||||
|     with caplog.at_level(logging.ERROR): | ||||
|         reqrequest, response = app.test_client.get("/fail") | ||||
|  | ||||
|     assert response.status == 404 | ||||
|     assert response.status == 500 | ||||
|     # 404 errors are not logged | ||||
|     assert ( | ||||
|         "sanic.error", | ||||
| @@ -318,6 +328,15 @@ def test_middleware_return_response(app): | ||||
|         resp1 = await request.respond() | ||||
|         return resp1 | ||||
|  | ||||
|     _, response = app.test_client.get("/") | ||||
|     app.test_client.get("/") | ||||
|     assert response_middleware_run_count == 1 | ||||
|     assert request_middleware_run_count == 1 | ||||
|  | ||||
|  | ||||
| def test_middleware_object(): | ||||
|     mock = Mock() | ||||
|     middleware = Middleware(mock) | ||||
|     middleware(1, 2, 3, answer=42) | ||||
|  | ||||
|     mock.assert_called_once_with(1, 2, 3, answer=42) | ||||
|     assert middleware.order == (0, 0) | ||||
|   | ||||
							
								
								
									
										83
									
								
								tests/test_middleware_priority.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								tests/test_middleware_priority.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | ||||
| from functools import partial | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.response import json | ||||
|  | ||||
|  | ||||
| PRIORITY_TEST_CASES = ( | ||||
|     ([0, 1, 2], [1, 1, 1]), | ||||
|     ([0, 1, 2], [1, 1, None]), | ||||
|     ([0, 1, 2], [1, None, None]), | ||||
|     ([0, 1, 2], [2, 1, None]), | ||||
|     ([0, 1, 2], [2, 2, None]), | ||||
|     ([0, 1, 2], [3, 2, 1]), | ||||
|     ([0, 1, 2], [None, None, None]), | ||||
|     ([0, 2, 1], [1, None, 1]), | ||||
|     ([0, 2, 1], [2, None, 1]), | ||||
|     ([0, 2, 1], [2, None, 2]), | ||||
|     ([0, 2, 1], [3, 1, 2]), | ||||
|     ([1, 0, 2], [1, 2, None]), | ||||
|     ([1, 0, 2], [2, 3, 1]), | ||||
|     ([1, 0, 2], [None, 1, None]), | ||||
|     ([1, 2, 0], [1, 3, 2]), | ||||
|     ([1, 2, 0], [None, 1, 1]), | ||||
|     ([1, 2, 0], [None, 2, 1]), | ||||
|     ([1, 2, 0], [None, 2, 2]), | ||||
|     ([2, 0, 1], [1, None, 2]), | ||||
|     ([2, 0, 1], [2, 1, 3]), | ||||
|     ([2, 0, 1], [None, None, 1]), | ||||
|     ([2, 1, 0], [1, 2, 3]), | ||||
|     ([2, 1, 0], [None, 1, 2]), | ||||
| ) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "expected,priorities", | ||||
|     PRIORITY_TEST_CASES, | ||||
| ) | ||||
| def test_request_middleware_order_priority(app: Sanic, expected, priorities): | ||||
|     order = [] | ||||
|  | ||||
|     def add_ident(request, ident): | ||||
|         order.append(ident) | ||||
|  | ||||
|     @app.get("/") | ||||
|     def handler(request): | ||||
|         return json(None) | ||||
|  | ||||
|     for ident, priority in enumerate(priorities): | ||||
|         kwargs = {} | ||||
|         if priority is not None: | ||||
|             kwargs["priority"] = priority | ||||
|         app.on_request(partial(add_ident, ident=ident), **kwargs) | ||||
|  | ||||
|     app.test_client.get("/") | ||||
|  | ||||
|     assert order == expected | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "expected,priorities", | ||||
|     PRIORITY_TEST_CASES, | ||||
| ) | ||||
| def test_response_middleware_order_priority(app: Sanic, expected, priorities): | ||||
|     order = [] | ||||
|  | ||||
|     def add_ident(request, response, ident): | ||||
|         order.append(ident) | ||||
|  | ||||
|     @app.get("/") | ||||
|     def handler(request): | ||||
|         return json(None) | ||||
|  | ||||
|     for ident, priority in enumerate(priorities): | ||||
|         kwargs = {} | ||||
|         if priority is not None: | ||||
|             kwargs["priority"] = priority | ||||
|         app.on_response(partial(add_ident, ident=ident), **kwargs) | ||||
|  | ||||
|     app.test_client.get("/") | ||||
|  | ||||
|     assert order[::-1] == expected | ||||
		Reference in New Issue
	
	Block a user