Compare commits

...

12 Commits

Author SHA1 Message Date
Adam Hopkins
38b4ccf2bc
Cleanup implementation 2022-09-19 21:34:50 +03:00
Adam Hopkins
8b970dd490
Merge branch 'main' of github.com:sanic-org/sanic into middleware-revamp 2022-09-19 16:04:38 +03:00
Adam Hopkins
c9be17e8da
Merge conflicts 2022-09-18 23:48:06 +03:00
Adam Hopkins
19f642b364
Add to tests 2022-09-15 18:46:09 +03:00
Adam Hopkins
c4c39cb082
Merge branch 'main' of github.com:sanic-org/sanic into middleware-revamp 2022-09-15 18:33:22 +03:00
Adam Hopkins
c7bac72137
WIP 2022-08-20 22:24:43 +03:00
Adam Hopkins
beb5c62767
Add global middleware ordering 2022-08-17 21:57:07 +03:00
Adam Hopkins
09b59d34fe
Fix typing error 2022-08-17 15:26:59 +03:00
Adam Hopkins
78bc475bb1
Add test case 2022-08-17 15:23:30 +03:00
Adam Hopkins
b59131504b
Merge branch 'main' into middleware-revamp 2022-08-17 14:17:34 +03:00
Adam Hopkins
782e0881e5
Slots to Middleware 2022-08-07 22:38:25 +03:00
Adam Hopkins
c72cbe4326
Begin middleware revamp 2022-08-07 22:31:26 +03:00
13 changed files with 671 additions and 370 deletions

View File

