This commit is contained in:
Adam Hopkins 2022-08-20 22:24:43 +03:00
parent beb5c62767
commit c7bac72137
No known key found for this signature in database
GPG Key ID: 9F85EE6C807303FB
10 changed files with 367 additions and 40 deletions

View File

@ -60,7 +60,7 @@ from sanic.exceptions import (
ServerError, ServerError,
URLBuildError, URLBuildError,
) )
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler, RequestManager
from sanic.helpers import _default from sanic.helpers import _default
from sanic.http import Stage from sanic.http import Stage
from sanic.log import ( from sanic.log import (
@ -705,6 +705,14 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
request: Request, request: Request,
exception: BaseException, exception: BaseException,
run_middleware: bool = True, run_middleware: bool = True,
): # no cov
raise NotImplementedError
async def _handle_exception(
self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov ): # no cov
""" """
A handler that catches specific exceptions and outputs a response. A handler that catches specific exceptions and outputs a response.
@ -830,6 +838,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
:param request: HTTP Request object :param request: HTTP Request object
:return: Nothing :return: Nothing
""" """
async def _handle_request(self, request: Request): # no cov
await self.dispatch( await self.dispatch(
"http.lifecycle.handle", "http.lifecycle.handle",
inline=True, inline=True,

View File

@ -7,6 +7,7 @@ from urllib.parse import quote
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
from sanic.handlers import RequestManager
from sanic.helpers import _default from sanic.helpers import _default
from sanic.http import Stage from sanic.http import Stage
from sanic.log import logger from sanic.log import logger
@ -230,8 +231,9 @@ class ASGIApp:
""" """
Handle the incoming request. Handle the incoming request.
""" """
manager = RequestManager.create(self.request)
try: try:
self.stage = Stage.HANDLER self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request) await manager.handle()
except Exception as e: except Exception as e:
await self.sanic_app.handle_exception(self.request, e) await manager.error(e)

View File

@ -1,16 +1,319 @@
from __future__ import annotations from __future__ import annotations
from functools import partial
from inspect import isawaitable
from traceback import format_exc
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from sanic_routing import Route
from sanic.errorpages import BaseRenderer, TextRenderer, exception_response from sanic.errorpages import BaseRenderer, TextRenderer, exception_response
from sanic.exceptions import ( from sanic.exceptions import (
HeaderNotFound, HeaderNotFound,
InvalidRangeType, InvalidRangeType,
RangeNotSatisfiable, RangeNotSatisfiable,
SanicException,
ServerError,
) )
from sanic.log import deprecation, error_logger from sanic.http.constants import Stage
from sanic.log import deprecation, error_logger, logger
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
from sanic.response import text from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream, text
class RequestHandler:
def __init__(self, func, request_middleware, response_middleware):
self.func = func
self.request_middleware = request_middleware
self.response_middleware = response_middleware
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
class RequestManager:
request: Request
def __init__(self, request: Request):
self.request_middleware_run = False
self.response_middleware_run = False
self.handler = self._noop
self.set_request(request)
@classmethod
def create(cls, request: Request) -> RequestManager:
return cls(request)
def set_request(self, request: Request):
request._manager = self
self.request = request
self.request_middleware = request.app.request_middleware
self.response_middleware = request.app.response_middleware
async def handle(self):
route = self.resolve_route()
if self.handler is None:
await self.error(
ServerError(
(
"'None' was returned while requesting a "
"handler from the router"
)
)
)
return
if (
self.request.stream
and self.request.stream.request_body
and not route.ctx.ignore_body
):
await self.receive_body()
await self.lifecycle(
partial(self.handler, self.request, **self.request.match_info)
)
async def lifecycle(self, handler):
response: Optional[BaseHTTPResponse] = None
if not self.request_middleware_run and self.request_middleware:
response = await self.run(self.run_request_middleware)
if not response:
# Run response handler
response = await self.run(handler)
if not self.response_middleware_run and self.response_middleware:
response = await self.run(
partial(self.run_response_middleware, response)
)
await self.cleanup(response)
async def run(self, operation) -> Optional[BaseHTTPResponse]:
try:
response = operation()
if isawaitable(response):
response = await response
except Exception as e:
response = await self.error(e)
return response
async def error(self, exception: Exception):
error_handler = self.request.app.error_handler
if (
self.request.stream is not None
and self.request.stream.stage is not Stage.HANDLER
):
error_logger.exception(exception, exc_info=True)
logger.error(
"The error response will not be sent to the client for "
f'the following exception:"{exception}". A previous response '
"has at least partially been sent."
)
handler = error_handler._lookup(
exception, self.request.name if self.request else None
)
if handler:
logger.warning(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"The response from your custom exception handler "
f"{handler.__name__} will not be sent to the client."
"Exception handlers should only be used to generate the "
"exception responses. If you would like to perform any "
"other action on a raised exception, consider using a "
"signal handler like "
'`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/"
"signals.html",
)
return
try:
await self.lifecycle(
partial(error_handler.response, self.request, exception)
)
except Exception as e:
await self.lifecycle(
partial(error_handler.default, self.request, e)
)
if isinstance(e, SanicException):
response = error_handler.default(self.request, e)
elif self.request.app.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
return response
return None
async def cleanup(self, response: Optional[BaseHTTPResponse]):
if self.request.responded:
if response is not None:
error_logger.error(
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
if self.request.stream is not None:
response = self.request.stream.response
elif response is not None:
self.request.reset_response()
response = await self.request.respond(response) # type: ignore
elif not hasattr(self.handler, "is_websocket"):
response = self.request.stream.response # type: ignore
# Marked for cleanup and DRY with handle_request/handle_exception
# when ResponseStream is no longer supporder
if isinstance(response, BaseHTTPResponse):
# await self.dispatch(
# "http.lifecycle.response",
# inline=True,
# context={
# "request": self.request,
# "response": response,
# },
# )
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
await response(self.request) # type: ignore
# await self.dispatch(
# "http.lifecycle.response",
# inline=True,
# context={
# "request": self.request,
# "response": resp,
# },
# )
await response.eof() # type: ignore
else:
if not hasattr(self.handler, "is_websocket"):
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
async def receive_body(self):
if hasattr(self.handler, "is_stream"):
# Streaming handler: lift the size limit
self.request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await self.request.receive_body()
async def run_request_middleware(self) -> Optional[BaseHTTPResponse]:
self.request._request_middleware_started = True
self.request_middleware_run = True
for middleware in self.request_middleware:
# await self.dispatch(
# "http.middleware.before",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
response = await self.run(partial(middleware, self.request))
# await self.dispatch(
# "http.middleware.after",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
if response:
return response
return None
async def run_response_middleware(
self, response: BaseHTTPResponse
) -> BaseHTTPResponse:
self.response_middleware_run = True
for middleware in self.response_middleware:
# await self.dispatch(
# "http.middleware.before",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
resp = await self.run(partial(middleware, self.request, response))
# await self.dispatch(
# "http.middleware.after",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
if resp:
return resp
return response
# try:
# middleware = (
# self.route and self.route.extra.response_middleware
# ) or self.app.response_middleware
# if middleware:
# response = await self.app._run_response_middleware(
# self, response, middleware
# )
# except CancelledErrors:
# raise
# except Exception:
# error_logger.exception(
# "Exception occurred in one of response middleware handlers"
# )
# return None
def resolve_route(self) -> Route:
# Fetch handler from router
route, handler, kwargs = self.request.app.router.get(
self.request.path,
self.request.method,
self.request.headers.getone("host", None),
)
self.request._match_info = {**kwargs}
self.request.route = route
self.handler = handler
if route.handler and route.handler.request_middleware:
self.request_middleware = route.handler.request_middleware
if route.handler and route.handler.response_middleware:
self.response_middleware = route.handler.response_middleware
return route
@staticmethod
def _noop(_):
...
class ErrorHandler: class ErrorHandler:

View File

@ -124,7 +124,8 @@ class Http(Stream, metaclass=TouchUpMeta):
self.stage = Stage.HANDLER self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
await self.request.manager.handle()
# Handler finished, response should've been sent # Handler finished, response should've been sent
if self.stage is Stage.HANDLER and not self.upgrade_websocket: if self.stage is Stage.HANDLER and not self.upgrade_websocket:
@ -246,6 +247,7 @@ class Http(Stream, metaclass=TouchUpMeta):
transport=self.protocol.transport, transport=self.protocol.transport,
app=self.protocol.app, app=self.protocol.app,
) )
self.protocol.request_handler.create(request)
self.protocol.request_class._current.set(request) self.protocol.request_class._current.set(request)
await self.dispatch( await self.dispatch(
"http.lifecycle.request", "http.lifecycle.request",
@ -419,12 +421,11 @@ class Http(Stream, metaclass=TouchUpMeta):
# From request and handler states we can respond, otherwise be silent # From request and handler states we can respond, otherwise be silent
if self.stage is Stage.HANDLER: if self.stage is Stage.HANDLER:
app = self.protocol.app
if self.request is None: if self.request is None:
self.create_empty_request() self.create_empty_request()
self.protocol.request_handler.create(self.request)
await app.handle_exception(self.request, exception) await self.request.manager.error(exception)
def create_empty_request(self) -> None: def create_empty_request(self) -> None:
""" """

View File

@ -33,9 +33,10 @@ class Middleware:
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
def __repr__(self) -> str: def __repr__(self) -> str:
name = getattr(self.func, "__name__", str(self.func))
return ( return (
f"{self.__class__.__name__}(" f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, " f"func=<function {name}>, "
f"priority={self.priority}, " f"priority={self.priority}, "
f"location={self.location.name})" f"location={self.location.name})"
) )

View File

@ -4,6 +4,7 @@ from operator import attrgetter
from typing import List from typing import List
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.handlers import RequestHandler
from sanic.middleware import Middleware, MiddlewareLocation from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware from sanic.models.futures import FutureMiddleware
from sanic.router import Router from sanic.router import Router
@ -104,19 +105,23 @@ class MiddlewareMixin(metaclass=SanicMeta):
self.named_response_middleware.get(route.name, deque()), self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE, location=MiddlewareLocation.RESPONSE,
) )
route.extra.request_middleware = deque(
sorted( route.handler = RequestHandler(
request_middleware, route.handler,
key=attrgetter("order"), deque(
reverse=True, sorted(
) request_middleware,
) key=attrgetter("order"),
route.extra.response_middleware = deque( reverse=True,
sorted( )
response_middleware, ),
key=attrgetter("order"), deque(
reverse=True, sorted(
)[::-1] response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
),
) )
request_middleware = Middleware.convert( request_middleware = Middleware.convert(
self.request_middleware, self.request_middleware,

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar from contextvars import ContextVar
from functools import partial
from inspect import isawaitable from inspect import isawaitable
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -23,6 +24,7 @@ from sanic.models.http_types import Credentials
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.handlers import RequestManager
from sanic.server import ConnInfo from sanic.server import ConnInfo
from sanic.app import Sanic from sanic.app import Sanic
@ -37,7 +39,7 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url from httptools import parse_url
from httptools.parser.errors import HttpParserInvalidURLError from httptools.parser.errors import HttpParserInvalidURLError
from sanic.compat import CancelledErrors, Header from sanic.compat import Header
from sanic.constants import ( from sanic.constants import (
CACHEABLE_HTTP_METHODS, CACHEABLE_HTTP_METHODS,
DEFAULT_HTTP_CONTENT_TYPE, DEFAULT_HTTP_CONTENT_TYPE,
@ -99,6 +101,7 @@ class Request:
"_cookies", "_cookies",
"_id", "_id",
"_ip", "_ip",
"_manager",
"_parsed_url", "_parsed_url",
"_port", "_port",
"_protocol", "_protocol",
@ -182,6 +185,7 @@ class Request:
self.responded: bool = False self.responded: bool = False
self.route: Optional[Route] = None self.route: Optional[Route] = None
self.stream: Optional[Stream] = None self.stream: Optional[Stream] = None
self._manager: Optional[RequestManager] = None
self._cookies: Optional[Dict[str, str]] = None self._cookies: Optional[Dict[str, str]] = None
self._match_info: Dict[str, Any] = {} self._match_info: Dict[str, Any] = {}
self._protocol = None self._protocol = None
@ -243,6 +247,10 @@ class Request:
) )
return self._stream_id return self._stream_id
@property
def manager(self):
return self._manager
def reset_response(self): def reset_response(self):
try: try:
if ( if (
@ -333,19 +341,13 @@ class Request:
if isawaitable(response): if isawaitable(response):
response = await response # type: ignore response = await response # type: ignore
# Run response middleware # Run response middleware
try: if (
middleware = ( self._manager
self.route and self.route.extra.response_middleware and not self._manager.response_middleware_run
) or self.app.response_middleware and self._manager.response_middleware
if middleware: ):
response = await self.app._run_response_middleware( response = await self._manager.run(
self, response, middleware partial(self._manager.run_response_middleware, response)
)
except CancelledErrors:
raise
except Exception:
error_logger.exception(
"Exception occurred in one of response middleware handlers"
) )
self.responded = True self.responded = True
return response return response

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from sanic.handlers import RequestManager
from sanic.http.constants import HTTP from sanic.http.constants import HTTP
from sanic.http.http3 import Http3 from sanic.http.http3 import Http3
from sanic.touchup.meta import TouchUpMeta from sanic.touchup.meta import TouchUpMeta
@ -53,7 +54,7 @@ class HttpProtocolMixin:
def _setup(self): def _setup(self):
self.request: Optional[Request] = None self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request self.request_handler = RequestManager
self.error_handler = self.app.error_handler self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT self.response_timeout = self.app.config.RESPONSE_TIMEOUT

View File

@ -150,8 +150,11 @@ def test_app_route_raise_value_error(app):
def test_app_handle_request_handler_is_none(app, monkeypatch): def test_app_handle_request_handler_is_none(app, monkeypatch):
mock = Mock()
mock.handler = None
def mockreturn(*args, **kwargs): def mockreturn(*args, **kwargs):
return Mock(), None, {} return mock, None, {}
# Not sure how to make app.router.get() return None, so use mock here. # Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn) monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -1,4 +1,3 @@
from httpx import AsyncByteStream
from sanic_testing.reusable import ReusableClient from sanic_testing.reusable import ReusableClient
from sanic.response import json, text from sanic.response import json, text