Compare commits
	
		
			1 Commits
		
	
	
		
			middleware
			...
			v21.9.3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 8673021ad4 | 
| @@ -1 +1 @@ | ||||
| __version__ = "21.9.0" | ||||
| __version__ = "21.9.3" | ||||
|   | ||||
							
								
								
									
										19
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -173,18 +173,18 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): | ||||
|         self.asgi = False | ||||
|         self.auto_reload = False | ||||
|         self.blueprints: Dict[str, Blueprint] = {} | ||||
|         self.config = config or Config( | ||||
|             load_env=load_env, env_prefix=env_prefix | ||||
|         self.config: Config = config or Config( | ||||
|             load_env=load_env, | ||||
|             env_prefix=env_prefix, | ||||
|             app=self, | ||||
|         ) | ||||
|         self.configure_logging = configure_logging | ||||
|         self.ctx = ctx or SimpleNamespace() | ||||
|         self.configure_logging: bool = configure_logging | ||||
|         self.ctx: Any = ctx or SimpleNamespace() | ||||
|         self.debug = None | ||||
|         self.error_handler = error_handler or ErrorHandler( | ||||
|             fallback=self.config.FALLBACK_ERROR_FORMAT, | ||||
|         ) | ||||
|         self.error_handler: ErrorHandler = error_handler or ErrorHandler() | ||||
|         self.is_running = False | ||||
|         self.is_stopping = False | ||||
|         self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) | ||||
|         self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) | ||||
|         self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} | ||||
|         self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} | ||||
|         self.reload_dirs: Set[Path] = set() | ||||
| @@ -1474,6 +1474,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): | ||||
|     async def _startup(self): | ||||
|         self.signalize() | ||||
|         self.finalize() | ||||
|         ErrorHandler.finalize( | ||||
|             self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT | ||||
|         ) | ||||
|         TouchUp.run(self) | ||||
|  | ||||
|     async def _server_event( | ||||
|   | ||||
| @@ -1,7 +1,9 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from inspect import isclass | ||||
| from os import environ | ||||
| from pathlib import Path | ||||
| from typing import Any, Dict, Optional, Union | ||||
| from typing import TYPE_CHECKING, Any, Dict, Optional, Union | ||||
| from warnings import warn | ||||
|  | ||||
| from sanic.errorpages import check_error_format | ||||
| @@ -10,6 +12,10 @@ from sanic.http import Http | ||||
| from .utils import load_module_from_file_location, str_to_bool | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING:  # no cov | ||||
|     from sanic import Sanic | ||||
|  | ||||
|  | ||||
| SANIC_PREFIX = "SANIC_" | ||||
| BASE_LOGO = """ | ||||
|  | ||||
| @@ -71,11 +77,14 @@ class Config(dict): | ||||
|         load_env: Optional[Union[bool, str]] = True, | ||||
|         env_prefix: Optional[str] = SANIC_PREFIX, | ||||
|         keep_alive: Optional[bool] = None, | ||||
|         *, | ||||
|         app: Optional[Sanic] = None, | ||||
|     ): | ||||
|         defaults = defaults or {} | ||||
|         super().__init__({**DEFAULT_CONFIG, **defaults}) | ||||
|  | ||||
|         self.LOGO = BASE_LOGO | ||||
|         self._app = app | ||||
|         self._LOGO = BASE_LOGO | ||||
|  | ||||
|         if keep_alive is not None: | ||||
|             self.KEEP_ALIVE = keep_alive | ||||
| @@ -97,6 +106,7 @@ class Config(dict): | ||||
|  | ||||
|         self._configure_header_size() | ||||
|         self._check_error_format() | ||||
|         self._init = True | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         try: | ||||
| @@ -104,16 +114,51 @@ class Config(dict): | ||||
|         except KeyError as ke: | ||||
|             raise AttributeError(f"Config has no '{ke.args[0]}'") | ||||
|  | ||||
|     def __setattr__(self, attr, value): | ||||
|         self[attr] = value | ||||
|         if attr in ( | ||||
|             "REQUEST_MAX_HEADER_SIZE", | ||||
|             "REQUEST_BUFFER_SIZE", | ||||
|             "REQUEST_MAX_SIZE", | ||||
|         ): | ||||
|             self._configure_header_size() | ||||
|         elif attr == "FALLBACK_ERROR_FORMAT": | ||||
|             self._check_error_format() | ||||
|     def __setattr__(self, attr, value) -> None: | ||||
|         self.update({attr: value}) | ||||
|  | ||||
|     def __setitem__(self, attr, value) -> None: | ||||
|         self.update({attr: value}) | ||||
|  | ||||
|     def update(self, *other, **kwargs) -> None: | ||||
|         other_mapping = {k: v for item in other for k, v in dict(item).items()} | ||||
|         super().update(*other, **kwargs) | ||||
|         for attr, value in {**other_mapping, **kwargs}.items(): | ||||
|             self._post_set(attr, value) | ||||
|  | ||||
|     def _post_set(self, attr, value) -> None: | ||||
|         if self.get("_init"): | ||||
|             if attr in ( | ||||
|                 "REQUEST_MAX_HEADER_SIZE", | ||||
|                 "REQUEST_BUFFER_SIZE", | ||||
|                 "REQUEST_MAX_SIZE", | ||||
|             ): | ||||
|                 self._configure_header_size() | ||||
|             elif attr == "FALLBACK_ERROR_FORMAT": | ||||
|                 self._check_error_format() | ||||
|                 if self.app and value != self.app.error_handler.fallback: | ||||
|                     if self.app.error_handler.fallback != "auto": | ||||
|                         warn( | ||||
|                             "Overriding non-default ErrorHandler fallback " | ||||
|                             "value. Changing from " | ||||
|                             f"{self.app.error_handler.fallback} to {value}." | ||||
|                         ) | ||||
|                     self.app.error_handler.fallback = value | ||||
|             elif attr == "LOGO": | ||||
|                 self._LOGO = value | ||||
|                 warn( | ||||
|                     "Setting the config.LOGO is deprecated and will no longer " | ||||
|                     "be supported starting in v22.6.", | ||||
|                     DeprecationWarning, | ||||
|                 ) | ||||
|  | ||||
|     @property | ||||
|     def app(self): | ||||
|         return self._app | ||||
|  | ||||
|     @property | ||||
|     def LOGO(self): | ||||
|         return self._LOGO | ||||
|  | ||||
|     def _configure_header_size(self): | ||||
|         Http.set_header_max_size( | ||||
|   | ||||
| @@ -393,7 +393,8 @@ def exception_response( | ||||
|             # from the route | ||||
|             if request.route: | ||||
|                 try: | ||||
|                     render_format = request.route.ctx.error_format | ||||
|                     if request.route.ctx.error_format: | ||||
|                         render_format = request.route.ctx.error_format | ||||
|                 except AttributeError: | ||||
|                     ... | ||||
|  | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| from inspect import signature | ||||
| from typing import Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response | ||||
| @@ -25,7 +26,9 @@ class ErrorHandler: | ||||
|     """ | ||||
|  | ||||
|     # Beginning in v22.3, the base renderer will be TextRenderer | ||||
|     def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer): | ||||
|     def __init__( | ||||
|         self, fallback: str = "auto", base: Type[BaseRenderer] = HTMLRenderer | ||||
|     ): | ||||
|         self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] | ||||
|         self.cached_handlers: Dict[ | ||||
|             Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] | ||||
| @@ -34,6 +37,41 @@ class ErrorHandler: | ||||
|         self.fallback = fallback | ||||
|         self.base = base | ||||
|  | ||||
|     @classmethod | ||||
|     def finalize(cls, error_handler, fallback: Optional[str] = None): | ||||
|         if ( | ||||
|             fallback | ||||
|             and fallback != "auto" | ||||
|             and error_handler.fallback == "auto" | ||||
|         ): | ||||
|             error_handler.fallback = fallback | ||||
|  | ||||
|         if not isinstance(error_handler, cls): | ||||
|             error_logger.warning( | ||||
|                 f"Error handler is non-conforming: {type(error_handler)}" | ||||
|             ) | ||||
|  | ||||
|         sig = signature(error_handler.lookup) | ||||
|         if len(sig.parameters) == 1: | ||||
|             error_logger.warning( | ||||
|                 DeprecationWarning( | ||||
|                     "You are using a deprecated error handler. The lookup " | ||||
|                     "method should accept two positional parameters: " | ||||
|                     "(exception, route_name: Optional[str]). " | ||||
|                     "Until you upgrade your ErrorHandler.lookup, Blueprint " | ||||
|                     "specific exceptions will not work properly. Beginning " | ||||
|                     "in v22.3, the legacy style lookup method will not " | ||||
|                     "work at all." | ||||
|                 ), | ||||
|             ) | ||||
|             error_handler._lookup = error_handler._legacy_lookup | ||||
|  | ||||
|     def _full_lookup(self, exception, route_name: Optional[str] = None): | ||||
|         return self.lookup(exception, route_name) | ||||
|  | ||||
|     def _legacy_lookup(self, exception, route_name: Optional[str] = None): | ||||
|         return self.lookup(exception) | ||||
|  | ||||
|     def add(self, exception, handler, route_names: Optional[List[str]] = None): | ||||
|         """ | ||||
|         Add a new exception handler to an already existing handler object. | ||||
| @@ -56,7 +94,7 @@ class ErrorHandler: | ||||
|         else: | ||||
|             self.cached_handlers[(exception, None)] = handler | ||||
|  | ||||
|     def lookup(self, exception, route_name: Optional[str]): | ||||
|     def lookup(self, exception, route_name: Optional[str] = None): | ||||
|         """ | ||||
|         Lookup the existing instance of :class:`ErrorHandler` and fetch the | ||||
|         registered handler for a specific type of exception. | ||||
| @@ -94,6 +132,8 @@ class ErrorHandler: | ||||
|         handler = None | ||||
|         return handler | ||||
|  | ||||
|     _lookup = _full_lookup | ||||
|  | ||||
|     def response(self, request, exception): | ||||
|         """Fetches and executes an exception handler and returns a response | ||||
|         object | ||||
| @@ -109,7 +149,7 @@ class ErrorHandler: | ||||
|             or registered handler for that type of exception. | ||||
|         """ | ||||
|         route_name = request.name if request else None | ||||
|         handler = self.lookup(exception, route_name) | ||||
|         handler = self._lookup(exception, route_name) | ||||
|         response = None | ||||
|         try: | ||||
|             if handler: | ||||
|   | ||||
| @@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta): | ||||
|         self.keep_alive = True | ||||
|         self.stage: Stage = Stage.IDLE | ||||
|         self.dispatch = self.protocol.app.dispatch | ||||
|         self.init_for_request() | ||||
|  | ||||
|     def init_for_request(self): | ||||
|         """Init/reset all per-request variables.""" | ||||
| @@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta): | ||||
|         """ | ||||
|         HTTP 1.1 connection handler | ||||
|         """ | ||||
|         while True:  # As long as connection stays keep-alive | ||||
|         # Handle requests while the connection stays reusable | ||||
|         while self.keep_alive and self.stage is Stage.IDLE: | ||||
|             self.init_for_request() | ||||
|             # Wait for incoming bytes (in IDLE stage) | ||||
|             if not self.recv_buffer: | ||||
|                 await self._receive_more() | ||||
|             self.stage = Stage.REQUEST | ||||
|             try: | ||||
|                 # Receive and handle a request | ||||
|                 self.stage = Stage.REQUEST | ||||
|                 self.response_func = self.http1_response_header | ||||
|  | ||||
|                 await self.http1_request_header() | ||||
|  | ||||
|                 self.stage = Stage.HANDLER | ||||
|                 self.request.conn_info = self.protocol.conn_info | ||||
|                 await self.protocol.request_handler(self.request) | ||||
|  | ||||
| @@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta): | ||||
|                 if self.response: | ||||
|                     self.response.stream = None | ||||
|  | ||||
|             # Exit and disconnect if no more requests can be taken | ||||
|             if self.stage is not Stage.IDLE or not self.keep_alive: | ||||
|                 break | ||||
|  | ||||
|             self.init_for_request() | ||||
|  | ||||
|             # Wait for the next request | ||||
|             if not self.recv_buffer: | ||||
|                 await self._receive_more() | ||||
|  | ||||
|     async def http1_request_header(self):  # no cov | ||||
|         """ | ||||
|         Receive and parse request header into self.request. | ||||
| @@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta): | ||||
|  | ||||
|         # Remove header and its trailing CRLF | ||||
|         del buf[: pos + 4] | ||||
|         self.stage = Stage.HANDLER | ||||
|         self.request, request.stream = request, self | ||||
|         self.protocol.state["requests_count"] += 1 | ||||
|  | ||||
|   | ||||
| @@ -918,7 +918,7 @@ class RouteMixin: | ||||
|  | ||||
|         return route | ||||
|  | ||||
|     def _determine_error_format(self, handler) -> str: | ||||
|     def _determine_error_format(self, handler) -> Optional[str]: | ||||
|         if not isinstance(handler, CompositionView): | ||||
|             try: | ||||
|                 src = dedent(getsource(handler)) | ||||
| @@ -930,7 +930,7 @@ class RouteMixin: | ||||
|             except (OSError, TypeError): | ||||
|                 ... | ||||
|  | ||||
|         return "auto" | ||||
|         return None | ||||
|  | ||||
|     def _get_response_types(self, node): | ||||
|         types = set() | ||||
|   | ||||
| @@ -139,11 +139,10 @@ class Router(BaseRouter): | ||||
|             route.ctx.stream = stream | ||||
|             route.ctx.hosts = hosts | ||||
|             route.ctx.static = static | ||||
|             route.ctx.error_format = ( | ||||
|                 error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT | ||||
|             ) | ||||
|             route.ctx.error_format = error_format | ||||
|  | ||||
|             check_error_format(route.ctx.error_format) | ||||
|             if error_format: | ||||
|                 check_error_format(route.ctx.error_format) | ||||
|  | ||||
|             routes.append(route) | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,7 @@ from os import environ | ||||
| from pathlib import Path | ||||
| from tempfile import TemporaryDirectory | ||||
| from textwrap import dedent | ||||
| from unittest.mock import Mock | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| @@ -350,3 +351,40 @@ def test_update_from_lowercase_key(app): | ||||
|     d = {"test_setting_value": 1} | ||||
|     app.update_config(d) | ||||
|     assert "test_setting_value" not in app.config | ||||
|  | ||||
|  | ||||
| def test_deprecation_notice_when_setting_logo(app): | ||||
|     message = ( | ||||
|         "Setting the config.LOGO is deprecated and will no longer be " | ||||
|         "supported starting in v22.6." | ||||
|     ) | ||||
|     with pytest.warns(DeprecationWarning, match=message): | ||||
|         app.config.LOGO = "My Custom Logo" | ||||
|  | ||||
|  | ||||
| def test_config_set_methods(app, monkeypatch): | ||||
|     post_set = Mock() | ||||
|     monkeypatch.setattr(Config, "_post_set", post_set) | ||||
|  | ||||
|     app.config.FOO = 1 | ||||
|     post_set.assert_called_once_with("FOO", 1) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config["FOO"] = 2 | ||||
|     post_set.assert_called_once_with("FOO", 2) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update({"FOO": 3}) | ||||
|     post_set.assert_called_once_with("FOO", 3) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update([("FOO", 4)]) | ||||
|     post_set.assert_called_once_with("FOO", 4) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update(FOO=5) | ||||
|     post_set.assert_called_once_with("FOO", 5) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update_config({"FOO": 6}) | ||||
|     post_set.assert_called_once_with("FOO", 6) | ||||
|   | ||||
| @@ -1,8 +1,10 @@ | ||||
| import pytest | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.config import Config | ||||
| from sanic.errorpages import HTMLRenderer, exception_response | ||||
| from sanic.exceptions import NotFound, SanicException | ||||
| from sanic.handlers import ErrorHandler | ||||
| from sanic.request import Request | ||||
| from sanic.response import HTTPResponse, html, json, text | ||||
|  | ||||
| @@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected): | ||||
|         ) | ||||
|  | ||||
|     assert response.content_type == expected | ||||
|  | ||||
|  | ||||
| def test_allow_fallback_error_format_set_main_process_start(app): | ||||
|     @app.main_process_start | ||||
|     async def start(app, _): | ||||
|         app.config.FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     request, response = app.test_client.get("/error") | ||||
|     assert request.app.error_handler.fallback == "text" | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|  | ||||
| def test_setting_fallback_to_non_default_raise_warning(app): | ||||
|     app.error_handler = ErrorHandler(fallback="text") | ||||
|  | ||||
|     assert app.error_handler.fallback == "text" | ||||
|  | ||||
|     with pytest.warns( | ||||
|         UserWarning, | ||||
|         match=( | ||||
|             "Overriding non-default ErrorHandler fallback value. " | ||||
|             "Changing from text to auto." | ||||
|         ), | ||||
|     ): | ||||
|         app.config.FALLBACK_ERROR_FORMAT = "auto" | ||||
|  | ||||
|     assert app.error_handler.fallback == "auto" | ||||
|  | ||||
|     app.config.FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     with pytest.warns( | ||||
|         UserWarning, | ||||
|         match=( | ||||
|             "Overriding non-default ErrorHandler fallback value. " | ||||
|             "Changing from text to json." | ||||
|         ), | ||||
|     ): | ||||
|         app.config.FALLBACK_ERROR_FORMAT = "json" | ||||
|  | ||||
|     assert app.error_handler.fallback == "json" | ||||
|  | ||||
|  | ||||
| def test_allow_fallback_error_format_in_config_injection(): | ||||
|     class MyConfig(Config): | ||||
|         FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     app = Sanic("test", config=MyConfig()) | ||||
|  | ||||
|     @app.route("/error", methods=["GET", "POST"]) | ||||
|     def err(request): | ||||
|         raise Exception("something went wrong") | ||||
|  | ||||
|     request, response = app.test_client.get("/error") | ||||
|     assert request.app.error_handler.fallback == "text" | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|  | ||||
| def test_allow_fallback_error_format_in_config_replacement(app): | ||||
|     class MyConfig(Config): | ||||
|         FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     app.config = MyConfig() | ||||
|  | ||||
|     request, response = app.test_client.get("/error") | ||||
|     assert request.app.error_handler.fallback == "text" | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import warnings | ||||
| import pytest | ||||
|  | ||||
| from bs4 import BeautifulSoup | ||||
| from websockets.version import version as websockets_version | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.exceptions import ( | ||||
| @@ -16,7 +17,6 @@ from sanic.exceptions import ( | ||||
|     abort, | ||||
| ) | ||||
| from sanic.response import text | ||||
| from websockets.version import version as websockets_version | ||||
|  | ||||
|  | ||||
| class SanicExceptionTestException(Exception): | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import asyncio | ||||
| import logging | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| @@ -206,3 +207,23 @@ def test_exception_handler_processed_request_middleware(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/8") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "Done." | ||||
|  | ||||
|  | ||||
| def test_single_arg_exception_handler_notice(exception_handler_app, caplog): | ||||
|     class CustomErrorHandler(ErrorHandler): | ||||
|         def lookup(self, exception): | ||||
|             return super().lookup(exception, None) | ||||
|  | ||||
|     exception_handler_app.error_handler = CustomErrorHandler() | ||||
|  | ||||
|     with caplog.at_level(logging.WARNING): | ||||
|         _, response = exception_handler_app.test_client.get("/1") | ||||
|  | ||||
|     assert caplog.records[0].message == ( | ||||
|         "You are using a deprecated error handler. The lookup method should " | ||||
|         "accept two positional parameters: (exception, route_name: " | ||||
|         "Optional[str]). Until you upgrade your ErrorHandler.lookup, " | ||||
|         "Blueprint specific exceptions will not work properly. Beginning in " | ||||
|         "v22.3, the legacy style lookup method will not work at all." | ||||
|     ) | ||||
|     assert response.status == 400 | ||||
|   | ||||
| @@ -1,109 +0,0 @@ | ||||
| import asyncio | ||||
|  | ||||
| import httpcore | ||||
| import httpx | ||||
| import pytest | ||||
|  | ||||
| from sanic_testing.testing import SanicTestClient | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.response import text | ||||
|  | ||||
|  | ||||
| class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection): | ||||
|     async def arequest(self, *args, **kwargs): | ||||
|         await asyncio.sleep(2) | ||||
|         return await super().arequest(*args, **kwargs) | ||||
|  | ||||
|     async def _open_socket(self, *args, **kwargs): | ||||
|         retval = await super()._open_socket(*args, **kwargs) | ||||
|         if self._request_delay: | ||||
|             await asyncio.sleep(self._request_delay) | ||||
|         return retval | ||||
|  | ||||
|  | ||||
| class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool): | ||||
|     def __init__(self, request_delay=None, *args, **kwargs): | ||||
|         self._request_delay = request_delay | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     async def _add_to_pool(self, connection, timeout): | ||||
|         connection.__class__ = DelayableHTTPConnection | ||||
|         connection._request_delay = self._request_delay | ||||
|         await super()._add_to_pool(connection, timeout) | ||||
|  | ||||
|  | ||||
| class DelayableSanicSession(httpx.AsyncClient): | ||||
|     def __init__(self, request_delay=None, *args, **kwargs) -> None: | ||||
|         transport = DelayableSanicConnectionPool(request_delay=request_delay) | ||||
|         super().__init__(transport=transport, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class DelayableSanicTestClient(SanicTestClient): | ||||
|     def __init__(self, app, request_delay=None): | ||||
|         super().__init__(app) | ||||
|         self._request_delay = request_delay | ||||
|         self._loop = None | ||||
|  | ||||
|     def get_new_session(self): | ||||
|         return DelayableSanicSession(request_delay=self._request_delay) | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def request_no_timeout_app(): | ||||
|     app = Sanic("test_request_no_timeout") | ||||
|     app.config.REQUEST_TIMEOUT = 0.6 | ||||
|  | ||||
|     @app.route("/1") | ||||
|     async def handler2(request): | ||||
|         return text("OK") | ||||
|  | ||||
|     return app | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def request_timeout_default_app(): | ||||
|     app = Sanic("test_request_timeout_default") | ||||
|     app.config.REQUEST_TIMEOUT = 0.6 | ||||
|  | ||||
|     @app.route("/1") | ||||
|     async def handler1(request): | ||||
|         return text("OK") | ||||
|  | ||||
|     @app.websocket("/ws1") | ||||
|     async def ws_handler1(request, ws): | ||||
|         await ws.send("OK") | ||||
|  | ||||
|     return app | ||||
|  | ||||
|  | ||||
| def test_default_server_error_request_timeout(request_timeout_default_app): | ||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) | ||||
|     _, response = client.get("/1") | ||||
|     assert response.status == 408 | ||||
|     assert "Request Timeout" in response.text | ||||
|  | ||||
|  | ||||
| def test_default_server_error_request_dont_timeout(request_no_timeout_app): | ||||
|     client = DelayableSanicTestClient(request_no_timeout_app, 0.2) | ||||
|     _, response = client.get("/1") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "OK" | ||||
|  | ||||
|  | ||||
| def test_default_server_error_websocket_request_timeout( | ||||
|     request_timeout_default_app, | ||||
| ): | ||||
|  | ||||
|     headers = { | ||||
|         "Upgrade": "websocket", | ||||
|         "Connection": "upgrade", | ||||
|         "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", | ||||
|         "Sec-WebSocket-Version": "13", | ||||
|     } | ||||
|  | ||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) | ||||
|     _, response = client.get("/ws1", headers=headers) | ||||
|  | ||||
|     assert response.status == 408 | ||||
|     assert "Request Timeout" in response.text | ||||
| @@ -26,6 +26,7 @@ def protocol(app, mock_transport): | ||||
|     protocol = HttpProtocol(loop=loop, app=app) | ||||
|     protocol.connection_made(mock_transport) | ||||
|     protocol._setup_connection() | ||||
|     protocol._http.init_for_request() | ||||
|     protocol._task = Mock(spec=asyncio.Task) | ||||
|     protocol._task.cancel = Mock() | ||||
|     return protocol | ||||
|   | ||||
		Reference in New Issue
	
	Block a user