@ -21,7 +21,6 @@ from functools import partial
from inspect import isawaitable from inspect import isawaitable
from os import environ from os import environ
from socket import socket from socket import socket
from traceback import format_exc
from types import SimpleNamespace from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -54,12 +53,7 @@ from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support
from sanic.config import SANIC_PREFIX, Config from sanic.config import SANIC_PREFIX, Config
from sanic.exceptions import ( from sanic.exceptions import BadRequest, SanicException, URLBuildError
BadRequest,
SanicException,
ServerError,
URLBuildError,
)
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.helpers import _default from sanic.helpers import _default
from sanic.http import Stage from sanic.http import Stage
@ -83,7 +77,7 @@ from sanic.models.futures import (
from sanic.models.handler_types import ListenerType, MiddlewareType from sanic.models.handler_types import ListenerType, MiddlewareType
from sanic.models.handler_types import Sanic as SanicVar from sanic.models.handler_types import Sanic as SanicVar
from sanic.request import Request from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream from sanic.response import BaseHTTPResponse
from sanic.router import Router from sanic.router import Router
from sanic.server.websockets.impl import ConnectionClosed from sanic.server.websockets.impl import ConnectionClosed
from sanic.signals import Signal, SignalRouter from sanic.signals import Signal, SignalRouter
@ -709,265 +703,15 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
async def handle_exception( async def handle_exception(
self, request: Request, exception: BaseException self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov ): # no cov
""" raise NotImplementedError
A handler that catches specific exceptions and outputs a response.
:param request: The current request object
:param exception: The exception that was raised
:raises ServerError: response 500
"""
await self.dispatch(
"http.lifecycle.exception",
inline=True,
context={"request": request, "exception": exception},
)
if (
request.stream is not None
and request.stream.stage is not Stage.HANDLER
):
error_logger.exception(exception, exc_info=True)
logger.error(
"The error response will not be sent to the client for "
f'the following exception:"{exception}". A previous response '
"has at least partially been sent."
)
handler = self.error_handler._lookup(
exception, request.name if request else None
)
if handler:
logger.warning(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"The response from your custom exception handler "
f"{handler.__name__} will not be sent to the client."
"Exception handlers should only be used to generate the "
"exception responses. If you would like to perform any "
"other action on a raised exception, consider using a "
"signal handler like "
'`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/"
"signals.html",
)
return
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
# No middleware results
if not response:
try:
response = self.error_handler.response(request, exception)
if isawaitable(response):
response = await response
except Exception as e:
if isinstance(e, SanicException):
response = self.error_handler.default(request, e)
elif self.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
if response is not None:
try:
request.reset_response()
response = await request.respond(response)
except BaseException:
# Skip response middleware
if request.stream:
request.stream.respond(response)
await response.send(end_stream=True)
raise
else:
if request.stream:
response = request.stream.response
# Marked for cleanup and DRY with handle_request/handle_exception
# when ResponseStream is no longer supporder
if isinstance(response, BaseHTTPResponse):
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": response,
},
)
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": resp,
},
)
await response.eof()
else:
raise ServerError(
f"Invalid response type {response!r} (need HTTPResponse)"
)
async def handle_request(self, request: Request): # no cov async def handle_request(self, request: Request): # no cov
"""Take a request from the HTTP Server and return a response object raise NotImplementedError
to be sent back The HTTP Server only expects a response object, so
exception handling must be done here
:param request: HTTP Request object
:return: Nothing
"""
await self.dispatch(
"http.lifecycle.handle",
inline=True,
context={"request": request},
)
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
try:
await self.dispatch(
"http.routing.before",
inline=True,
context={"request": request},
)
# Fetch handler from router
route, handler, kwargs = self.router.get(
request.path,
request.method,
request.headers.getone("host", None),
)
request._match_info = {**kwargs}
request.route = route
await self.dispatch(
"http.routing.after",
inline=True,
context={
"request": request,
"route": route,
"kwargs": kwargs,
"handler": handler,
},
)
if (
request.stream
and request.stream.request_body
and not route.ctx.ignore_body
):
if hasattr(handler, "is_stream"):
# Streaming handler: lift the size limit
request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await request.receive_body()
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=route.name
)
# No middleware results
if not response:
# -------------------------------------------- #
# Execute Handler
# -------------------------------------------- #
if handler is None:
raise ServerError(
(
"'None' was returned while requesting a "
"handler from the router"
)
)
# Run response handler
await self.dispatch(
"http.handler.before",
inline=True,
context={"request": request},
)
response = handler(request, **request.match_info)
if isawaitable(response):
response = await response
await self.dispatch(
"http.handler.after",
inline=True,
context={"request": request},
)
if request.responded:
if response is not None:
error_logger.error(
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
if request.stream is not None:
response = request.stream.response
elif response is not None:
response = await request.respond(response)
elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore
# Marked for cleanup and DRY with handle_request/handle_exception
# when ResponseStream is no longer supporder
if isinstance(response, BaseHTTPResponse):
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": response,
},
)
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
await self.dispatch(
"http.lifecycle.response",
inline=True,
context={
"request": request,
"response": resp,
},
)
await response.eof()
else:
if not hasattr(handler, "is_websocket"):
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
except CancelledError:
raise
except Exception as e:
# Response Generation Failed
await self.handle_exception(request, e)
async def _websocket_handler( async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs self, handler, request, *args, subprotocols=None, **kwargs
@ -1036,86 +780,72 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
async def _run_request_middleware( async def _run_request_middleware(
self, request, request_name=None self, request, middleware_collection
): # no cov ): # no cov
# The if improves speed. I don't know why request._request_middleware_started = True
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
# request.request_middleware_started is meant as a stop-gap solution for middleware in middleware_collection:
# until RFC 1630 is adopted await self.dispatch(
if applicable_middleware and not request.request_middleware_started: "http.middleware.before",
request.request_middleware_started = True inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in applicable_middleware: response = middleware(request)
await self.dispatch( if isawaitable(response):
"http.middleware.before", response = await response
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
response = middleware(request) await self.dispatch(
if isawaitable(response): "http.middleware.after",
response = await response inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch( if response:
"http.middleware.after", return response
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
if response:
return response
return None return None
async def _run_response_middleware( async def _run_response_middleware(
self, request, response, request_name=None self, request, response, middleware_collection
): # no cov ): # no cov
named_middleware = self.named_response_middleware.get( for middleware in middleware_collection:
request_name, deque() await self.dispatch(
) "http.middleware.before",
applicable_middleware = self.response_middleware + named_middleware inline=True,
if applicable_middleware: context={
for middleware in applicable_middleware: "request": request,
await self.dispatch( "response": response,
"http.middleware.before", },
inline=True, condition={"attach_to": "response"},
context={ )
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
_response = middleware(request, response) _response = middleware(request, response)
if isawaitable(_response): if isawaitable(_response):
_response = await _response _response = await _response
await self.dispatch( await self.dispatch(
"http.middleware.after", "http.middleware.after",
inline=True, inline=True,
context={ context={
"request": request, "request": request,
"response": _response if _response else response, "response": _response if _response else response,
}, },
condition={"attach_to": "response"}, condition={"attach_to": "response"},
) )
if _response: if _response:
response = _response response = _response
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response) response = request.stream.respond(response)
break break
return response return response
def _build_endpoint_name(self, *parts): def _build_endpoint_name(self, *parts):
@ -1528,6 +1258,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
except FinalizationError as e: except FinalizationError as e:
if not Sanic.test_mode: if not Sanic.test_mode:
raise e raise e
self.finalize_middleware()
def signalize(self, allow_fail_builtin=True): def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin self.signal_router.allow_fail_builtin = allow_fail_builtin

View File

@ -7,6 +7,7 @@ from urllib.parse import quote
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
from sanic.handlers import RequestManager
from sanic.helpers import _default from sanic.helpers import _default
from sanic.http import Stage from sanic.http import Stage
from sanic.log import logger from sanic.log import logger
@ -230,8 +231,9 @@ class ASGIApp:
""" """
Handle the incoming request. Handle the incoming request.
""" """
manager = RequestManager.create(self.request)
try: try:
self.stage = Stage.HANDLER self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request) await manager.handle()
except Exception as e: except Exception as e:
await self.sanic_app.handle_exception(self.request, e) await manager.error(e)

View File

@ -1,16 +1,317 @@
from __future__ import annotations from __future__ import annotations
from functools import partial
from inspect import isawaitable
from traceback import format_exc
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from sanic_routing import Route
from sanic.errorpages import BaseRenderer, TextRenderer, exception_response from sanic.errorpages import BaseRenderer, TextRenderer, exception_response
from sanic.exceptions import ( from sanic.exceptions import (
HeaderNotFound, HeaderNotFound,
InvalidRangeType, InvalidRangeType,
RangeNotSatisfiable, RangeNotSatisfiable,
SanicException,
ServerError,
) )
from sanic.log import deprecation, error_logger from sanic.http.constants import Stage
from sanic.log import deprecation, error_logger, logger
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
from sanic.response import text from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream, text
from sanic.touchup import TouchUpMeta
class RequestHandler:
def __init__(self, func, request_middleware, response_middleware):
self.func = func.func if isinstance(func, RequestHandler) else func
self.request_middleware = request_middleware
self.response_middleware = response_middleware
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
class RequestManager(metaclass=TouchUpMeta):
__touchup__ = (
"cleanup",
"run_request_middleware",
"run_response_middleware",
)
__slots__ = (
"handler",
"request_middleware_run",
"request_middleware",
"request",
"response_middleware_run",
"response_middleware",
)
request: Request
def __init__(self, request: Request):
self.request_middleware_run = False
self.response_middleware_run = False
self.handler = self._noop
self.set_request(request)
@classmethod
def create(cls, request: Request) -> RequestManager:
return cls(request)
def set_request(self, request: Request):
request._manager = self
self.request = request
self.request_middleware = request.app.request_middleware
self.response_middleware = request.app.response_middleware
async def handle(self):
route = self.resolve_route()
if self.handler is None:
await self.error(
ServerError(
(
"'None' was returned while requesting a "
"handler from the router"
)
)
)
return
if (
self.request.stream
and self.request.stream.request_body
and not route.ctx.ignore_body
):
await self.receive_body()
await self.lifecycle(
partial(self.handler, self.request, **self.request.match_info)
)
async def lifecycle(self, handler, raise_exception: bool = False):
response: Optional[BaseHTTPResponse] = None
if not self.request_middleware_run and self.request_middleware:
response = await self.run(
self.run_request_middleware, raise_exception
)
if not response:
# Run response handler
response = await self.run(handler, raise_exception)
if not self.response_middleware_run and self.response_middleware:
response = await self.run(
partial(self.run_response_middleware, response),
raise_exception,
)
await self.cleanup(response)
async def run(
self, operation, raise_exception: bool = False
) -> Optional[BaseHTTPResponse]:
try:
response = operation()
if isawaitable(response):
response = await response
except Exception as e:
if raise_exception:
raise
response = await self.error(e)
return response
async def error(self, exception: Exception):
error_handler = self.request.app.error_handler
if (
self.request.stream is not None
and self.request.stream.stage is not Stage.HANDLER
):
error_logger.exception(exception, exc_info=True)
logger.error(
"The error response will not be sent to the client for "
f'the following exception:"{exception}". A previous response '
"has at least partially been sent."
)
handler = error_handler._lookup(
exception, self.request.name if self.request else None
)
if handler:
logger.warning(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"The response from your custom exception handler "
f"{handler.__name__} will not be sent to the client."
"Exception handlers should only be used to generate the "
"exception responses. If you would like to perform any "
"other action on a raised exception, consider using a "
"signal handler like "
'`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/"
"signals.html",
)
return
try:
await self.lifecycle(
partial(error_handler.response, self.request, exception), True
)
except Exception as e:
if isinstance(e, SanicException):
response = error_handler.default(self.request, e)
elif self.request.app.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
error_logger.exception(e)
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
return response
return None
async def cleanup(self, response: Optional[BaseHTTPResponse]):
if self.request.responded:
if response is not None:
error_logger.error(
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
if self.request.stream is not None:
response = self.request.stream.response
elif response is not None:
self.request.reset_response()
response = await self.request.respond(response) # type: ignore
elif not hasattr(self.handler, "is_websocket"):
response = self.request.stream.response # type: ignore
if isinstance(response, BaseHTTPResponse):
await self.request.app.dispatch(
"http.lifecycle.response",
inline=True,
context={"request": self.request, "response": response},
)
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
await response(self.request) # type: ignore
await response.eof() # type: ignore
await self.request.app.dispatch(
"http.lifecycle.response",
inline=True,
context={"request": self.request, "response": response},
)
else:
if not hasattr(self.handler, "is_websocket"):
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
async def receive_body(self):
if hasattr(self.handler, "is_stream"):
# Streaming handler: lift the size limit
self.request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await self.request.receive_body()
async def run_request_middleware(self) -> Optional[BaseHTTPResponse]:
self.request._request_middleware_started = True
self.request_middleware_run = True
for middleware in self.request_middleware:
await self.request.app.dispatch(
"http.middleware.before",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
try:
response = await self.run(partial(middleware, self.request))
except Exception:
error_logger.exception(
"Exception occurred in one of request middleware handlers"
)
raise
await self.request.app.dispatch(
"http.middleware.after",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
if response:
return response
return None
async def run_response_middleware(
self, response: BaseHTTPResponse
) -> BaseHTTPResponse:
self.response_middleware_run = True
for middleware in self.response_middleware:
await self.request.app.dispatch(
"http.middleware.before",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
try:
resp = await self.run(
partial(middleware, self.request, response), True
)
except Exception as e:
error_logger.exception(
"Exception occurred in one of response middleware handlers"
)
await self.error(e)
resp = None
await self.request.app.dispatch(
"http.middleware.after",
inline=True,
context={"request": self.request, "response": None},
condition={"attach_to": "request"},
)
if resp:
return resp
return response
def resolve_route(self) -> Route:
# Fetch handler from router
route, handler, kwargs = self.request.app.router.get(
self.request.path,
self.request.method,
self.request.headers.getone("host", None),
)
self.request._match_info = {**kwargs}
self.request.route = route
self.handler = handler
if handler and handler.request_middleware:
self.request_middleware = handler.request_middleware
if handler and handler.response_middleware:
self.response_middleware = handler.response_middleware
return route
@staticmethod
def _noop(_):
...
class ErrorHandler: class ErrorHandler:

View File

@ -124,7 +124,8 @@ class Http(Stream, metaclass=TouchUpMeta):
self.stage = Stage.HANDLER self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
await self.request.manager.handle()
# Handler finished, response should've been sent # Handler finished, response should've been sent
if self.stage is Stage.HANDLER and not self.upgrade_websocket: if self.stage is Stage.HANDLER and not self.upgrade_websocket:
@ -250,6 +251,7 @@ class Http(Stream, metaclass=TouchUpMeta):
transport=self.protocol.transport, transport=self.protocol.transport,
app=self.protocol.app, app=self.protocol.app,
) )
self.protocol.request_handler.create(request)
self.protocol.request_class._current.set(request) self.protocol.request_class._current.set(request)
await self.dispatch( await self.dispatch(
"http.lifecycle.request", "http.lifecycle.request",
@ -423,12 +425,11 @@ class Http(Stream, metaclass=TouchUpMeta):
# From request and handler states we can respond, otherwise be silent # From request and handler states we can respond, otherwise be silent
if self.stage is Stage.HANDLER: if self.stage is Stage.HANDLER:
app = self.protocol.app
if self.request is None: if self.request is None:
self.create_empty_request() self.create_empty_request()
self.protocol.request_handler.create(self.request)
await app.handle_exception(self.request, exception) await self.request.manager.error(exception)
def create_empty_request(self) -> None: def create_empty_request(self) -> None:
""" """

66
sanic/middleware.py Normal file
View File

@ -0,0 +1,66 @@
from __future__ import annotations
from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Sequence, Union
from sanic.models.handler_types import MiddlewareType
class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()
class Middleware:
_counter = count()
__slots__ = ("func", "priority", "location", "definition")
def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation = MiddlewareLocation.REQUEST,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware._counter)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __repr__(self) -> str:
name = getattr(self.func, "__name__", str(self.func))
return (
f"{self.__class__.__name__}("
f"func=<function {name}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)
@property
def order(self):
return (self.priority, -self.definition)
@classmethod
def convert(
cls,
*middleware_collections: Sequence[Union[Middleware, MiddlewareType]],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)
@classmethod
def reset_count(cls):
cls._counter = count()

View File

@ -1,11 +1,18 @@
from collections import deque
from functools import partial from functools import partial
from operator import attrgetter
from typing import List from typing import List
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.handlers import RequestHandler
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware from sanic.models.futures import FutureMiddleware
from sanic.router import Router
class MiddlewareMixin(metaclass=SanicMeta): class MiddlewareMixin(metaclass=SanicMeta):
router: Router
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self._future_middleware: List[FutureMiddleware] = [] self._future_middleware: List[FutureMiddleware] = []
@ -13,7 +20,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
raise NotImplementedError # noqa raise NotImplementedError # noqa
def middleware( def middleware(
self, middleware_or_request, attach_to="request", apply=True self,
middleware_or_request,
attach_to="request",
apply=True,
*,
priority=0
): ):
""" """
Decorate and register middleware to be called before a request Decorate and register middleware to be called before a request
@ -30,6 +42,12 @@ class MiddlewareMixin(metaclass=SanicMeta):
def register_middleware(middleware, attach_to="request"): def register_middleware(middleware, attach_to="request"):
nonlocal apply nonlocal apply
location = (
MiddlewareLocation.REQUEST
if attach_to == "request"
else MiddlewareLocation.RESPONSE
)
middleware = Middleware(middleware, location, priority=priority)
future_middleware = FutureMiddleware(middleware, attach_to) future_middleware = FutureMiddleware(middleware, attach_to)
self._future_middleware.append(future_middleware) self._future_middleware.append(future_middleware)
if apply: if apply:
@ -46,7 +64,7 @@ class MiddlewareMixin(metaclass=SanicMeta):
register_middleware, attach_to=middleware_or_request register_middleware, attach_to=middleware_or_request
) )
def on_request(self, middleware=None): def on_request(self, middleware=None, *, priority=0):
"""Register a middleware to be called before a request is handled. """Register a middleware to be called before a request is handled.
This is the same as *@app.middleware('request')*. This is the same as *@app.middleware('request')*.
@ -54,11 +72,13 @@ class MiddlewareMixin(metaclass=SanicMeta):
:param: middleware: A callable that takes in request. :param: middleware: A callable that takes in request.
""" """
if callable(middleware): if callable(middleware):
return self.middleware(middleware, "request") return self.middleware(middleware, "request", priority=priority)
else: else:
return partial(self.middleware, attach_to="request") return partial(
self.middleware, attach_to="request", priority=priority
)
def on_response(self, middleware=None): def on_response(self, middleware=None, *, priority=0):
"""Register a middleware to be called after a response is created. """Register a middleware to be called after a response is created.
This is the same as *@app.middleware('response')*. This is the same as *@app.middleware('response')*.
@ -67,6 +87,61 @@ class MiddlewareMixin(metaclass=SanicMeta):
A callable that takes in a request and its response. A callable that takes in a request and its response.
""" """
if callable(middleware): if callable(middleware):
return self.middleware(middleware, "response") return self.middleware(middleware, "response", priority=priority)
else: else:
return partial(self.middleware, attach_to="response") return partial(
self.middleware, attach_to="response", priority=priority
)
def finalize_middleware(self):
for route in self.router.routes:
request_middleware = Middleware.convert(
self.request_middleware,
self.named_request_middleware.get(route.name, deque()),
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE,
)
route.handler = RequestHandler(
route.handler,
deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
),
deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
),
)
request_middleware = Middleware.convert(
self.request_middleware,
location=MiddlewareLocation.REQUEST,
)
response_middleware = Middleware.convert(
self.response_middleware,
location=MiddlewareLocation.RESPONSE,
)
self.request_middleware = deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
)
self.response_middleware = deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
)

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar from contextvars import ContextVar
from functools import partial
from inspect import isawaitable from inspect import isawaitable
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -23,6 +24,7 @@ from sanic.models.http_types import Credentials
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.handlers import RequestManager
from sanic.server import ConnInfo from sanic.server import ConnInfo
from sanic.app import Sanic from sanic.app import Sanic
@ -37,7 +39,7 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url from httptools import parse_url
from httptools.parser.errors import HttpParserInvalidURLError from httptools.parser.errors import HttpParserInvalidURLError
from sanic.compat import CancelledErrors, Header from sanic.compat import Header
from sanic.constants import ( from sanic.constants import (
CACHEABLE_HTTP_METHODS, CACHEABLE_HTTP_METHODS,
DEFAULT_HTTP_CONTENT_TYPE, DEFAULT_HTTP_CONTENT_TYPE,
@ -56,7 +58,7 @@ from sanic.headers import (
parse_xforwarded, parse_xforwarded,
) )
from sanic.http import Stage from sanic.http import Stage
from sanic.log import error_logger, logger from sanic.log import deprecation, error_logger, logger
from sanic.models.protocol_types import TransportProtocol from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse from sanic.response import BaseHTTPResponse, HTTPResponse
@ -99,10 +101,12 @@ class Request:
"_cookies", "_cookies",
"_id", "_id",
"_ip", "_ip",
"_manager",
"_parsed_url", "_parsed_url",
"_port", "_port",
"_protocol", "_protocol",
"_remote_addr", "_remote_addr",
"_request_middleware_started",
"_scheme", "_scheme",
"_socket", "_socket",
"_stream_id", "_stream_id",
@ -126,7 +130,6 @@ class Request:
"parsed_token", "parsed_token",
"raw_url", "raw_url",
"responded", "responded",
"request_middleware_started",
"route", "route",
"stream", "stream",
"transport", "transport",
@ -178,10 +181,11 @@ class Request:
self.parsed_not_grouped_args: DefaultDict[ self.parsed_not_grouped_args: DefaultDict[
Tuple[bool, bool, str, str], List[Tuple[str, str]] Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list) ] = defaultdict(list)
self.request_middleware_started = False self._request_middleware_started = False
self.responded: bool = False self.responded: bool = False
self.route: Optional[Route] = None self.route: Optional[Route] = None
self.stream: Optional[Stream] = None self.stream: Optional[Stream] = None
self._manager: Optional[RequestManager] = None
self._cookies: Optional[Dict[str, str]] = None self._cookies: Optional[Dict[str, str]] = None
self._match_info: Dict[str, Any] = {} self._match_info: Dict[str, Any] = {}
self._protocol = None self._protocol = None
@ -219,6 +223,16 @@ class Request:
def generate_id(*_): def generate_id(*_):
return uuid.uuid4() return uuid.uuid4()
@property
def request_middleware_started(self):
deprecation(
"Request.request_middleware_started has been deprecated and will"
"be removed. You should set a flag on the request context using"
"either middleware or signals if you need this feature.",
22.3,
)
return self._request_middleware_started
@property @property
def stream_id(self): def stream_id(self):
""" """
@ -233,6 +247,10 @@ class Request:
) )
return self._stream_id return self._stream_id
@property
def manager(self):
return self._manager
def reset_response(self): def reset_response(self):
try: try:
if ( if (
@ -323,15 +341,13 @@ class Request:
if isawaitable(response): if isawaitable(response):
response = await response # type: ignore response = await response # type: ignore
# Run response middleware # Run response middleware
try: if (
response = await self.app._run_response_middleware( self._manager
self, response, request_name=self.name and not self._manager.response_middleware_run
) and self._manager.response_middleware
except CancelledErrors: ):
raise response = await self._manager.run(
except Exception: partial(self._manager.run_response_middleware, response)
error_logger.exception(
"Exception occurred in one of response middleware handlers"
) )
self.responded = True self.responded = True
return response return response

View File

@ -13,6 +13,7 @@ from sanic_routing.route import Route
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.errorpages import check_error_format from sanic.errorpages import check_error_format
from sanic.exceptions import MethodNotAllowed, NotFound, SanicException from sanic.exceptions import MethodNotAllowed, NotFound, SanicException
from sanic.handlers import RequestHandler
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
@ -31,9 +32,11 @@ class Router(BaseRouter):
def _get( def _get(
self, path: str, method: str, host: Optional[str] self, path: str, method: str, host: Optional[str]
) -> Tuple[Route, RouteHandler, Dict[str, Any]]: ) -> Tuple[Route, RequestHandler, Dict[str, Any]]:
try: try:
return self.resolve( # We know this will always be RequestHandler, so we can ignore
# typing issue here
return self.resolve( # type: ignore
path=path, path=path,
method=method, method=method,
extra={"host": host} if host else None, extra={"host": host} if host else None,
@ -50,7 +53,7 @@ class Router(BaseRouter):
@lru_cache(maxsize=ROUTER_CACHE_SIZE) @lru_cache(maxsize=ROUTER_CACHE_SIZE)
def get( # type: ignore def get( # type: ignore
self, path: str, method: str, host: Optional[str] self, path: str, method: str, host: Optional[str]
) -> Tuple[Route, RouteHandler, Dict[str, Any]]: ) -> Tuple[Route, RequestHandler, Dict[str, Any]]:
""" """
Retrieve a `Route` object containing the details about how to handle Retrieve a `Route` object containing the details about how to handle
a response for a given request a response for a given request
@ -59,7 +62,7 @@ class Router(BaseRouter):
:type request: Request :type request: Request
:return: details needed for handling the request and returning the :return: details needed for handling the request and returning the
correct response correct response
:rtype: Tuple[ Route, RouteHandler, Dict[str, Any]] :rtype: Tuple[ Route, RequestHandler, Dict[str, Any]]
""" """
return self._get(path, method, host) return self._get(path, method, host)
@ -114,7 +117,7 @@ class Router(BaseRouter):
params = dict( params = dict(
path=uri, path=uri,
handler=handler, handler=RequestHandler(handler, [], []),
methods=frozenset(map(str, methods)) if methods else None, methods=frozenset(map(str, methods)) if methods else None,
name=name, name=name,
strict=strict_slashes, strict=strict_slashes,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from sanic.handlers import RequestManager
from sanic.http.constants import HTTP from sanic.http.constants import HTTP
from sanic.http.http3 import Http3 from sanic.http.http3 import Http3
from sanic.touchup.meta import TouchUpMeta from sanic.touchup.meta import TouchUpMeta
@ -57,7 +58,7 @@ class HttpProtocolMixin:
def _setup(self): def _setup(self):
self.request: Optional[Request] = None self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request self.request_handler = RequestManager
self.error_handler = self.app.error_handler self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT self.response_timeout = self.app.config.RESPONSE_TIMEOUT

View File

@ -84,7 +84,7 @@ ujson = "ujson>=1.35" + env_dependency
uvloop = "uvloop>=0.5.3" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency
types_ujson = "types-ujson" + env_dependency types_ujson = "types-ujson" + env_dependency
requirements = [ requirements = [
"sanic-routing>=22.3.0,<22.6.0", "sanic-routing>=22.8.0",
"httptools>=0.0.10", "httptools>=0.0.10",
uvloop, uvloop,
ujson, ujson,

View File

@ -152,8 +152,11 @@ def test_app_route_raise_value_error(app: Sanic):
def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch): def test_app_handle_request_handler_is_none(app: Sanic, monkeypatch):
mock = Mock()
mock.handler = None
def mockreturn(*args, **kwargs): def mockreturn(*args, **kwargs):
return Mock(), None, {} return mock, None, {}
monkeypatch.setattr(app.router, "get", mockreturn) monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -2,12 +2,22 @@ import logging
from asyncio import CancelledError from asyncio import CancelledError
from itertools import count from itertools import count
from unittest.mock import Mock
import pytest
from sanic.exceptions import NotFound from sanic.exceptions import NotFound
from sanic.middleware import Middleware
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, json, text from sanic.response import HTTPResponse, json, text
@pytest.fixture(autouse=True)
def reset_middleware():
yield
Middleware.reset_count()
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
# GET # GET
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
@ -183,7 +193,7 @@ def test_middleware_response_raise_exception(app, caplog):
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
reqrequest, response = app.test_client.get("/fail") reqrequest, response = app.test_client.get("/fail")
assert response.status == 404 assert response.status == 500
# 404 errors are not logged # 404 errors are not logged
assert ( assert (
"sanic.error", "sanic.error",
@ -318,6 +328,15 @@ def test_middleware_return_response(app):
resp1 = await request.respond() resp1 = await request.respond()
return resp1 return resp1
_, response = app.test_client.get("/") app.test_client.get("/")
assert response_middleware_run_count == 1 assert response_middleware_run_count == 1
assert request_middleware_run_count == 1 assert request_middleware_run_count == 1
def test_middleware_object():
mock = Mock()
middleware = Middleware(mock)
middleware(1, 2, 3, answer=42)
mock.assert_called_once_with(1, 2, 3, answer=42)
assert middleware.order == (0, 0)

View File

@ -0,0 +1,83 @@
from functools import partial
import pytest
from sanic import Sanic
from sanic.response import json
PRIORITY_TEST_CASES = (
([0, 1, 2], [1, 1, 1]),
([0, 1, 2], [1, 1, None]),
([0, 1, 2], [1, None, None]),
([0, 1, 2], [2, 1, None]),
([0, 1, 2], [2, 2, None]),
([0, 1, 2], [3, 2, 1]),
([0, 1, 2], [None, None, None]),
([0, 2, 1], [1, None, 1]),
([0, 2, 1], [2, None, 1]),
([0, 2, 1], [2, None, 2]),
([0, 2, 1], [3, 1, 2]),
([1, 0, 2], [1, 2, None]),
([1, 0, 2], [2, 3, 1]),
([1, 0, 2], [None, 1, None]),
([1, 2, 0], [1, 3, 2]),
([1, 2, 0], [None, 1, 1]),
([1, 2, 0], [None, 2, 1]),
([1, 2, 0], [None, 2, 2]),
([2, 0, 1], [1, None, 2]),
([2, 0, 1], [2, 1, 3]),
([2, 0, 1], [None, None, 1]),
([2, 1, 0], [1, 2, 3]),
([2, 1, 0], [None, 1, 2]),
)
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_request_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_request(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order == expected
@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
)
def test_response_middleware_order_priority(app: Sanic, expected, priorities):
order = []
def add_ident(request, response, ident):
order.append(ident)
@app.get("/")
def handler(request):
return json(None)
for ident, priority in enumerate(priorities):
kwargs = {}
if priority is not None:
kwargs["priority"] = priority
app.on_response(partial(add_ident, ident=ident), **kwargs)
app.test_client.get("/")
assert order[::-1] == expected