Compare commits
	
		
			4 Commits
		
	
	
		
			sml-change
			...
			v21.9.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | af1d289a45 | ||
|   | b20b3cb417 | ||
|   | 45c22f9af2 | ||
|   | 71d845786d | 
| @@ -1 +1 @@ | |||||||
| __version__ = "21.9.1" | __version__ = "21.9.2" | ||||||
|   | |||||||
							
								
								
									
										24
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -173,18 +173,16 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): | |||||||
|         self.asgi = False |         self.asgi = False | ||||||
|         self.auto_reload = False |         self.auto_reload = False | ||||||
|         self.blueprints: Dict[str, Blueprint] = {} |         self.blueprints: Dict[str, Blueprint] = {} | ||||||
|         self.config = config or Config( |         self.config: Config = config or Config( | ||||||
|             load_env=load_env, env_prefix=env_prefix |             load_env=load_env, | ||||||
|  |             env_prefix=env_prefix, | ||||||
|  |             app=self, | ||||||
|         ) |         ) | ||||||
|         self.configure_logging = configure_logging |         self.configure_logging: bool = configure_logging | ||||||
|         self.ctx = ctx or SimpleNamespace() |         self.ctx: Any = ctx or SimpleNamespace() | ||||||
|         self.debug = None |         self.debug = False | ||||||
|         self.error_handler = error_handler or ErrorHandler( |         self.error_handler: ErrorHandler = error_handler or ErrorHandler() | ||||||
|             fallback=self.config.FALLBACK_ERROR_FORMAT, |         self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) | ||||||
|         ) |  | ||||||
|         self.is_running = False |  | ||||||
|         self.is_stopping = False |  | ||||||
|         self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) |  | ||||||
|         self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} |         self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} | ||||||
|         self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} |         self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} | ||||||
|         self.reload_dirs: Set[Path] = set() |         self.reload_dirs: Set[Path] = set() | ||||||
| @@ -1474,7 +1472,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): | |||||||
|     async def _startup(self): |     async def _startup(self): | ||||||
|         self.signalize() |         self.signalize() | ||||||
|         self.finalize() |         self.finalize() | ||||||
|         ErrorHandler.finalize(self.error_handler) |         ErrorHandler.finalize( | ||||||
|  |             self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT | ||||||
|  |         ) | ||||||
|         TouchUp.run(self) |         TouchUp.run(self) | ||||||
|  |  | ||||||
|     async def _server_event( |     async def _server_event( | ||||||
|   | |||||||
| @@ -1,7 +1,9 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
| from inspect import isclass | from inspect import isclass | ||||||
| from os import environ | from os import environ | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Any, Dict, Optional, Union | from typing import TYPE_CHECKING, Any, Dict, Optional, Union | ||||||
| from warnings import warn | from warnings import warn | ||||||
|  |  | ||||||
| from sanic.errorpages import check_error_format | from sanic.errorpages import check_error_format | ||||||
| @@ -10,6 +12,10 @@ from sanic.http import Http | |||||||
| from .utils import load_module_from_file_location, str_to_bool | from .utils import load_module_from_file_location, str_to_bool | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING:  # no cov | ||||||
|  |     from sanic import Sanic | ||||||
|  |  | ||||||
|  |  | ||||||
| SANIC_PREFIX = "SANIC_" | SANIC_PREFIX = "SANIC_" | ||||||
| BASE_LOGO = """ | BASE_LOGO = """ | ||||||
|  |  | ||||||
| @@ -71,11 +77,14 @@ class Config(dict): | |||||||
|         load_env: Optional[Union[bool, str]] = True, |         load_env: Optional[Union[bool, str]] = True, | ||||||
|         env_prefix: Optional[str] = SANIC_PREFIX, |         env_prefix: Optional[str] = SANIC_PREFIX, | ||||||
|         keep_alive: Optional[bool] = None, |         keep_alive: Optional[bool] = None, | ||||||
|  |         *, | ||||||
|  |         app: Optional[Sanic] = None, | ||||||
|     ): |     ): | ||||||
|         defaults = defaults or {} |         defaults = defaults or {} | ||||||
|         super().__init__({**DEFAULT_CONFIG, **defaults}) |         super().__init__({**DEFAULT_CONFIG, **defaults}) | ||||||
|  |  | ||||||
|         self.LOGO = BASE_LOGO |         self._app = app | ||||||
|  |         self._LOGO = "" | ||||||
|  |  | ||||||
|         if keep_alive is not None: |         if keep_alive is not None: | ||||||
|             self.KEEP_ALIVE = keep_alive |             self.KEEP_ALIVE = keep_alive | ||||||
| @@ -97,6 +106,7 @@ class Config(dict): | |||||||
|  |  | ||||||
|         self._configure_header_size() |         self._configure_header_size() | ||||||
|         self._check_error_format() |         self._check_error_format() | ||||||
|  |         self._init = True | ||||||
|  |  | ||||||
|     def __getattr__(self, attr): |     def __getattr__(self, attr): | ||||||
|         try: |         try: | ||||||
| @@ -104,8 +114,20 @@ class Config(dict): | |||||||
|         except KeyError as ke: |         except KeyError as ke: | ||||||
|             raise AttributeError(f"Config has no '{ke.args[0]}'") |             raise AttributeError(f"Config has no '{ke.args[0]}'") | ||||||
|  |  | ||||||
|     def __setattr__(self, attr, value): |     def __setattr__(self, attr, value) -> None: | ||||||
|         self[attr] = value |         self.update({attr: value}) | ||||||
|  |  | ||||||
|  |     def __setitem__(self, attr, value) -> None: | ||||||
|  |         self.update({attr: value}) | ||||||
|  |  | ||||||
|  |     def update(self, *other, **kwargs) -> None: | ||||||
|  |         other_mapping = {k: v for item in other for k, v in dict(item).items()} | ||||||
|  |         super().update(*other, **kwargs) | ||||||
|  |         for attr, value in {**other_mapping, **kwargs}.items(): | ||||||
|  |             self._post_set(attr, value) | ||||||
|  |  | ||||||
|  |     def _post_set(self, attr, value) -> None: | ||||||
|  |         if self.get("_init"): | ||||||
|             if attr in ( |             if attr in ( | ||||||
|                 "REQUEST_MAX_HEADER_SIZE", |                 "REQUEST_MAX_HEADER_SIZE", | ||||||
|                 "REQUEST_BUFFER_SIZE", |                 "REQUEST_BUFFER_SIZE", | ||||||
| @@ -114,6 +136,29 @@ class Config(dict): | |||||||
|                 self._configure_header_size() |                 self._configure_header_size() | ||||||
|             elif attr == "FALLBACK_ERROR_FORMAT": |             elif attr == "FALLBACK_ERROR_FORMAT": | ||||||
|                 self._check_error_format() |                 self._check_error_format() | ||||||
|  |                 if self.app and value != self.app.error_handler.fallback: | ||||||
|  |                     if self.app.error_handler.fallback != "auto": | ||||||
|  |                         warn( | ||||||
|  |                             "Overriding non-default ErrorHandler fallback " | ||||||
|  |                             "value. Changing from " | ||||||
|  |                             f"{self.app.error_handler.fallback} to {value}." | ||||||
|  |                         ) | ||||||
|  |                     self.app.error_handler.fallback = value | ||||||
|  |             elif attr == "LOGO": | ||||||
|  |                 self._LOGO = value | ||||||
|  |                 warn( | ||||||
|  |                     "Setting the config.LOGO is deprecated and will no longer " | ||||||
|  |                     "be supported starting in v22.6.", | ||||||
|  |                     DeprecationWarning, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def app(self): | ||||||
|  |         return self._app | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def LOGO(self): | ||||||
|  |         return self._LOGO | ||||||
|  |  | ||||||
|     def _configure_header_size(self): |     def _configure_header_size(self): | ||||||
|         Http.set_header_max_size( |         Http.set_header_max_size( | ||||||
|   | |||||||
| @@ -383,6 +383,7 @@ def exception_response( | |||||||
|     """ |     """ | ||||||
|     content_type = None |     content_type = None | ||||||
|  |  | ||||||
|  |     print("exception_response", fallback) | ||||||
|     if not renderer: |     if not renderer: | ||||||
|         # Make sure we have something set |         # Make sure we have something set | ||||||
|         renderer = base |         renderer = base | ||||||
| @@ -393,6 +394,7 @@ def exception_response( | |||||||
|             # from the route |             # from the route | ||||||
|             if request.route: |             if request.route: | ||||||
|                 try: |                 try: | ||||||
|  |                     if request.route.ctx.error_format: | ||||||
|                         render_format = request.route.ctx.error_format |                         render_format = request.route.ctx.error_format | ||||||
|                 except AttributeError: |                 except AttributeError: | ||||||
|                     ... |                     ... | ||||||
|   | |||||||
| @@ -38,7 +38,14 @@ class ErrorHandler: | |||||||
|         self.base = base |         self.base = base | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def finalize(cls, error_handler): |     def finalize(cls, error_handler, fallback: Optional[str] = None): | ||||||
|  |         if ( | ||||||
|  |             fallback | ||||||
|  |             and fallback != "auto" | ||||||
|  |             and error_handler.fallback == "auto" | ||||||
|  |         ): | ||||||
|  |             error_handler.fallback = fallback | ||||||
|  |  | ||||||
|         if not isinstance(error_handler, cls): |         if not isinstance(error_handler, cls): | ||||||
|             error_logger.warning( |             error_logger.warning( | ||||||
|                 f"Error handler is non-conforming: {type(error_handler)}" |                 f"Error handler is non-conforming: {type(error_handler)}" | ||||||
|   | |||||||
| @@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta): | |||||||
|         self.keep_alive = True |         self.keep_alive = True | ||||||
|         self.stage: Stage = Stage.IDLE |         self.stage: Stage = Stage.IDLE | ||||||
|         self.dispatch = self.protocol.app.dispatch |         self.dispatch = self.protocol.app.dispatch | ||||||
|         self.init_for_request() |  | ||||||
|  |  | ||||||
|     def init_for_request(self): |     def init_for_request(self): | ||||||
|         """Init/reset all per-request variables.""" |         """Init/reset all per-request variables.""" | ||||||
| @@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta): | |||||||
|         """ |         """ | ||||||
|         HTTP 1.1 connection handler |         HTTP 1.1 connection handler | ||||||
|         """ |         """ | ||||||
|         while True:  # As long as connection stays keep-alive |         # Handle requests while the connection stays reusable | ||||||
|  |         while self.keep_alive and self.stage is Stage.IDLE: | ||||||
|  |             self.init_for_request() | ||||||
|  |             # Wait for incoming bytes (in IDLE stage) | ||||||
|  |             if not self.recv_buffer: | ||||||
|  |                 await self._receive_more() | ||||||
|  |             self.stage = Stage.REQUEST | ||||||
|             try: |             try: | ||||||
|                 # Receive and handle a request |                 # Receive and handle a request | ||||||
|                 self.stage = Stage.REQUEST |  | ||||||
|                 self.response_func = self.http1_response_header |                 self.response_func = self.http1_response_header | ||||||
|  |  | ||||||
|                 await self.http1_request_header() |                 await self.http1_request_header() | ||||||
|  |  | ||||||
|  |                 self.stage = Stage.HANDLER | ||||||
|                 self.request.conn_info = self.protocol.conn_info |                 self.request.conn_info = self.protocol.conn_info | ||||||
|                 await self.protocol.request_handler(self.request) |                 await self.protocol.request_handler(self.request) | ||||||
|  |  | ||||||
| @@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta): | |||||||
|                 if self.response: |                 if self.response: | ||||||
|                     self.response.stream = None |                     self.response.stream = None | ||||||
|  |  | ||||||
|             # Exit and disconnect if no more requests can be taken |  | ||||||
|             if self.stage is not Stage.IDLE or not self.keep_alive: |  | ||||||
|                 break |  | ||||||
|  |  | ||||||
|             self.init_for_request() |  | ||||||
|  |  | ||||||
|             # Wait for the next request |  | ||||||
|             if not self.recv_buffer: |  | ||||||
|                 await self._receive_more() |  | ||||||
|  |  | ||||||
|     async def http1_request_header(self):  # no cov |     async def http1_request_header(self):  # no cov | ||||||
|         """ |         """ | ||||||
|         Receive and parse request header into self.request. |         Receive and parse request header into self.request. | ||||||
| @@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta): | |||||||
|  |  | ||||||
|         # Remove header and its trailing CRLF |         # Remove header and its trailing CRLF | ||||||
|         del buf[: pos + 4] |         del buf[: pos + 4] | ||||||
|         self.stage = Stage.HANDLER |  | ||||||
|         self.request, request.stream = request, self |         self.request, request.stream = request, self | ||||||
|         self.protocol.state["requests_count"] += 1 |         self.protocol.state["requests_count"] += 1 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -918,7 +918,7 @@ class RouteMixin: | |||||||
|  |  | ||||||
|         return route |         return route | ||||||
|  |  | ||||||
|     def _determine_error_format(self, handler) -> str: |     def _determine_error_format(self, handler) -> Optional[str]: | ||||||
|         if not isinstance(handler, CompositionView): |         if not isinstance(handler, CompositionView): | ||||||
|             try: |             try: | ||||||
|                 src = dedent(getsource(handler)) |                 src = dedent(getsource(handler)) | ||||||
| @@ -930,7 +930,7 @@ class RouteMixin: | |||||||
|             except (OSError, TypeError): |             except (OSError, TypeError): | ||||||
|                 ... |                 ... | ||||||
|  |  | ||||||
|         return "auto" |         return None | ||||||
|  |  | ||||||
|     def _get_response_types(self, node): |     def _get_response_types(self, node): | ||||||
|         types = set() |         types = set() | ||||||
|   | |||||||
| @@ -139,10 +139,9 @@ class Router(BaseRouter): | |||||||
|             route.ctx.stream = stream |             route.ctx.stream = stream | ||||||
|             route.ctx.hosts = hosts |             route.ctx.hosts = hosts | ||||||
|             route.ctx.static = static |             route.ctx.static = static | ||||||
|             route.ctx.error_format = ( |             route.ctx.error_format = error_format | ||||||
|                 error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|  |             if error_format: | ||||||
|                 check_error_format(route.ctx.error_format) |                 check_error_format(route.ctx.error_format) | ||||||
|  |  | ||||||
|             routes.append(route) |             routes.append(route) | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ from os import environ | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from tempfile import TemporaryDirectory | from tempfile import TemporaryDirectory | ||||||
| from textwrap import dedent | from textwrap import dedent | ||||||
|  | from unittest.mock import Mock | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| @@ -350,3 +351,40 @@ def test_update_from_lowercase_key(app): | |||||||
|     d = {"test_setting_value": 1} |     d = {"test_setting_value": 1} | ||||||
|     app.update_config(d) |     app.update_config(d) | ||||||
|     assert "test_setting_value" not in app.config |     assert "test_setting_value" not in app.config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_deprecation_notice_when_setting_logo(app): | ||||||
|  |     message = ( | ||||||
|  |         "Setting the config.LOGO is deprecated and will no longer be " | ||||||
|  |         "supported starting in v22.6." | ||||||
|  |     ) | ||||||
|  |     with pytest.warns(DeprecationWarning, match=message): | ||||||
|  |         app.config.LOGO = "My Custom Logo" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_config_set_methods(app, monkeypatch): | ||||||
|  |     post_set = Mock() | ||||||
|  |     monkeypatch.setattr(Config, "_post_set", post_set) | ||||||
|  |  | ||||||
|  |     app.config.FOO = 1 | ||||||
|  |     post_set.assert_called_once_with("FOO", 1) | ||||||
|  |     post_set.reset_mock() | ||||||
|  |  | ||||||
|  |     app.config["FOO"] = 2 | ||||||
|  |     post_set.assert_called_once_with("FOO", 2) | ||||||
|  |     post_set.reset_mock() | ||||||
|  |  | ||||||
|  |     app.config.update({"FOO": 3}) | ||||||
|  |     post_set.assert_called_once_with("FOO", 3) | ||||||
|  |     post_set.reset_mock() | ||||||
|  |  | ||||||
|  |     app.config.update([("FOO", 4)]) | ||||||
|  |     post_set.assert_called_once_with("FOO", 4) | ||||||
|  |     post_set.reset_mock() | ||||||
|  |  | ||||||
|  |     app.config.update(FOO=5) | ||||||
|  |     post_set.assert_called_once_with("FOO", 5) | ||||||
|  |     post_set.reset_mock() | ||||||
|  |  | ||||||
|  |     app.config.update_config({"FOO": 6}) | ||||||
|  |     post_set.assert_called_once_with("FOO", 6) | ||||||
|   | |||||||
| @@ -1,8 +1,10 @@ | |||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from sanic import Sanic | from sanic import Sanic | ||||||
|  | from sanic.config import Config | ||||||
| from sanic.errorpages import HTMLRenderer, exception_response | from sanic.errorpages import HTMLRenderer, exception_response | ||||||
| from sanic.exceptions import NotFound, SanicException | from sanic.exceptions import NotFound, SanicException | ||||||
|  | from sanic.handlers import ErrorHandler | ||||||
| from sanic.request import Request | from sanic.request import Request | ||||||
| from sanic.response import HTTPResponse, html, json, text | from sanic.response import HTTPResponse, html, json, text | ||||||
|  |  | ||||||
| @@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     assert response.content_type == expected |     assert response.content_type == expected | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_allow_fallback_error_format_set_main_process_start(app): | ||||||
|  |     @app.main_process_start | ||||||
|  |     async def start(app, _): | ||||||
|  |         app.config.FALLBACK_ERROR_FORMAT = "text" | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get("/error") | ||||||
|  |     assert request.app.error_handler.fallback == "text" | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/plain; charset=utf-8" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_setting_fallback_to_non_default_raise_warning(app): | ||||||
|  |     app.error_handler = ErrorHandler(fallback="text") | ||||||
|  |  | ||||||
|  |     assert app.error_handler.fallback == "text" | ||||||
|  |  | ||||||
|  |     with pytest.warns( | ||||||
|  |         UserWarning, | ||||||
|  |         match=( | ||||||
|  |             "Overriding non-default ErrorHandler fallback value. " | ||||||
|  |             "Changing from text to auto." | ||||||
|  |         ), | ||||||
|  |     ): | ||||||
|  |         app.config.FALLBACK_ERROR_FORMAT = "auto" | ||||||
|  |  | ||||||
|  |     assert app.error_handler.fallback == "auto" | ||||||
|  |  | ||||||
|  |     app.config.FALLBACK_ERROR_FORMAT = "text" | ||||||
|  |  | ||||||
|  |     with pytest.warns( | ||||||
|  |         UserWarning, | ||||||
|  |         match=( | ||||||
|  |             "Overriding non-default ErrorHandler fallback value. " | ||||||
|  |             "Changing from text to json." | ||||||
|  |         ), | ||||||
|  |     ): | ||||||
|  |         app.config.FALLBACK_ERROR_FORMAT = "json" | ||||||
|  |  | ||||||
|  |     assert app.error_handler.fallback == "json" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_allow_fallback_error_format_in_config_injection(): | ||||||
|  |     class MyConfig(Config): | ||||||
|  |         FALLBACK_ERROR_FORMAT = "text" | ||||||
|  |  | ||||||
|  |     app = Sanic("test", config=MyConfig()) | ||||||
|  |  | ||||||
|  |     @app.route("/error", methods=["GET", "POST"]) | ||||||
|  |     def err(request): | ||||||
|  |         raise Exception("something went wrong") | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get("/error") | ||||||
|  |     assert request.app.error_handler.fallback == "text" | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/plain; charset=utf-8" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_allow_fallback_error_format_in_config_replacement(app): | ||||||
|  |     class MyConfig(Config): | ||||||
|  |         FALLBACK_ERROR_FORMAT = "text" | ||||||
|  |  | ||||||
|  |     app.config = MyConfig() | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get("/error") | ||||||
|  |     assert request.app.error_handler.fallback == "text" | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/plain; charset=utf-8" | ||||||
|   | |||||||
| @@ -1,109 +0,0 @@ | |||||||
| import asyncio |  | ||||||
|  |  | ||||||
| import httpcore |  | ||||||
| import httpx |  | ||||||
| import pytest |  | ||||||
|  |  | ||||||
| from sanic_testing.testing import SanicTestClient |  | ||||||
|  |  | ||||||
| from sanic import Sanic |  | ||||||
| from sanic.response import text |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection): |  | ||||||
|     async def arequest(self, *args, **kwargs): |  | ||||||
|         await asyncio.sleep(2) |  | ||||||
|         return await super().arequest(*args, **kwargs) |  | ||||||
|  |  | ||||||
|     async def _open_socket(self, *args, **kwargs): |  | ||||||
|         retval = await super()._open_socket(*args, **kwargs) |  | ||||||
|         if self._request_delay: |  | ||||||
|             await asyncio.sleep(self._request_delay) |  | ||||||
|         return retval |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool): |  | ||||||
|     def __init__(self, request_delay=None, *args, **kwargs): |  | ||||||
|         self._request_delay = request_delay |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
|  |  | ||||||
|     async def _add_to_pool(self, connection, timeout): |  | ||||||
|         connection.__class__ = DelayableHTTPConnection |  | ||||||
|         connection._request_delay = self._request_delay |  | ||||||
|         await super()._add_to_pool(connection, timeout) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DelayableSanicSession(httpx.AsyncClient): |  | ||||||
|     def __init__(self, request_delay=None, *args, **kwargs) -> None: |  | ||||||
|         transport = DelayableSanicConnectionPool(request_delay=request_delay) |  | ||||||
|         super().__init__(transport=transport, *args, **kwargs) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DelayableSanicTestClient(SanicTestClient): |  | ||||||
|     def __init__(self, app, request_delay=None): |  | ||||||
|         super().__init__(app) |  | ||||||
|         self._request_delay = request_delay |  | ||||||
|         self._loop = None |  | ||||||
|  |  | ||||||
|     def get_new_session(self): |  | ||||||
|         return DelayableSanicSession(request_delay=self._request_delay) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture |  | ||||||
| def request_no_timeout_app(): |  | ||||||
|     app = Sanic("test_request_no_timeout") |  | ||||||
|     app.config.REQUEST_TIMEOUT = 0.6 |  | ||||||
|  |  | ||||||
|     @app.route("/1") |  | ||||||
|     async def handler2(request): |  | ||||||
|         return text("OK") |  | ||||||
|  |  | ||||||
|     return app |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture |  | ||||||
| def request_timeout_default_app(): |  | ||||||
|     app = Sanic("test_request_timeout_default") |  | ||||||
|     app.config.REQUEST_TIMEOUT = 0.6 |  | ||||||
|  |  | ||||||
|     @app.route("/1") |  | ||||||
|     async def handler1(request): |  | ||||||
|         return text("OK") |  | ||||||
|  |  | ||||||
|     @app.websocket("/ws1") |  | ||||||
|     async def ws_handler1(request, ws): |  | ||||||
|         await ws.send("OK") |  | ||||||
|  |  | ||||||
|     return app |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_default_server_error_request_timeout(request_timeout_default_app): |  | ||||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) |  | ||||||
|     _, response = client.get("/1") |  | ||||||
|     assert response.status == 408 |  | ||||||
|     assert "Request Timeout" in response.text |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_default_server_error_request_dont_timeout(request_no_timeout_app): |  | ||||||
|     client = DelayableSanicTestClient(request_no_timeout_app, 0.2) |  | ||||||
|     _, response = client.get("/1") |  | ||||||
|     assert response.status == 200 |  | ||||||
|     assert response.text == "OK" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_default_server_error_websocket_request_timeout( |  | ||||||
|     request_timeout_default_app, |  | ||||||
| ): |  | ||||||
|  |  | ||||||
|     headers = { |  | ||||||
|         "Upgrade": "websocket", |  | ||||||
|         "Connection": "upgrade", |  | ||||||
|         "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", |  | ||||||
|         "Sec-WebSocket-Version": "13", |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) |  | ||||||
|     _, response = client.get("/ws1", headers=headers) |  | ||||||
|  |  | ||||||
|     assert response.status == 408 |  | ||||||
|     assert "Request Timeout" in response.text |  | ||||||
| @@ -26,6 +26,7 @@ def protocol(app, mock_transport): | |||||||
|     protocol = HttpProtocol(loop=loop, app=app) |     protocol = HttpProtocol(loop=loop, app=app) | ||||||
|     protocol.connection_made(mock_transport) |     protocol.connection_made(mock_transport) | ||||||
|     protocol._setup_connection() |     protocol._setup_connection() | ||||||
|  |     protocol._http.init_for_request() | ||||||
|     protocol._task = Mock(spec=asyncio.Task) |     protocol._task = Mock(spec=asyncio.Task) | ||||||
|     protocol._task.cancel = Mock() |     protocol._task.cancel = Mock() | ||||||
|     return protocol |     return protocol | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user