This commit is contained in:
Adam Hopkins 2022-08-20 22:24:43 +03:00
parent beb5c62767
commit c7bac72137
No known key found for this signature in database
GPG Key ID: 9F85EE6C807303FB
10 changed files with 367 additions and 40 deletions

View File

@ -60,7 +60,7 @@ from sanic.exceptions import (
ServerError,
URLBuildError,
)
from sanic.handlers import ErrorHandler
from sanic.handlers import ErrorHandler, RequestManager
from sanic.helpers import _default
from sanic.http import Stage
from sanic.log import (
@ -705,6 +705,14 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov
raise NotImplementedError
async def _handle_exception(
self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
@ -830,6 +838,8 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
:param request: HTTP Request object
:return: Nothing
"""
async def _handle_request(self, request: Request): # no cov
await self.dispatch(
"http.lifecycle.handle",
inline=True,

View File

@ -7,6 +7,7 @@ from urllib.parse import quote
from sanic.compat import Header
from sanic.exceptions import ServerError
from sanic.handlers import RequestManager
from sanic.helpers import _default
from sanic.http import Stage
from sanic.log import logger
@ -230,8 +231,9 @@ class ASGIApp:
"""
Handle the incoming request.
"""
manager = RequestManager.create(self.request)
try:
self.stage = Stage.HANDLER
await self.sanic_app.handle_request(self.request)
await manager.handle()
except Exception as e:
await self.sanic_app.handle_exception(self.request, e)
await manager.error(e)

View File

@ -1,16 +1,319 @@
from __future__ import annotations
from functools import partial
from inspect import isawaitable
from traceback import format_exc
from typing import Dict, List, Optional, Tuple, Type
from sanic_routing import Route
from sanic.errorpages import BaseRenderer, TextRenderer, exception_response
from sanic.exceptions import (
HeaderNotFound,
InvalidRangeType,
RangeNotSatisfiable,
SanicException,
ServerError,
)
from sanic.log import deprecation, error_logger
from sanic.http.constants import Stage
from sanic.log import deprecation, error_logger, logger
from sanic.models.handler_types import RouteHandler
from sanic.response import text
from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream, text
class RequestHandler:
def __init__(self, func, request_middleware, response_middleware):
self.func = func
self.request_middleware = request_middleware
self.response_middleware = response_middleware
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
class RequestManager:
request: Request
def __init__(self, request: Request):
self.request_middleware_run = False
self.response_middleware_run = False
self.handler = self._noop
self.set_request(request)
@classmethod
def create(cls, request: Request) -> RequestManager:
return cls(request)
def set_request(self, request: Request):
request._manager = self
self.request = request
self.request_middleware = request.app.request_middleware
self.response_middleware = request.app.response_middleware
async def handle(self):
route = self.resolve_route()
if self.handler is None:
await self.error(
ServerError(
(
"'None' was returned while requesting a "
"handler from the router"
)
)
)
return
if (
self.request.stream
and self.request.stream.request_body
and not route.ctx.ignore_body
):
await self.receive_body()
await self.lifecycle(
partial(self.handler, self.request, **self.request.match_info)
)
async def lifecycle(self, handler):
response: Optional[BaseHTTPResponse] = None
if not self.request_middleware_run and self.request_middleware:
response = await self.run(self.run_request_middleware)
if not response:
# Run response handler
response = await self.run(handler)
if not self.response_middleware_run and self.response_middleware:
response = await self.run(
partial(self.run_response_middleware, response)
)
await self.cleanup(response)
async def run(self, operation) -> Optional[BaseHTTPResponse]:
try:
response = operation()
if isawaitable(response):
response = await response
except Exception as e:
response = await self.error(e)
return response
async def error(self, exception: Exception):
error_handler = self.request.app.error_handler
if (
self.request.stream is not None
and self.request.stream.stage is not Stage.HANDLER
):
error_logger.exception(exception, exc_info=True)
logger.error(
"The error response will not be sent to the client for "
f'the following exception:"{exception}". A previous response '
"has at least partially been sent."
)
handler = error_handler._lookup(
exception, self.request.name if self.request else None
)
if handler:
logger.warning(
"An error occurred while handling the request after at "
"least some part of the response was sent to the client. "
"The response from your custom exception handler "
f"{handler.__name__} will not be sent to the client."
"Exception handlers should only be used to generate the "
"exception responses. If you would like to perform any "
"other action on a raised exception, consider using a "
"signal handler like "
'`@app.signal("http.lifecycle.exception")`\n'
"For further information, please see the docs: "
"https://sanicframework.org/en/guide/advanced/"
"signals.html",
)
return
try:
await self.lifecycle(
partial(error_handler.response, self.request, exception)
)
except Exception as e:
await self.lifecycle(
partial(error_handler.default, self.request, e)
)
if isinstance(e, SanicException):
response = error_handler.default(self.request, e)
elif self.request.app.debug:
response = HTTPResponse(
(
f"Error while handling error: {e}\n"
f"Stack: {format_exc()}"
),
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error", status=500
)
return response
return None
async def cleanup(self, response: Optional[BaseHTTPResponse]):
if self.request.responded:
if response is not None:
error_logger.error(
"The response object returned by the route handler "
"will not be sent to client. The request has already "
"been responded to."
)
if self.request.stream is not None:
response = self.request.stream.response
elif response is not None:
self.request.reset_response()
response = await self.request.respond(response) # type: ignore
elif not hasattr(self.handler, "is_websocket"):
response = self.request.stream.response # type: ignore
# Marked for cleanup and DRY with handle_request/handle_exception
# when ResponseStream is no longer supporder
if isinstance(response, BaseHTTPResponse):
# await self.dispatch(
# "http.lifecycle.response",
# inline=True,
# context={
# "request": self.request,
# "response": response,
# },
# )
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
await response(self.request) # type: ignore
# await self.dispatch(
# "http.lifecycle.response",
# inline=True,
# context={
# "request": self.request,
# "response": resp,
# },
# )
await response.eof() # type: ignore
else:
if not hasattr(self.handler, "is_websocket"):
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
)
async def receive_body(self):
if hasattr(self.handler, "is_stream"):
# Streaming handler: lift the size limit
self.request.stream.request_max_size = float("inf")
else:
# Non-streaming handler: preload body
await self.request.receive_body()
async def run_request_middleware(self) -> Optional[BaseHTTPResponse]:
self.request._request_middleware_started = True
self.request_middleware_run = True
for middleware in self.request_middleware:
# await self.dispatch(
# "http.middleware.before",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
response = await self.run(partial(middleware, self.request))
# await self.dispatch(
# "http.middleware.after",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
if response:
return response
return None
async def run_response_middleware(
self, response: BaseHTTPResponse
) -> BaseHTTPResponse:
self.response_middleware_run = True
for middleware in self.response_middleware:
# await self.dispatch(
# "http.middleware.before",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
resp = await self.run(partial(middleware, self.request, response))
# await self.dispatch(
# "http.middleware.after",
# inline=True,
# context={
# "request": request,
# "response": None,
# },
# condition={"attach_to": "request"},
# )
if resp:
return resp
return response
# try:
# middleware = (
# self.route and self.route.extra.response_middleware
# ) or self.app.response_middleware
# if middleware:
# response = await self.app._run_response_middleware(
# self, response, middleware
# )
# except CancelledErrors:
# raise
# except Exception:
# error_logger.exception(
# "Exception occurred in one of response middleware handlers"
# )
# return None
def resolve_route(self) -> Route:
# Fetch handler from router
route, handler, kwargs = self.request.app.router.get(
self.request.path,
self.request.method,
self.request.headers.getone("host", None),
)
self.request._match_info = {**kwargs}
self.request.route = route
self.handler = handler
if route.handler and route.handler.request_middleware:
self.request_middleware = route.handler.request_middleware
if route.handler and route.handler.response_middleware:
self.response_middleware = route.handler.response_middleware
return route
@staticmethod
def _noop(_):
...
class ErrorHandler:

View File

@ -124,7 +124,8 @@ class Http(Stream, metaclass=TouchUpMeta):
self.stage = Stage.HANDLER
self.request.conn_info = self.protocol.conn_info
await self.protocol.request_handler(self.request)
await self.request.manager.handle()
# Handler finished, response should've been sent
if self.stage is Stage.HANDLER and not self.upgrade_websocket:
@ -246,6 +247,7 @@ class Http(Stream, metaclass=TouchUpMeta):
transport=self.protocol.transport,
app=self.protocol.app,
)
self.protocol.request_handler.create(request)
self.protocol.request_class._current.set(request)
await self.dispatch(
"http.lifecycle.request",
@ -419,12 +421,11 @@ class Http(Stream, metaclass=TouchUpMeta):
# From request and handler states we can respond, otherwise be silent
if self.stage is Stage.HANDLER:
app = self.protocol.app
if self.request is None:
self.create_empty_request()
self.protocol.request_handler.create(self.request)
await app.handle_exception(self.request, exception)
await self.request.manager.error(exception)
def create_empty_request(self) -> None:
"""

View File

@ -33,9 +33,10 @@ class Middleware:
return self.func(*args, **kwargs)
def __repr__(self) -> str:
name = getattr(self.func, "__name__", str(self.func))
return (
f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, "
f"func=<function {name}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)

View File

@ -4,6 +4,7 @@ from operator import attrgetter
from typing import List
from sanic.base.meta import SanicMeta
from sanic.handlers import RequestHandler
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.models.futures import FutureMiddleware
from sanic.router import Router
@ -104,19 +105,23 @@ class MiddlewareMixin(metaclass=SanicMeta):
self.named_response_middleware.get(route.name, deque()),
location=MiddlewareLocation.RESPONSE,
)
route.extra.request_middleware = deque(
route.handler = RequestHandler(
route.handler,
deque(
sorted(
request_middleware,
key=attrgetter("order"),
reverse=True,
)
)
route.extra.response_middleware = deque(
),
deque(
sorted(
response_middleware,
key=attrgetter("order"),
reverse=True,
)[::-1]
),
)
request_middleware = Middleware.convert(
self.request_middleware,

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from contextvars import ContextVar
from functools import partial
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
@ -23,6 +24,7 @@ from sanic.models.http_types import Credentials
if TYPE_CHECKING:
from sanic.handlers import RequestManager
from sanic.server import ConnInfo
from sanic.app import Sanic
@ -37,7 +39,7 @@ from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse
from httptools import parse_url
from httptools.parser.errors import HttpParserInvalidURLError
from sanic.compat import CancelledErrors, Header
from sanic.compat import Header
from sanic.constants import (
CACHEABLE_HTTP_METHODS,
DEFAULT_HTTP_CONTENT_TYPE,
@ -99,6 +101,7 @@ class Request:
"_cookies",
"_id",
"_ip",
"_manager",
"_parsed_url",
"_port",
"_protocol",
@ -182,6 +185,7 @@ class Request:
self.responded: bool = False
self.route: Optional[Route] = None
self.stream: Optional[Stream] = None
self._manager: Optional[RequestManager] = None
self._cookies: Optional[Dict[str, str]] = None
self._match_info: Dict[str, Any] = {}
self._protocol = None
@ -243,6 +247,10 @@ class Request:
)
return self._stream_id
@property
def manager(self):
return self._manager
def reset_response(self):
try:
if (
@ -333,19 +341,13 @@ class Request:
if isawaitable(response):
response = await response # type: ignore
# Run response middleware
try:
middleware = (
self.route and self.route.extra.response_middleware
) or self.app.response_middleware
if middleware:
response = await self.app._run_response_middleware(
self, response, middleware
)
except CancelledErrors:
raise
except Exception:
error_logger.exception(
"Exception occurred in one of response middleware handlers"
if (
self._manager
and not self._manager.response_middleware_run
and self._manager.response_middleware
):
response = await self._manager.run(
partial(self._manager.run_response_middleware, response)
)
self.responded = True
return response

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from sanic.handlers import RequestManager
from sanic.http.constants import HTTP
from sanic.http.http3 import Http3
from sanic.touchup.meta import TouchUpMeta
@ -53,7 +54,7 @@ class HttpProtocolMixin:
def _setup(self):
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.request_handler = RequestManager
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT

View File

@ -150,8 +150,11 @@ def test_app_route_raise_value_error(app):
def test_app_handle_request_handler_is_none(app, monkeypatch):
mock = Mock()
mock.handler = None
def mockreturn(*args, **kwargs):
return Mock(), None, {}
return mock, None, {}
# Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -1,4 +1,3 @@
from httpx import AsyncByteStream
from sanic_testing.reusable import ReusableClient
from sanic.response import json, text