Compare commits
	
		
			50 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | af1d289a45 | ||
|   | b20b3cb417 | ||
|   | 45c22f9af2 | ||
|   | 71d845786d | ||
|   | 5e12edbc38 | ||
|   | 50a606adee | ||
|   | f995612073 | ||
|   | bc08383acd | ||
|   | b83a1a184c | ||
|   | 59dd6814f8 | ||
|   | f7abf3db1b | ||
|   | cf1d2148ac | ||
|   | b5f2bd9b0e | ||
|   | ba2670e99c | ||
|   | 6ffc4d9756 | ||
|   | 595d2c76ac | ||
|   | d9796e9b1e | ||
|   | 404c5f9f9e | ||
|   | a937e08ef0 | ||
|   | ef4f058a6c | ||
|   | 69c5dde9bf | ||
|   | 945885d501 | ||
|   | 9d0b54c90d | ||
|   | 2e5c288fea | ||
|   | f32ef20b74 | ||
|   | e2eefaac55 | ||
|   | e1cfbf0fd9 | ||
|   | 08c5689441 | ||
|   | 8dbda247d6 | ||
|   | 71a631237d | ||
|   | e22ff3828b | ||
|   | b1b12e004e | ||
|   | 5308fec354 | ||
|   | 0ba57d4701 | ||
|   | 54ca6a6178 | ||
|   | 7dd4a78cf2 | ||
|   | 52ff49512a | ||
|   | 4732b6bdfa | ||
|   | c3b6fa1bba | ||
|   | 94d496afe1 | ||
|   | 7b7a572f9b | ||
|   | 1b8cb742f9 | ||
|   | 3492d180a8 | ||
|   | 021da38373 | ||
|   | ac784759d5 | ||
|   | 36eda2cd62 | ||
|   | 08a4b3013f | ||
|   | 1dd0332e8b | ||
|   | a90877ac31 | ||
|   | 8b7ea27a48 | 
| @@ -10,3 +10,15 @@ exclude_patterns: | ||||
|   - "examples/" | ||||
|   - "hack/" | ||||
|   - "scripts/" | ||||
|   - "tests/" | ||||
| checks: | ||||
|   argument-count: | ||||
|     enabled: false | ||||
|   file-lines: | ||||
|     config: | ||||
|       threshold: 1000 | ||||
|   method-count: | ||||
|     config: | ||||
|       threshold: 40 | ||||
|   complex-logic: | ||||
|     enabled: false | ||||
|   | ||||
							
								
								
									
										62
									
								
								.github/workflows/pr-windows.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/pr-windows.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,34 +1,34 @@ | ||||
| # name: Run Unit Tests on Windows | ||||
| # on: | ||||
| #   pull_request: | ||||
| #     branches: | ||||
| #       - main | ||||
| name: Run Unit Tests on Windows | ||||
| on: | ||||
|   pull_request: | ||||
|     branches: | ||||
|       - main | ||||
|  | ||||
| # jobs: | ||||
| #   testsOnWindows: | ||||
| #     name: ut-${{ matrix.config.tox-env }} | ||||
| #     runs-on: windows-latest | ||||
| #     strategy: | ||||
| #       fail-fast: false | ||||
| #       matrix: | ||||
| #         config: | ||||
| #           - { python-version: 3.7, tox-env: py37-no-ext } | ||||
| #           - { python-version: 3.8, tox-env: py38-no-ext } | ||||
| #           - { python-version: 3.9, tox-env: py39-no-ext } | ||||
| #           - { python-version: pypy-3.7, tox-env: pypy37-no-ext } | ||||
| jobs: | ||||
|   testsOnWindows: | ||||
|     name: ut-${{ matrix.config.tox-env }} | ||||
|     runs-on: windows-latest | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         config: | ||||
|           - { python-version: 3.7, tox-env: py37-no-ext } | ||||
|           - { python-version: 3.8, tox-env: py38-no-ext } | ||||
|           - { python-version: 3.9, tox-env: py39-no-ext } | ||||
|           - { python-version: pypy-3.7, tox-env: pypy37-no-ext } | ||||
|  | ||||
| #     steps: | ||||
| #       - name: Checkout Repository | ||||
| #         uses: actions/checkout@v2 | ||||
|     steps: | ||||
|       - name: Checkout Repository | ||||
|         uses: actions/checkout@v2 | ||||
|  | ||||
| #       - name: Run Unit Tests | ||||
| #         uses: ahopkins/custom-actions@pip-extra-args | ||||
| #         with: | ||||
| #           python-version: ${{ matrix.config.python-version }} | ||||
| #           test-infra-tool: tox | ||||
| #           test-infra-version: latest | ||||
| #           action: tests | ||||
| #           test-additional-args: "-e=${{ matrix.config.tox-env }}" | ||||
| #           experimental-ignore-error: "true" | ||||
| #           command-timeout: "600000" | ||||
| #           pip-extra-args: "--user" | ||||
|       - name: Run Unit Tests | ||||
|         uses: ahopkins/custom-actions@pip-extra-args | ||||
|         with: | ||||
|           python-version: ${{ matrix.config.python-version }} | ||||
|           test-infra-tool: tox | ||||
|           test-infra-version: latest | ||||
|           action: tests | ||||
|           test-additional-args: "-e=${{ matrix.config.tox-env }}" | ||||
|           experimental-ignore-error: "true" | ||||
|           command-timeout: "600000" | ||||
|           pip-extra-args: "--user" | ||||
|   | ||||
| @@ -1,3 +1,22 @@ | ||||
| .. note:: | ||||
|  | ||||
|   From v21.9, CHANGELOG files are maintained in ``./docs/sanic/releases`` | ||||
|  | ||||
| Version 21.6.1 | ||||
| -------------- | ||||
|  | ||||
| Bugfixes | ||||
| ******** | ||||
|  | ||||
|   * `#2178 <https://github.com/sanic-org/sanic/pull/2178>`_ | ||||
|     Update sanic-routing to allow for better splitting of complex URI templates | ||||
|   * `#2183 <https://github.com/sanic-org/sanic/pull/2183>`_ | ||||
|     Proper handling of chunked request bodies to resolve phantom 503 in logs | ||||
|   * `#2181 <https://github.com/sanic-org/sanic/pull/2181>`_ | ||||
|     Resolve regression in exception logging | ||||
|   * `#2201 <https://github.com/sanic-org/sanic/pull/2201>`_ | ||||
|     Cleanup request info in pipelined requests | ||||
|  | ||||
| Version 21.6.0 | ||||
| -------------- | ||||
|  | ||||
|   | ||||
| @@ -19,7 +19,7 @@ a virtual environment already set up, then run: | ||||
|  | ||||
| .. code-block:: bash | ||||
|  | ||||
|    pip3 install -e . ".[dev]" | ||||
|    pip install -e ".[dev]" | ||||
|  | ||||
| Dependency Changes | ||||
| ------------------ | ||||
|   | ||||
							
								
								
									
										12
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								README.rst
									
									
									
									
									
								
							| @@ -77,17 +77,7 @@ The goal of the project is to provide a simple way to get up and running a highl | ||||
| Sponsor | ||||
| ------- | ||||
|  | ||||
| |Try CodeStream| | ||||
|  | ||||
| .. |Try CodeStream| image:: https://alt-images.codestream.com/codestream_logo_sanicorg.png | ||||
|    :target: https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner | ||||
|    :alt: Try CodeStream | ||||
|  | ||||
| Manage pull requests and conduct code reviews in your IDE with full source-tree context. Comment on any line, not just the diffs. Use jump-to-definition, your favorite keybindings, and code intelligence with more of your workflow. | ||||
|  | ||||
| `Learn More <https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner>`_ | ||||
|  | ||||
| Thank you to our sponsor. Check out `open collective <https://opencollective.com/sanic-org>`_ to learn more about helping to fund Sanic. | ||||
| Check out `open collective <https://opencollective.com/sanic-org>`_ to learn more about helping to fund Sanic. | ||||
|  | ||||
| Installation | ||||
| ------------ | ||||
|   | ||||
							
								
								
									
										20
									
								
								docs/conf.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								docs/conf.py
									
									
									
									
									
								
							| @@ -10,10 +10,8 @@ | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| # Add support for auto-doc | ||||
| import recommonmark | ||||
|  | ||||
| from recommonmark.transform import AutoStructify | ||||
| # Add support for auto-doc | ||||
|  | ||||
|  | ||||
| # Ensure that sanic is present in the path, to allow sphinx-apidoc to | ||||
| @@ -26,7 +24,7 @@ import sanic | ||||
|  | ||||
| # -- General configuration ------------------------------------------------ | ||||
|  | ||||
| extensions = ["sphinx.ext.autodoc", "recommonmark"] | ||||
| extensions = ["sphinx.ext.autodoc", "m2r2"] | ||||
|  | ||||
| templates_path = ["_templates"] | ||||
|  | ||||
| @@ -162,20 +160,6 @@ autodoc_default_options = { | ||||
|     "member-order": "groupwise", | ||||
| } | ||||
|  | ||||
|  | ||||
| # app setup hook | ||||
| def setup(app): | ||||
|     app.add_config_value( | ||||
|         "recommonmark_config", | ||||
|         { | ||||
|             "enable_eval_rst": True, | ||||
|             "enable_auto_doc_ref": False, | ||||
|         }, | ||||
|         True, | ||||
|     ) | ||||
|     app.add_transform(AutoStructify) | ||||
|  | ||||
|  | ||||
| html_theme_options = { | ||||
|     "style_external_links": False, | ||||
| } | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| 📜 Changelog | ||||
| ============ | ||||
|  | ||||
| .. mdinclude:: ./releases/21.9.md | ||||
|  | ||||
| .. include:: ../../CHANGELOG.rst | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| ♥️ Contributing | ||||
| ============== | ||||
| =============== | ||||
|  | ||||
| .. include:: ../../CONTRIBUTING.rst | ||||
|   | ||||
							
								
								
									
										40
									
								
								docs/sanic/releases/21.9.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								docs/sanic/releases/21.9.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| ## Version 21.9 | ||||
|  | ||||
| ### Features | ||||
| - [#2158](https://github.com/sanic-org/sanic/pull/2158), [#2248](https://github.com/sanic-org/sanic/pull/2248) Complete overhaul of I/O to websockets | ||||
| - [#2160](https://github.com/sanic-org/sanic/pull/2160) Add new 17 signals into server and request lifecycles | ||||
| - [#2162](https://github.com/sanic-org/sanic/pull/2162) Smarter `auto` fallback formatting upon exception | ||||
| - [#2184](https://github.com/sanic-org/sanic/pull/2184) Introduce implementation for copying a Blueprint | ||||
| - [#2200](https://github.com/sanic-org/sanic/pull/2200) Accept header parsing | ||||
| - [#2207](https://github.com/sanic-org/sanic/pull/2207) Log remote address if available | ||||
| - [#2209](https://github.com/sanic-org/sanic/pull/2209) Add convenience methods to BP groups | ||||
| - [#2216](https://github.com/sanic-org/sanic/pull/2216) Add default messages to SanicExceptions | ||||
| - [#2225](https://github.com/sanic-org/sanic/pull/2225) Type annotation convenience for annotated handlers with path parameters | ||||
| - [#2236](https://github.com/sanic-org/sanic/pull/2236) Allow Falsey (but not-None) responses from route handlers | ||||
| - [#2238](https://github.com/sanic-org/sanic/pull/2238) Add `exception` decorator to Blueprint Groups | ||||
| - [#2244](https://github.com/sanic-org/sanic/pull/2244) Explicit static directive for serving file or dir (ex: `static(..., resource_type="file")`) | ||||
| - [#2245](https://github.com/sanic-org/sanic/pull/2245) Close HTTP loop when connection task cancelled | ||||
|  | ||||
| ### Bugfixes | ||||
| - [#2188](https://github.com/sanic-org/sanic/pull/2188) Fix the handling of the end of a chunked request | ||||
| - [#2195](https://github.com/sanic-org/sanic/pull/2195) Resolve unexpected error handling on static requests | ||||
| - [#2208](https://github.com/sanic-org/sanic/pull/2208) Make blueprint-based exceptions attach and trigger in a more intuitive manner | ||||
| - [#2211](https://github.com/sanic-org/sanic/pull/2211) Fixed for handling exceptions of asgi app call | ||||
| - [#2213](https://github.com/sanic-org/sanic/pull/2213) Fix bug where ws exceptions not being logged | ||||
| - [#2231](https://github.com/sanic-org/sanic/pull/2231) Cleaner closing of tasks by using `abort()` in strategic places to avoid dangling sockets | ||||
| - [#2247](https://github.com/sanic-org/sanic/pull/2247) Fix logging of auto-reload status in debug mode | ||||
| - [#2246](https://github.com/sanic-org/sanic/pull/2246) Account for BP with exception handler but no routes | ||||
|  | ||||
| ### Developer infrastructure   | ||||
| - [#2194](https://github.com/sanic-org/sanic/pull/2194) HTTP unit tests with raw client | ||||
| - [#2199](https://github.com/sanic-org/sanic/pull/2199) Switch to codeclimate | ||||
| - [#2214](https://github.com/sanic-org/sanic/pull/2214) Try Reopening Windows Tests | ||||
| - [#2229](https://github.com/sanic-org/sanic/pull/2229) Refactor `HttpProtocol` into a base class | ||||
| - [#2230](https://github.com/sanic-org/sanic/pull/2230) Refactor `server.py` into multi-file module | ||||
|  | ||||
| ### Miscellaneous | ||||
| - [#2173](https://github.com/sanic-org/sanic/pull/2173) Remove Duplicated Dependencies and PEP 517 Support  | ||||
| - [#2193](https://github.com/sanic-org/sanic/pull/2193), [#2196](https://github.com/sanic-org/sanic/pull/2196), [#2217](https://github.com/sanic-org/sanic/pull/2217) Type annotation changes | ||||
|  | ||||
|  | ||||
|  | ||||
| @@ -1,29 +1,44 @@ | ||||
| from sanic import Sanic | ||||
| from sanic import response | ||||
| from signal import signal, SIGINT | ||||
| import asyncio | ||||
|  | ||||
| from signal import SIGINT, signal | ||||
|  | ||||
| import uvloop | ||||
|  | ||||
| from sanic import Sanic, response | ||||
| from sanic.server import AsyncioServer | ||||
|  | ||||
|  | ||||
| app = Sanic(__name__) | ||||
|  | ||||
| @app.listener('after_server_start') | ||||
|  | ||||
| @app.listener("after_server_start") | ||||
| async def after_start_test(app, loop): | ||||
|     print("Async Server Started!") | ||||
|  | ||||
|  | ||||
| @app.route("/") | ||||
| async def test(request): | ||||
|     return response.json({"answer": "42"}) | ||||
|  | ||||
|  | ||||
| asyncio.set_event_loop(uvloop.new_event_loop()) | ||||
| serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True) | ||||
| serv_coro = app.create_server( | ||||
|     host="0.0.0.0", port=8000, return_asyncio_server=True | ||||
| ) | ||||
| loop = asyncio.get_event_loop() | ||||
| serv_task = asyncio.ensure_future(serv_coro, loop=loop) | ||||
| signal(SIGINT, lambda s, f: loop.stop()) | ||||
| server = loop.run_until_complete(serv_task) | ||||
| server: AsyncioServer = loop.run_until_complete(serv_task)  # type: ignore | ||||
| server.startup() | ||||
|  | ||||
| # When using app.run(), this actually triggers before the serv_coro. | ||||
| # But, in this example, we are using the convenience method, even if it is | ||||
| # out of order. | ||||
| server.before_start() | ||||
| server.after_start() | ||||
| try: | ||||
|     loop.run_forever() | ||||
| except KeyboardInterrupt as e: | ||||
| except KeyboardInterrupt: | ||||
|     loop.stop() | ||||
| finally: | ||||
|     server.before_stop() | ||||
|   | ||||
| @@ -1,13 +1,14 @@ | ||||
| from sanic import Sanic | ||||
| from sanic.response import file | ||||
| from sanic.response import redirect | ||||
|  | ||||
| app = Sanic(__name__) | ||||
|  | ||||
|  | ||||
| @app.route('/') | ||||
| async def index(request): | ||||
|     return await file('websocket.html') | ||||
| app.static('index.html', "websocket.html") | ||||
|  | ||||
| @app.route('/') | ||||
| def index(request): | ||||
|     return redirect("index.html") | ||||
|  | ||||
| @app.websocket('/feed') | ||||
| async def feed(request, ws): | ||||
|   | ||||
							
								
								
									
										3
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| [build-system] | ||||
| requires = ["setuptools", "wheel"] | ||||
| build-backend = "setuptools.build_meta" | ||||
| @@ -1 +1 @@ | ||||
| __version__ = "21.6.1" | ||||
| __version__ = "21.9.2" | ||||
|   | ||||
							
								
								
									
										332
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										332
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -1,9 +1,12 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import logging | ||||
| import logging.config | ||||
| import os | ||||
| import re | ||||
|  | ||||
| from asyncio import ( | ||||
|     AbstractEventLoop, | ||||
|     CancelledError, | ||||
|     Protocol, | ||||
|     ensure_future, | ||||
| @@ -21,6 +24,7 @@ from traceback import format_exc | ||||
| from types import SimpleNamespace | ||||
| from typing import ( | ||||
|     Any, | ||||
|     AnyStr, | ||||
|     Awaitable, | ||||
|     Callable, | ||||
|     Coroutine, | ||||
| @@ -30,6 +34,7 @@ from typing import ( | ||||
|     List, | ||||
|     Optional, | ||||
|     Set, | ||||
|     Tuple, | ||||
|     Type, | ||||
|     Union, | ||||
| ) | ||||
| @@ -69,20 +74,29 @@ from sanic.router import Router | ||||
| from sanic.server import AsyncioServer, HttpProtocol | ||||
| from sanic.server import Signal as ServerSignal | ||||
| from sanic.server import serve, serve_multiple, serve_single | ||||
| from sanic.server.protocols.websocket_protocol import WebSocketProtocol | ||||
| from sanic.server.websockets.impl import ConnectionClosed | ||||
| from sanic.signals import Signal, SignalRouter | ||||
| from sanic.websocket import ConnectionClosed, WebSocketProtocol | ||||
| from sanic.touchup import TouchUp, TouchUpMeta | ||||
|  | ||||
|  | ||||
| class Sanic(BaseSanic): | ||||
| class Sanic(BaseSanic, metaclass=TouchUpMeta): | ||||
|     """ | ||||
|     The main application instance | ||||
|     """ | ||||
|  | ||||
|     __touchup__ = ( | ||||
|         "handle_request", | ||||
|         "handle_exception", | ||||
|         "_run_response_middleware", | ||||
|         "_run_request_middleware", | ||||
|     ) | ||||
|     __fake_slots__ = ( | ||||
|         "_asgi_app", | ||||
|         "_app_registry", | ||||
|         "_asgi_client", | ||||
|         "_blueprint_order", | ||||
|         "_delayed_tasks", | ||||
|         "_future_routes", | ||||
|         "_future_statics", | ||||
|         "_future_middleware", | ||||
| @@ -137,7 +151,7 @@ class Sanic(BaseSanic): | ||||
|         log_config: Optional[Dict[str, Any]] = None, | ||||
|         configure_logging: bool = True, | ||||
|         register: Optional[bool] = None, | ||||
|         dumps: Optional[Callable[..., str]] = None, | ||||
|         dumps: Optional[Callable[..., AnyStr]] = None, | ||||
|     ) -> None: | ||||
|         super().__init__(name=name) | ||||
|  | ||||
| @@ -153,21 +167,22 @@ class Sanic(BaseSanic): | ||||
|  | ||||
|         self._asgi_client = None | ||||
|         self._blueprint_order: List[Blueprint] = [] | ||||
|         self._delayed_tasks: List[str] = [] | ||||
|         self._test_client = None | ||||
|         self._test_manager = None | ||||
|         self.asgi = False | ||||
|         self.auto_reload = False | ||||
|         self.blueprints: Dict[str, Blueprint] = {} | ||||
|         self.config = config or Config( | ||||
|             load_env=load_env, env_prefix=env_prefix | ||||
|         self.config: Config = config or Config( | ||||
|             load_env=load_env, | ||||
|             env_prefix=env_prefix, | ||||
|             app=self, | ||||
|         ) | ||||
|         self.configure_logging = configure_logging | ||||
|         self.ctx = ctx or SimpleNamespace() | ||||
|         self.debug = None | ||||
|         self.error_handler = error_handler or ErrorHandler() | ||||
|         self.is_running = False | ||||
|         self.is_stopping = False | ||||
|         self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) | ||||
|         self.configure_logging: bool = configure_logging | ||||
|         self.ctx: Any = ctx or SimpleNamespace() | ||||
|         self.debug = False | ||||
|         self.error_handler: ErrorHandler = error_handler or ErrorHandler() | ||||
|         self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) | ||||
|         self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} | ||||
|         self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} | ||||
|         self.reload_dirs: Set[Path] = set() | ||||
| @@ -190,9 +205,10 @@ class Sanic(BaseSanic): | ||||
|             self.__class__.register_app(self) | ||||
|  | ||||
|         self.router.ctx.app = self | ||||
|         self.signal_router.ctx.app = self | ||||
|  | ||||
|         if dumps: | ||||
|             BaseHTTPResponse._dumps = dumps | ||||
|             BaseHTTPResponse._dumps = dumps  # type: ignore | ||||
|  | ||||
|     @property | ||||
|     def loop(self): | ||||
| @@ -230,9 +246,12 @@ class Sanic(BaseSanic): | ||||
|             loop = self.loop  # Will raise SanicError if loop is not started | ||||
|             self._loop_add_task(task, self, loop) | ||||
|         except SanicException: | ||||
|             self.listener("before_server_start")( | ||||
|                 partial(self._loop_add_task, task) | ||||
|             ) | ||||
|             task_name = f"sanic.delayed_task.{hash(task)}" | ||||
|             if not self._delayed_tasks: | ||||
|                 self.after_server_start(partial(self.dispatch_delayed_tasks)) | ||||
|  | ||||
|             self.signal(task_name)(partial(self.run_delayed_task, task=task)) | ||||
|             self._delayed_tasks.append(task_name) | ||||
|  | ||||
|     def register_listener(self, listener: Callable, event: str) -> Any: | ||||
|         """ | ||||
| @@ -244,12 +263,20 @@ class Sanic(BaseSanic): | ||||
|         """ | ||||
|  | ||||
|         try: | ||||
|             _event = ListenerEvent(event) | ||||
|         except ValueError: | ||||
|             valid = ", ".join(ListenerEvent.__members__.values()) | ||||
|             _event = ListenerEvent[event.upper()] | ||||
|         except (ValueError, AttributeError): | ||||
|             valid = ", ".join( | ||||
|                 map(lambda x: x.lower(), ListenerEvent.__members__.keys()) | ||||
|             ) | ||||
|             raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") | ||||
|  | ||||
|         self.listeners[_event].append(listener) | ||||
|         if "." in _event: | ||||
|             self.signal(_event.value)( | ||||
|                 partial(self._listener, listener=listener) | ||||
|             ) | ||||
|         else: | ||||
|             self.listeners[_event.value].append(listener) | ||||
|  | ||||
|         return listener | ||||
|  | ||||
|     def register_middleware(self, middleware, attach_to: str = "request"): | ||||
| @@ -308,7 +335,11 @@ class Sanic(BaseSanic): | ||||
|                     self.named_response_middleware[_rn].appendleft(middleware) | ||||
|         return middleware | ||||
|  | ||||
|     def _apply_exception_handler(self, handler: FutureException): | ||||
|     def _apply_exception_handler( | ||||
|         self, | ||||
|         handler: FutureException, | ||||
|         route_names: Optional[List[str]] = None, | ||||
|     ): | ||||
|         """Decorate a function to be registered as a handler for exceptions | ||||
|  | ||||
|         :param exceptions: exceptions | ||||
| @@ -318,9 +349,9 @@ class Sanic(BaseSanic): | ||||
|         for exception in handler.exceptions: | ||||
|             if isinstance(exception, (tuple, list)): | ||||
|                 for e in exception: | ||||
|                     self.error_handler.add(e, handler.handler) | ||||
|                     self.error_handler.add(e, handler.handler, route_names) | ||||
|             else: | ||||
|                 self.error_handler.add(exception, handler.handler) | ||||
|                 self.error_handler.add(exception, handler.handler, route_names) | ||||
|         return handler.handler | ||||
|  | ||||
|     def _apply_listener(self, listener: FutureListener): | ||||
| @@ -377,11 +408,17 @@ class Sanic(BaseSanic): | ||||
|         *, | ||||
|         condition: Optional[Dict[str, str]] = None, | ||||
|         context: Optional[Dict[str, Any]] = None, | ||||
|         fail_not_found: bool = True, | ||||
|         inline: bool = False, | ||||
|         reverse: bool = False, | ||||
|     ) -> Coroutine[Any, Any, Awaitable[Any]]: | ||||
|         return self.signal_router.dispatch( | ||||
|             event, | ||||
|             context=context, | ||||
|             condition=condition, | ||||
|             inline=inline, | ||||
|             reverse=reverse, | ||||
|             fail_not_found=fail_not_found, | ||||
|         ) | ||||
|  | ||||
|     async def event( | ||||
| @@ -411,7 +448,13 @@ class Sanic(BaseSanic): | ||||
|  | ||||
|         self.websocket_enabled = enable | ||||
|  | ||||
|     def blueprint(self, blueprint, **options): | ||||
|     def blueprint( | ||||
|         self, | ||||
|         blueprint: Union[ | ||||
|             Blueprint, List[Blueprint], Tuple[Blueprint], BlueprintGroup | ||||
|         ], | ||||
|         **options: Any, | ||||
|     ): | ||||
|         """Register a blueprint on the application. | ||||
|  | ||||
|         :param blueprint: Blueprint object or (list, tuple) thereof | ||||
| @@ -651,7 +694,7 @@ class Sanic(BaseSanic): | ||||
|  | ||||
|     async def handle_exception( | ||||
|         self, request: Request, exception: BaseException | ||||
|     ): | ||||
|     ):  # no cov | ||||
|         """ | ||||
|         A handler that catches specific exceptions and outputs a response. | ||||
|  | ||||
| @@ -661,6 +704,12 @@ class Sanic(BaseSanic): | ||||
|         :type exception: BaseException | ||||
|         :raises ServerError: response 500 | ||||
|         """ | ||||
|         await self.dispatch( | ||||
|             "http.lifecycle.exception", | ||||
|             inline=True, | ||||
|             context={"request": request, "exception": exception}, | ||||
|         ) | ||||
|  | ||||
|         # -------------------------------------------- # | ||||
|         # Request Middleware | ||||
|         # -------------------------------------------- # | ||||
| @@ -707,7 +756,7 @@ class Sanic(BaseSanic): | ||||
|                 f"Invalid response type {response!r} (need HTTPResponse)" | ||||
|             ) | ||||
|  | ||||
|     async def handle_request(self, request: Request): | ||||
|     async def handle_request(self, request: Request):  # no cov | ||||
|         """Take a request from the HTTP Server and return a response object | ||||
|         to be sent back The HTTP Server only expects a response object, so | ||||
|         exception handling must be done here | ||||
| @@ -715,10 +764,22 @@ class Sanic(BaseSanic): | ||||
|         :param request: HTTP Request object | ||||
|         :return: Nothing | ||||
|         """ | ||||
|         await self.dispatch( | ||||
|             "http.lifecycle.handle", | ||||
|             inline=True, | ||||
|             context={"request": request}, | ||||
|         ) | ||||
|  | ||||
|         # Define `response` var here to remove warnings about | ||||
|         # allocation before assignment below. | ||||
|         response = None | ||||
|         try: | ||||
|  | ||||
|             await self.dispatch( | ||||
|                 "http.routing.before", | ||||
|                 inline=True, | ||||
|                 context={"request": request}, | ||||
|             ) | ||||
|             # Fetch handler from router | ||||
|             route, handler, kwargs = self.router.get( | ||||
|                 request.path, | ||||
| @@ -726,19 +787,29 @@ class Sanic(BaseSanic): | ||||
|                 request.headers.getone("host", None), | ||||
|             ) | ||||
|  | ||||
|             request._match_info = kwargs | ||||
|             request._match_info = {**kwargs} | ||||
|             request.route = route | ||||
|  | ||||
|             await self.dispatch( | ||||
|                 "http.routing.after", | ||||
|                 inline=True, | ||||
|                 context={ | ||||
|                     "request": request, | ||||
|                     "route": route, | ||||
|                     "kwargs": kwargs, | ||||
|                     "handler": handler, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|             if ( | ||||
|                 request.stream.request_body  # type: ignore | ||||
|                 request.stream | ||||
|                 and request.stream.request_body | ||||
|                 and not route.ctx.ignore_body | ||||
|             ): | ||||
|  | ||||
|                 if hasattr(handler, "is_stream"): | ||||
|                     # Streaming handler: lift the size limit | ||||
|                     request.stream.request_max_size = float(  # type: ignore | ||||
|                         "inf" | ||||
|                     ) | ||||
|                     request.stream.request_max_size = float("inf") | ||||
|                 else: | ||||
|                     # Non-streaming handler: preload body | ||||
|                     await request.receive_body() | ||||
| @@ -765,17 +836,25 @@ class Sanic(BaseSanic): | ||||
|                     ) | ||||
|  | ||||
|                 # Run response handler | ||||
|                 response = handler(request, **kwargs) | ||||
|                 response = handler(request, **request.match_info) | ||||
|                 if isawaitable(response): | ||||
|                     response = await response | ||||
|  | ||||
|             if response: | ||||
|             if response is not None: | ||||
|                 response = await request.respond(response) | ||||
|             elif not hasattr(handler, "is_websocket"): | ||||
|                 response = request.stream.response  # type: ignore | ||||
|  | ||||
|             # Make sure that response is finished / run StreamingHTTP callback | ||||
|             if isinstance(response, BaseHTTPResponse): | ||||
|                 await self.dispatch( | ||||
|                     "http.lifecycle.response", | ||||
|                     inline=True, | ||||
|                     context={ | ||||
|                         "request": request, | ||||
|                         "response": response, | ||||
|                     }, | ||||
|                 ) | ||||
|                 await response.send(end_stream=True) | ||||
|             else: | ||||
|                 if not hasattr(handler, "is_websocket"): | ||||
| @@ -793,23 +872,11 @@ class Sanic(BaseSanic): | ||||
|     async def _websocket_handler( | ||||
|         self, handler, request, *args, subprotocols=None, **kwargs | ||||
|     ): | ||||
|         request.app = self | ||||
|         if not getattr(handler, "__blueprintname__", False): | ||||
|             request._name = handler.__name__ | ||||
|         else: | ||||
|             request._name = ( | ||||
|                 getattr(handler, "__blueprintname__", "") + handler.__name__ | ||||
|             ) | ||||
|  | ||||
|             pass | ||||
|  | ||||
|         if self.asgi: | ||||
|             ws = request.transport.get_websocket_connection() | ||||
|             await ws.accept(subprotocols) | ||||
|         else: | ||||
|             protocol = request.transport.get_protocol() | ||||
|             protocol.app = self | ||||
|  | ||||
|             ws = await protocol.websocket_handshake(request, subprotocols) | ||||
|  | ||||
|         # schedule the application handler | ||||
| @@ -817,13 +884,19 @@ class Sanic(BaseSanic): | ||||
|         # needs to be cancelled due to the server being stopped | ||||
|         fut = ensure_future(handler(request, ws, *args, **kwargs)) | ||||
|         self.websocket_tasks.add(fut) | ||||
|         cancelled = False | ||||
|         try: | ||||
|             await fut | ||||
|         except Exception as e: | ||||
|             self.error_handler.log(request, e) | ||||
|         except (CancelledError, ConnectionClosed): | ||||
|             pass | ||||
|             cancelled = True | ||||
|         finally: | ||||
|             self.websocket_tasks.remove(fut) | ||||
|             await ws.close() | ||||
|             if cancelled: | ||||
|                 ws.end_connection(1000) | ||||
|             else: | ||||
|                 await ws.close() | ||||
|  | ||||
|     # -------------------------------------------------------------------- # | ||||
|     # Testing | ||||
| @@ -869,7 +942,7 @@ class Sanic(BaseSanic): | ||||
|         *, | ||||
|         debug: bool = False, | ||||
|         auto_reload: Optional[bool] = None, | ||||
|         ssl: Union[dict, SSLContext, None] = None, | ||||
|         ssl: Union[Dict[str, str], SSLContext, None] = None, | ||||
|         sock: Optional[socket] = None, | ||||
|         workers: int = 1, | ||||
|         protocol: Optional[Type[Protocol]] = None, | ||||
| @@ -999,7 +1072,7 @@ class Sanic(BaseSanic): | ||||
|         port: Optional[int] = None, | ||||
|         *, | ||||
|         debug: bool = False, | ||||
|         ssl: Union[dict, SSLContext, None] = None, | ||||
|         ssl: Union[Dict[str, str], SSLContext, None] = None, | ||||
|         sock: Optional[socket] = None, | ||||
|         protocol: Type[Protocol] = None, | ||||
|         backlog: int = 100, | ||||
| @@ -1071,11 +1144,6 @@ class Sanic(BaseSanic): | ||||
|             run_async=return_asyncio_server, | ||||
|         ) | ||||
|  | ||||
|         # Trigger before_start events | ||||
|         await self.trigger_events( | ||||
|             server_settings.get("before_start", []), | ||||
|             server_settings.get("loop"), | ||||
|         ) | ||||
|         main_start = server_settings.pop("main_start", None) | ||||
|         main_stop = server_settings.pop("main_stop", None) | ||||
|         if main_start or main_stop: | ||||
| @@ -1088,17 +1156,9 @@ class Sanic(BaseSanic): | ||||
|             asyncio_server_kwargs=asyncio_server_kwargs, **server_settings | ||||
|         ) | ||||
|  | ||||
|     async def trigger_events(self, events, loop): | ||||
|         """Trigger events (functions or async) | ||||
|         :param events: one or more sync or async functions to execute | ||||
|         :param loop: event loop | ||||
|         """ | ||||
|         for event in events: | ||||
|             result = event(loop) | ||||
|             if isawaitable(result): | ||||
|                 await result | ||||
|  | ||||
|     async def _run_request_middleware(self, request, request_name=None): | ||||
|     async def _run_request_middleware( | ||||
|         self, request, request_name=None | ||||
|     ):  # no cov | ||||
|         # The if improves speed.  I don't know why | ||||
|         named_middleware = self.named_request_middleware.get( | ||||
|             request_name, deque() | ||||
| @@ -1111,25 +1171,67 @@ class Sanic(BaseSanic): | ||||
|             request.request_middleware_started = True | ||||
|  | ||||
|             for middleware in applicable_middleware: | ||||
|                 await self.dispatch( | ||||
|                     "http.middleware.before", | ||||
|                     inline=True, | ||||
|                     context={ | ||||
|                         "request": request, | ||||
|                         "response": None, | ||||
|                     }, | ||||
|                     condition={"attach_to": "request"}, | ||||
|                 ) | ||||
|  | ||||
|                 response = middleware(request) | ||||
|                 if isawaitable(response): | ||||
|                     response = await response | ||||
|  | ||||
|                 await self.dispatch( | ||||
|                     "http.middleware.after", | ||||
|                     inline=True, | ||||
|                     context={ | ||||
|                         "request": request, | ||||
|                         "response": None, | ||||
|                     }, | ||||
|                     condition={"attach_to": "request"}, | ||||
|                 ) | ||||
|  | ||||
|                 if response: | ||||
|                     return response | ||||
|         return None | ||||
|  | ||||
|     async def _run_response_middleware( | ||||
|         self, request, response, request_name=None | ||||
|     ): | ||||
|     ):  # no cov | ||||
|         named_middleware = self.named_response_middleware.get( | ||||
|             request_name, deque() | ||||
|         ) | ||||
|         applicable_middleware = self.response_middleware + named_middleware | ||||
|         if applicable_middleware: | ||||
|             for middleware in applicable_middleware: | ||||
|                 await self.dispatch( | ||||
|                     "http.middleware.before", | ||||
|                     inline=True, | ||||
|                     context={ | ||||
|                         "request": request, | ||||
|                         "response": response, | ||||
|                     }, | ||||
|                     condition={"attach_to": "response"}, | ||||
|                 ) | ||||
|  | ||||
|                 _response = middleware(request, response) | ||||
|                 if isawaitable(_response): | ||||
|                     _response = await _response | ||||
|  | ||||
|                 await self.dispatch( | ||||
|                     "http.middleware.after", | ||||
|                     inline=True, | ||||
|                     context={ | ||||
|                         "request": request, | ||||
|                         "response": _response if _response else response, | ||||
|                     }, | ||||
|                     condition={"attach_to": "response"}, | ||||
|                 ) | ||||
|  | ||||
|                 if _response: | ||||
|                     response = _response | ||||
|                     if isinstance(response, BaseHTTPResponse): | ||||
| @@ -1155,10 +1257,6 @@ class Sanic(BaseSanic): | ||||
|     ): | ||||
|         """Helper function used by `run` and `create_server`.""" | ||||
|  | ||||
|         self.listeners["before_server_start"] = [ | ||||
|             self.finalize | ||||
|         ] + self.listeners["before_server_start"] | ||||
|  | ||||
|         if isinstance(ssl, dict): | ||||
|             # try common aliaseses | ||||
|             cert = ssl.get("cert") or ssl.get("certificate") | ||||
| @@ -1195,10 +1293,6 @@ class Sanic(BaseSanic): | ||||
|         # Register start/stop events | ||||
|  | ||||
|         for event_name, settings_name, reverse in ( | ||||
|             ("before_server_start", "before_start", False), | ||||
|             ("after_server_start", "after_start", False), | ||||
|             ("before_server_stop", "before_stop", True), | ||||
|             ("after_server_stop", "after_stop", True), | ||||
|             ("main_process_start", "main_start", False), | ||||
|             ("main_process_stop", "main_stop", True), | ||||
|         ): | ||||
| @@ -1236,7 +1330,8 @@ class Sanic(BaseSanic): | ||||
|                 logger.info(f"Goin' Fast @ {proto}://{host}:{port}") | ||||
|  | ||||
|         debug_mode = "enabled" if self.debug else "disabled" | ||||
|         logger.debug("Sanic auto-reload: enabled") | ||||
|         reload_mode = "enabled" if auto_reload else "disabled" | ||||
|         logger.debug(f"Sanic auto-reload: {reload_mode}") | ||||
|         logger.debug(f"Sanic debug mode: {debug_mode}") | ||||
|  | ||||
|         return server_settings | ||||
| @@ -1246,20 +1341,44 @@ class Sanic(BaseSanic): | ||||
|         return ".".join(parts) | ||||
|  | ||||
|     @classmethod | ||||
|     def _loop_add_task(cls, task, app, loop): | ||||
|     def _prep_task(cls, task, app, loop): | ||||
|         if callable(task): | ||||
|             try: | ||||
|                 loop.create_task(task(app)) | ||||
|                 task = task(app) | ||||
|             except TypeError: | ||||
|                 loop.create_task(task()) | ||||
|         else: | ||||
|             loop.create_task(task) | ||||
|                 task = task() | ||||
|  | ||||
|         return task | ||||
|  | ||||
|     @classmethod | ||||
|     def _loop_add_task(cls, task, app, loop): | ||||
|         prepped = cls._prep_task(task, app, loop) | ||||
|         loop.create_task(prepped) | ||||
|  | ||||
|     @classmethod | ||||
|     def _cancel_websocket_tasks(cls, app, loop): | ||||
|         for task in app.websocket_tasks: | ||||
|             task.cancel() | ||||
|  | ||||
|     @staticmethod | ||||
|     async def dispatch_delayed_tasks(app, loop): | ||||
|         for name in app._delayed_tasks: | ||||
|             await app.dispatch(name, context={"app": app, "loop": loop}) | ||||
|         app._delayed_tasks.clear() | ||||
|  | ||||
|     @staticmethod | ||||
|     async def run_delayed_task(app, loop, task): | ||||
|         prepped = app._prep_task(task, app, loop) | ||||
|         await prepped | ||||
|  | ||||
|     @staticmethod | ||||
|     async def _listener( | ||||
|         app: Sanic, loop: AbstractEventLoop, listener: ListenerType | ||||
|     ): | ||||
|         maybe_coro = listener(app, loop) | ||||
|         if maybe_coro and isawaitable(maybe_coro): | ||||
|             await maybe_coro | ||||
|  | ||||
|     # -------------------------------------------------------------------- # | ||||
|     # ASGI | ||||
|     # -------------------------------------------------------------------- # | ||||
| @@ -1333,15 +1452,54 @@ class Sanic(BaseSanic): | ||||
|             raise SanicException(f'Sanic app name "{name}" not found.') | ||||
|  | ||||
|     # -------------------------------------------------------------------- # | ||||
|     # Static methods | ||||
|     # Lifecycle | ||||
|     # -------------------------------------------------------------------- # | ||||
|  | ||||
|     @staticmethod | ||||
|     async def finalize(app, _): | ||||
|     def finalize(self): | ||||
|         try: | ||||
|             app.router.finalize() | ||||
|             if app.signal_router.routes: | ||||
|                 app.signal_router.finalize()  # noqa | ||||
|             self.router.finalize() | ||||
|         except FinalizationError as e: | ||||
|             if not Sanic.test_mode: | ||||
|                 raise e  # noqa | ||||
|                 raise e | ||||
|  | ||||
|     def signalize(self): | ||||
|         try: | ||||
|             self.signal_router.finalize() | ||||
|         except FinalizationError as e: | ||||
|             if not Sanic.test_mode: | ||||
|                 raise e | ||||
|  | ||||
|     async def _startup(self): | ||||
|         self.signalize() | ||||
|         self.finalize() | ||||
|         ErrorHandler.finalize( | ||||
|             self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT | ||||
|         ) | ||||
|         TouchUp.run(self) | ||||
|  | ||||
|     async def _server_event( | ||||
|         self, | ||||
|         concern: str, | ||||
|         action: str, | ||||
|         loop: Optional[AbstractEventLoop] = None, | ||||
|     ) -> None: | ||||
|         event = f"server.{concern}.{action}" | ||||
|         if action not in ("before", "after") or concern not in ( | ||||
|             "init", | ||||
|             "shutdown", | ||||
|         ): | ||||
|             raise SanicException(f"Invalid server event: {event}") | ||||
|         logger.debug(f"Triggering server events: {event}") | ||||
|         reverse = concern == "shutdown" | ||||
|         if loop is None: | ||||
|             loop = self.loop | ||||
|         await self.dispatch( | ||||
|             event, | ||||
|             fail_not_found=False, | ||||
|             reverse=reverse, | ||||
|             inline=True, | ||||
|             context={ | ||||
|                 "app": self, | ||||
|                 "loop": loop, | ||||
|             }, | ||||
|         ) | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| import warnings | ||||
|  | ||||
| from inspect import isawaitable | ||||
| from typing import Optional | ||||
| from urllib.parse import quote | ||||
|  | ||||
| @@ -11,21 +10,27 @@ from sanic.exceptions import ServerError | ||||
| from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport | ||||
| from sanic.request import Request | ||||
| from sanic.server import ConnInfo | ||||
| from sanic.websocket import WebSocketConnection | ||||
| from sanic.server.websockets.connection import WebSocketConnection | ||||
|  | ||||
|  | ||||
| class Lifespan: | ||||
|     def __init__(self, asgi_app: "ASGIApp") -> None: | ||||
|         self.asgi_app = asgi_app | ||||
|  | ||||
|         if "before_server_start" in self.asgi_app.sanic_app.listeners: | ||||
|         if ( | ||||
|             "server.init.before" | ||||
|             in self.asgi_app.sanic_app.signal_router.name_index | ||||
|         ): | ||||
|             warnings.warn( | ||||
|                 'You have set a listener for "before_server_start" ' | ||||
|                 "in ASGI mode. " | ||||
|                 "It will be executed as early as possible, but not before " | ||||
|                 "the ASGI server is started." | ||||
|             ) | ||||
|         if "after_server_stop" in self.asgi_app.sanic_app.listeners: | ||||
|         if ( | ||||
|             "server.shutdown.after" | ||||
|             in self.asgi_app.sanic_app.signal_router.name_index | ||||
|         ): | ||||
|             warnings.warn( | ||||
|                 'You have set a listener for "after_server_stop" ' | ||||
|                 "in ASGI mode. " | ||||
| @@ -42,19 +47,9 @@ class Lifespan: | ||||
|         in sequence since the ASGI lifespan protocol only supports a single | ||||
|         startup event. | ||||
|         """ | ||||
|         self.asgi_app.sanic_app.router.finalize() | ||||
|         if self.asgi_app.sanic_app.signal_router.routes: | ||||
|             self.asgi_app.sanic_app.signal_router.finalize() | ||||
|         listeners = self.asgi_app.sanic_app.listeners.get( | ||||
|             "before_server_start", [] | ||||
|         ) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) | ||||
|  | ||||
|         for handler in listeners: | ||||
|             response = handler( | ||||
|                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop | ||||
|             ) | ||||
|             if response and isawaitable(response): | ||||
|                 await response | ||||
|         await self.asgi_app.sanic_app._startup() | ||||
|         await self.asgi_app.sanic_app._server_event("init", "before") | ||||
|         await self.asgi_app.sanic_app._server_event("init", "after") | ||||
|  | ||||
|     async def shutdown(self) -> None: | ||||
|         """ | ||||
| @@ -65,16 +60,8 @@ class Lifespan: | ||||
|         in sequence since the ASGI lifespan protocol only supports a single | ||||
|         shutdown event. | ||||
|         """ | ||||
|         listeners = self.asgi_app.sanic_app.listeners.get( | ||||
|             "before_server_stop", [] | ||||
|         ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) | ||||
|  | ||||
|         for handler in listeners: | ||||
|             response = handler( | ||||
|                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop | ||||
|             ) | ||||
|             if response and isawaitable(response): | ||||
|                 await response | ||||
|         await self.asgi_app.sanic_app._server_event("shutdown", "before") | ||||
|         await self.asgi_app.sanic_app._server_event("shutdown", "after") | ||||
|  | ||||
|     async def __call__( | ||||
|         self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||
| @@ -207,4 +194,7 @@ class ASGIApp: | ||||
|         """ | ||||
|         Handle the incoming request. | ||||
|         """ | ||||
|         await self.sanic_app.handle_request(self.request) | ||||
|         try: | ||||
|             await self.sanic_app.handle_request(self.request) | ||||
|         except Exception as e: | ||||
|             await self.sanic_app.handle_exception(self.request, e) | ||||
|   | ||||
| @@ -58,7 +58,7 @@ class BaseSanic( | ||||
|         if name not in self.__fake_slots__: | ||||
|             warn( | ||||
|                 f"Setting variables on {self.__class__.__name__} instances is " | ||||
|                 "deprecated and will be removed in version 21.9. You should " | ||||
|                 "deprecated and will be removed in version 21.12. You should " | ||||
|                 f"change your {self.__class__.__name__} instance to use " | ||||
|                 f"instance.ctx.{name} instead.", | ||||
|                 DeprecationWarning, | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from collections.abc import MutableSequence | ||||
| from functools import partial | ||||
| from typing import TYPE_CHECKING, List, Optional, Union | ||||
|  | ||||
|  | ||||
| @@ -196,6 +197,27 @@ class BlueprintGroup(MutableSequence): | ||||
|         """ | ||||
|         self._blueprints.append(value) | ||||
|  | ||||
|     def exception(self, *exceptions, **kwargs): | ||||
|         """ | ||||
|         A decorator that can be used to implement a global exception handler | ||||
|         for all the Blueprints that belong to this Blueprint Group. | ||||
|  | ||||
|         In case of nested Blueprint Groups, the same handler is applied | ||||
|         across each of the Blueprints recursively. | ||||
|  | ||||
|         :param args: List of Python exceptions to be caught by the handler | ||||
|         :param kwargs: Additional optional arguments to be passed to the | ||||
|             exception handler | ||||
|         :return a decorated method to handle global exceptions for any | ||||
|             blueprint registered under this group. | ||||
|         """ | ||||
|  | ||||
|         def register_exception_handler_for_blueprints(fn): | ||||
|             for blueprint in self.blueprints: | ||||
|                 blueprint.exception(*exceptions, **kwargs)(fn) | ||||
|  | ||||
|         return register_exception_handler_for_blueprints | ||||
|  | ||||
|     def insert(self, index: int, item: Blueprint) -> None: | ||||
|         """ | ||||
|         The Abstract class `MutableSequence` leverages this insert method to | ||||
| @@ -229,3 +251,15 @@ class BlueprintGroup(MutableSequence): | ||||
|             args = list(args)[1:] | ||||
|             return register_middleware_for_blueprints(fn) | ||||
|         return register_middleware_for_blueprints | ||||
|  | ||||
|     def on_request(self, middleware=None): | ||||
|         if callable(middleware): | ||||
|             return self.middleware(middleware, "request") | ||||
|         else: | ||||
|             return partial(self.middleware, attach_to="request") | ||||
|  | ||||
|     def on_response(self, middleware=None): | ||||
|         if callable(middleware): | ||||
|             return self.middleware(middleware, "response") | ||||
|         else: | ||||
|             return partial(self.middleware, attach_to="response") | ||||
|   | ||||
| @@ -3,6 +3,7 @@ from __future__ import annotations | ||||
| import asyncio | ||||
|  | ||||
| from collections import defaultdict | ||||
| from copy import deepcopy | ||||
| from types import SimpleNamespace | ||||
| from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union | ||||
|  | ||||
| @@ -12,6 +13,7 @@ from sanic_routing.route import Route  # type: ignore | ||||
| from sanic.base import BaseSanic | ||||
| from sanic.blueprint_group import BlueprintGroup | ||||
| from sanic.exceptions import SanicException | ||||
| from sanic.helpers import Default, _default | ||||
| from sanic.models.futures import FutureRoute, FutureStatic | ||||
| from sanic.models.handler_types import ( | ||||
|     ListenerType, | ||||
| @@ -40,7 +42,7 @@ class Blueprint(BaseSanic): | ||||
|     :param host: IP Address of FQDN for the sanic server to use. | ||||
|     :param version: Blueprint Version | ||||
|     :param strict_slashes: Enforce the API urls are requested with a | ||||
|         training */* | ||||
|         trailing */* | ||||
|     """ | ||||
|  | ||||
|     __fake_slots__ = ( | ||||
| @@ -76,15 +78,9 @@ class Blueprint(BaseSanic): | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         super().__init__(name=name) | ||||
|  | ||||
|         self._apps: Set[Sanic] = set() | ||||
|         self.reset() | ||||
|         self.ctx = SimpleNamespace() | ||||
|         self.exceptions: List[RouteHandler] = [] | ||||
|         self.host = host | ||||
|         self.listeners: Dict[str, List[ListenerType]] = {} | ||||
|         self.middlewares: List[MiddlewareType] = [] | ||||
|         self.routes: List[Route] = [] | ||||
|         self.statics: List[RouteHandler] = [] | ||||
|         self.strict_slashes = strict_slashes | ||||
|         self.url_prefix = ( | ||||
|             url_prefix[:-1] | ||||
| @@ -93,7 +89,6 @@ class Blueprint(BaseSanic): | ||||
|         ) | ||||
|         self.version = version | ||||
|         self.version_prefix = version_prefix | ||||
|         self.websocket_routes: List[Route] = [] | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         args = ", ".join( | ||||
| @@ -144,12 +139,87 @@ class Blueprint(BaseSanic): | ||||
|         kwargs["apply"] = False | ||||
|         return super().signal(event, *args, **kwargs) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._apps: Set[Sanic] = set() | ||||
|         self.exceptions: List[RouteHandler] = [] | ||||
|         self.listeners: Dict[str, List[ListenerType]] = {} | ||||
|         self.middlewares: List[MiddlewareType] = [] | ||||
|         self.routes: List[Route] = [] | ||||
|         self.statics: List[RouteHandler] = [] | ||||
|         self.websocket_routes: List[Route] = [] | ||||
|  | ||||
|     def copy( | ||||
|         self, | ||||
|         name: str, | ||||
|         url_prefix: Optional[Union[str, Default]] = _default, | ||||
|         version: Optional[Union[int, str, float, Default]] = _default, | ||||
|         version_prefix: Union[str, Default] = _default, | ||||
|         strict_slashes: Optional[Union[bool, Default]] = _default, | ||||
|         with_registration: bool = True, | ||||
|         with_ctx: bool = False, | ||||
|     ): | ||||
|         """ | ||||
|         Copy a blueprint instance with some optional parameters to | ||||
|         override the values of attributes in the old instance. | ||||
|  | ||||
|         :param name: unique name of the blueprint | ||||
|         :param url_prefix: URL to be prefixed before all route URLs | ||||
|         :param version: Blueprint Version | ||||
|         :param version_prefix: the prefix of the version number shown in the | ||||
|             URL. | ||||
|         :param strict_slashes: Enforce the API urls are requested with a | ||||
|             trailing */* | ||||
|         :param with_registration: whether register new blueprint instance with | ||||
|             sanic apps that were registered with the old instance or not. | ||||
|         :param with_ctx: whether ``ctx`` will be copied or not. | ||||
|         """ | ||||
|  | ||||
|         attrs_backup = { | ||||
|             "_apps": self._apps, | ||||
|             "routes": self.routes, | ||||
|             "websocket_routes": self.websocket_routes, | ||||
|             "middlewares": self.middlewares, | ||||
|             "exceptions": self.exceptions, | ||||
|             "listeners": self.listeners, | ||||
|             "statics": self.statics, | ||||
|         } | ||||
|  | ||||
|         self.reset() | ||||
|         new_bp = deepcopy(self) | ||||
|         new_bp.name = name | ||||
|  | ||||
|         if not isinstance(url_prefix, Default): | ||||
|             new_bp.url_prefix = url_prefix | ||||
|         if not isinstance(version, Default): | ||||
|             new_bp.version = version | ||||
|         if not isinstance(strict_slashes, Default): | ||||
|             new_bp.strict_slashes = strict_slashes | ||||
|         if not isinstance(version_prefix, Default): | ||||
|             new_bp.version_prefix = version_prefix | ||||
|  | ||||
|         for key, value in attrs_backup.items(): | ||||
|             setattr(self, key, value) | ||||
|  | ||||
|         if with_registration and self._apps: | ||||
|             if new_bp._future_statics: | ||||
|                 raise SanicException( | ||||
|                     "Static routes registered with the old blueprint instance," | ||||
|                     " cannot be registered again." | ||||
|                 ) | ||||
|             for app in self._apps: | ||||
|                 app.blueprint(new_bp) | ||||
|  | ||||
|         if not with_ctx: | ||||
|             new_bp.ctx = SimpleNamespace() | ||||
|  | ||||
|         return new_bp | ||||
|  | ||||
|     @staticmethod | ||||
|     def group( | ||||
|         *blueprints, | ||||
|         url_prefix="", | ||||
|         version=None, | ||||
|         strict_slashes=None, | ||||
|         *blueprints: Union[Blueprint, BlueprintGroup], | ||||
|         url_prefix: Optional[str] = None, | ||||
|         version: Optional[Union[int, str, float]] = None, | ||||
|         strict_slashes: Optional[bool] = None, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         """ | ||||
| @@ -196,6 +266,9 @@ class Blueprint(BaseSanic): | ||||
|         opt_version = options.get("version", None) | ||||
|         opt_strict_slashes = options.get("strict_slashes", None) | ||||
|         opt_version_prefix = options.get("version_prefix", self.version_prefix) | ||||
|         error_format = options.get( | ||||
|             "error_format", app.config.FALLBACK_ERROR_FORMAT | ||||
|         ) | ||||
|  | ||||
|         routes = [] | ||||
|         middleware = [] | ||||
| @@ -243,6 +316,7 @@ class Blueprint(BaseSanic): | ||||
|                 future.unquote, | ||||
|                 future.static, | ||||
|                 version_prefix, | ||||
|                 error_format, | ||||
|             ) | ||||
|  | ||||
|             route = app._apply_route(apply_route) | ||||
| @@ -261,19 +335,22 @@ class Blueprint(BaseSanic): | ||||
|  | ||||
|         route_names = [route.name for route in routes if route] | ||||
|  | ||||
|         # Middleware | ||||
|         if route_names: | ||||
|             # Middleware | ||||
|             for future in self._future_middleware: | ||||
|                 middleware.append(app._apply_middleware(future, route_names)) | ||||
|  | ||||
|         # Exceptions | ||||
|         for future in self._future_exceptions: | ||||
|             exception_handlers.append(app._apply_exception_handler(future)) | ||||
|             # Exceptions | ||||
|             for future in self._future_exceptions: | ||||
|                 exception_handlers.append( | ||||
|                     app._apply_exception_handler(future, route_names) | ||||
|                 ) | ||||
|  | ||||
|         # Event listeners | ||||
|         for listener in self._future_listeners: | ||||
|             listeners[listener.event].append(app._apply_listener(listener)) | ||||
|  | ||||
|         # Signals | ||||
|         for signal in self._future_signals: | ||||
|             signal.condition.update({"blueprint": self.name}) | ||||
|             app._apply_signal(signal) | ||||
|   | ||||
| @@ -1,14 +1,21 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from inspect import isclass | ||||
| from os import environ | ||||
| from pathlib import Path | ||||
| from typing import Any, Dict, Optional, Union | ||||
| from typing import TYPE_CHECKING, Any, Dict, Optional, Union | ||||
| from warnings import warn | ||||
|  | ||||
| from sanic.errorpages import check_error_format | ||||
| from sanic.http import Http | ||||
|  | ||||
| from .utils import load_module_from_file_location, str_to_bool | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING:  # no cov | ||||
|     from sanic import Sanic | ||||
|  | ||||
|  | ||||
| SANIC_PREFIX = "SANIC_" | ||||
| BASE_LOGO = """ | ||||
|  | ||||
| @@ -20,7 +27,7 @@ BASE_LOGO = """ | ||||
| DEFAULT_CONFIG = { | ||||
|     "ACCESS_LOG": True, | ||||
|     "EVENT_AUTOREGISTER": False, | ||||
|     "FALLBACK_ERROR_FORMAT": "html", | ||||
|     "FALLBACK_ERROR_FORMAT": "auto", | ||||
|     "FORWARDED_FOR_HEADER": "X-Forwarded-For", | ||||
|     "FORWARDED_SECRET": None, | ||||
|     "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0,  # 15 sec | ||||
| @@ -35,12 +42,9 @@ DEFAULT_CONFIG = { | ||||
|     "REQUEST_MAX_SIZE": 100000000,  # 100 megabytes | ||||
|     "REQUEST_TIMEOUT": 60,  # 60 seconds | ||||
|     "RESPONSE_TIMEOUT": 60,  # 60 seconds | ||||
|     "WEBSOCKET_MAX_QUEUE": 32, | ||||
|     "WEBSOCKET_MAX_SIZE": 2 ** 20,  # 1 megabyte | ||||
|     "WEBSOCKET_PING_INTERVAL": 20, | ||||
|     "WEBSOCKET_PING_TIMEOUT": 20, | ||||
|     "WEBSOCKET_READ_LIMIT": 2 ** 16, | ||||
|     "WEBSOCKET_WRITE_LIMIT": 2 ** 16, | ||||
| } | ||||
|  | ||||
|  | ||||
| @@ -62,12 +66,10 @@ class Config(dict): | ||||
|     REQUEST_MAX_SIZE: int | ||||
|     REQUEST_TIMEOUT: int | ||||
|     RESPONSE_TIMEOUT: int | ||||
|     WEBSOCKET_MAX_QUEUE: int | ||||
|     SERVER_NAME: str | ||||
|     WEBSOCKET_MAX_SIZE: int | ||||
|     WEBSOCKET_PING_INTERVAL: int | ||||
|     WEBSOCKET_PING_TIMEOUT: int | ||||
|     WEBSOCKET_READ_LIMIT: int | ||||
|     WEBSOCKET_WRITE_LIMIT: int | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
| @@ -75,11 +77,14 @@ class Config(dict): | ||||
|         load_env: Optional[Union[bool, str]] = True, | ||||
|         env_prefix: Optional[str] = SANIC_PREFIX, | ||||
|         keep_alive: Optional[bool] = None, | ||||
|         *, | ||||
|         app: Optional[Sanic] = None, | ||||
|     ): | ||||
|         defaults = defaults or {} | ||||
|         super().__init__({**DEFAULT_CONFIG, **defaults}) | ||||
|  | ||||
|         self.LOGO = BASE_LOGO | ||||
|         self._app = app | ||||
|         self._LOGO = "" | ||||
|  | ||||
|         if keep_alive is not None: | ||||
|             self.KEEP_ALIVE = keep_alive | ||||
| @@ -100,6 +105,8 @@ class Config(dict): | ||||
|             self.load_environment_vars(SANIC_PREFIX) | ||||
|  | ||||
|         self._configure_header_size() | ||||
|         self._check_error_format() | ||||
|         self._init = True | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|         try: | ||||
| @@ -107,14 +114,51 @@ class Config(dict): | ||||
|         except KeyError as ke: | ||||
|             raise AttributeError(f"Config has no '{ke.args[0]}'") | ||||
|  | ||||
|     def __setattr__(self, attr, value): | ||||
|         self[attr] = value | ||||
|         if attr in ( | ||||
|             "REQUEST_MAX_HEADER_SIZE", | ||||
|             "REQUEST_BUFFER_SIZE", | ||||
|             "REQUEST_MAX_SIZE", | ||||
|         ): | ||||
|             self._configure_header_size() | ||||
|     def __setattr__(self, attr, value) -> None: | ||||
|         self.update({attr: value}) | ||||
|  | ||||
|     def __setitem__(self, attr, value) -> None: | ||||
|         self.update({attr: value}) | ||||
|  | ||||
|     def update(self, *other, **kwargs) -> None: | ||||
|         other_mapping = {k: v for item in other for k, v in dict(item).items()} | ||||
|         super().update(*other, **kwargs) | ||||
|         for attr, value in {**other_mapping, **kwargs}.items(): | ||||
|             self._post_set(attr, value) | ||||
|  | ||||
|     def _post_set(self, attr, value) -> None: | ||||
|         if self.get("_init"): | ||||
|             if attr in ( | ||||
|                 "REQUEST_MAX_HEADER_SIZE", | ||||
|                 "REQUEST_BUFFER_SIZE", | ||||
|                 "REQUEST_MAX_SIZE", | ||||
|             ): | ||||
|                 self._configure_header_size() | ||||
|             elif attr == "FALLBACK_ERROR_FORMAT": | ||||
|                 self._check_error_format() | ||||
|                 if self.app and value != self.app.error_handler.fallback: | ||||
|                     if self.app.error_handler.fallback != "auto": | ||||
|                         warn( | ||||
|                             "Overriding non-default ErrorHandler fallback " | ||||
|                             "value. Changing from " | ||||
|                             f"{self.app.error_handler.fallback} to {value}." | ||||
|                         ) | ||||
|                     self.app.error_handler.fallback = value | ||||
|             elif attr == "LOGO": | ||||
|                 self._LOGO = value | ||||
|                 warn( | ||||
|                     "Setting the config.LOGO is deprecated and will no longer " | ||||
|                     "be supported starting in v22.6.", | ||||
|                     DeprecationWarning, | ||||
|                 ) | ||||
|  | ||||
|     @property | ||||
|     def app(self): | ||||
|         return self._app | ||||
|  | ||||
|     @property | ||||
|     def LOGO(self): | ||||
|         return self._LOGO | ||||
|  | ||||
|     def _configure_header_size(self): | ||||
|         Http.set_header_max_size( | ||||
| @@ -123,6 +167,9 @@ class Config(dict): | ||||
|             self.REQUEST_MAX_SIZE, | ||||
|         ) | ||||
|  | ||||
|     def _check_error_format(self): | ||||
|         check_error_format(self.FALLBACK_ERROR_FORMAT) | ||||
|  | ||||
|     def load_environment_vars(self, prefix=SANIC_PREFIX): | ||||
|         """ | ||||
|         Looks for prefixed environment variables and applies | ||||
|   | ||||
| @@ -340,41 +340,140 @@ RENDERERS_BY_CONFIG = { | ||||
| } | ||||
|  | ||||
| RENDERERS_BY_CONTENT_TYPE = { | ||||
|     "multipart/form-data": HTMLRenderer, | ||||
|     "application/json": JSONRenderer, | ||||
|     "text/plain": TextRenderer, | ||||
|     "application/json": JSONRenderer, | ||||
|     "multipart/form-data": HTMLRenderer, | ||||
|     "text/html": HTMLRenderer, | ||||
| } | ||||
| CONTENT_TYPE_BY_RENDERERS = { | ||||
|     v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items() | ||||
| } | ||||
|  | ||||
| RESPONSE_MAPPING = { | ||||
|     "empty": "html", | ||||
|     "json": "json", | ||||
|     "text": "text", | ||||
|     "raw": "text", | ||||
|     "html": "html", | ||||
|     "file": "html", | ||||
|     "file_stream": "text", | ||||
|     "stream": "text", | ||||
|     "redirect": "html", | ||||
|     "text/plain": "text", | ||||
|     "text/html": "html", | ||||
|     "application/json": "json", | ||||
| } | ||||
|  | ||||
|  | ||||
| def check_error_format(format): | ||||
|     if format not in RENDERERS_BY_CONFIG and format != "auto": | ||||
|         raise SanicException(f"Unknown format: {format}") | ||||
|  | ||||
|  | ||||
| def exception_response( | ||||
|     request: Request, | ||||
|     exception: Exception, | ||||
|     debug: bool, | ||||
|     fallback: str, | ||||
|     base: t.Type[BaseRenderer], | ||||
|     renderer: t.Type[t.Optional[BaseRenderer]] = None, | ||||
| ) -> HTTPResponse: | ||||
|     """ | ||||
|     Render a response for the default FALLBACK exception handler. | ||||
|     """ | ||||
|     content_type = None | ||||
|  | ||||
|     print("exception_response", fallback) | ||||
|     if not renderer: | ||||
|         renderer = HTMLRenderer | ||||
|         # Make sure we have something set | ||||
|         renderer = base | ||||
|         render_format = fallback | ||||
|  | ||||
|         if request: | ||||
|             if request.app.config.FALLBACK_ERROR_FORMAT == "auto": | ||||
|             # If there is a request, try and get the format | ||||
|             # from the route | ||||
|             if request.route: | ||||
|                 try: | ||||
|                     renderer = JSONRenderer if request.json else HTMLRenderer | ||||
|                 except InvalidUsage: | ||||
|                     if request.route.ctx.error_format: | ||||
|                         render_format = request.route.ctx.error_format | ||||
|                 except AttributeError: | ||||
|                     ... | ||||
|  | ||||
|             content_type = request.headers.getone("content-type", "").split( | ||||
|                 ";" | ||||
|             )[0] | ||||
|  | ||||
|             acceptable = request.accept | ||||
|  | ||||
|             # If the format is auto still, make a guess | ||||
|             if render_format == "auto": | ||||
|                 # First, if there is an Accept header, check if text/html | ||||
|                 # is the first option | ||||
|                 # According to MDN Web Docs, all major browsers use text/html | ||||
|                 # as the primary value in Accept (with the exception of IE 8, | ||||
|                 # and, well, if you are supporting IE 8, then you have bigger | ||||
|                 # problems to concern yourself with than what default exception | ||||
|                 # renderer is used) | ||||
|                 # Source: | ||||
|                 # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values | ||||
|  | ||||
|                 if acceptable and acceptable[0].match( | ||||
|                     "text/html", | ||||
|                     allow_type_wildcard=False, | ||||
|                     allow_subtype_wildcard=False, | ||||
|                 ): | ||||
|                     renderer = HTMLRenderer | ||||
|  | ||||
|                 content_type, *_ = request.headers.getone( | ||||
|                     "content-type", "" | ||||
|                 ).split(";") | ||||
|                 renderer = RENDERERS_BY_CONTENT_TYPE.get( | ||||
|                     content_type, renderer | ||||
|                 ) | ||||
|                 # Second, if there is an Accept header, check if | ||||
|                 # application/json is an option, or if the content-type | ||||
|                 # is application/json | ||||
|                 elif ( | ||||
|                     acceptable | ||||
|                     and acceptable.match( | ||||
|                         "application/json", | ||||
|                         allow_type_wildcard=False, | ||||
|                         allow_subtype_wildcard=False, | ||||
|                     ) | ||||
|                     or content_type == "application/json" | ||||
|                 ): | ||||
|                     renderer = JSONRenderer | ||||
|  | ||||
|                 # Third, if there is no Accept header, assume we want text. | ||||
|                 # The likely use case here is a raw socket. | ||||
|                 elif not acceptable: | ||||
|                     renderer = TextRenderer | ||||
|                 else: | ||||
|                     # Fourth, look to see if there was a JSON body | ||||
|                     # When in this situation, the request is probably coming | ||||
|                     # from curl, an API client like Postman or Insomnia, or a | ||||
|                     # package like requests or httpx | ||||
|                     try: | ||||
|                         # Give them the benefit of the doubt if they did: | ||||
|                         # $ curl localhost:8000 -d '{"foo": "bar"}' | ||||
|                         # And provide them with JSONRenderer | ||||
|                         renderer = JSONRenderer if request.json else base | ||||
|                     except InvalidUsage: | ||||
|                         renderer = base | ||||
|             else: | ||||
|                 render_format = request.app.config.FALLBACK_ERROR_FORMAT | ||||
|                 renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) | ||||
|  | ||||
|             # Lastly, if there is an Accept header, make sure | ||||
|             # our choice is okay | ||||
|             if acceptable: | ||||
|                 type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer)  # type: ignore | ||||
|                 if type_ and type_ not in acceptable: | ||||
|                     # If the renderer selected is not in the Accept header | ||||
|                     # look through what is in the Accept header, and select | ||||
|                     # the first option that matches. Otherwise, just drop back | ||||
|                     # to the original default | ||||
|                     for accept in acceptable: | ||||
|                         mtype = f"{accept.type_}/{accept.subtype}" | ||||
|                         maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) | ||||
|                         if maybe: | ||||
|                             renderer = maybe | ||||
|                             break | ||||
|                     else: | ||||
|                         renderer = base | ||||
|  | ||||
|     renderer = t.cast(t.Type[BaseRenderer], renderer) | ||||
|     return renderer(request, exception, debug).render() | ||||
|   | ||||
| @@ -4,16 +4,20 @@ from sanic.helpers import STATUS_CODES | ||||
|  | ||||
|  | ||||
| class SanicException(Exception): | ||||
|     message: str = "" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         message: Optional[Union[str, bytes]] = None, | ||||
|         status_code: Optional[int] = None, | ||||
|         quiet: Optional[bool] = None, | ||||
|     ) -> None: | ||||
|  | ||||
|         if message is None and status_code is not None: | ||||
|             msg: bytes = STATUS_CODES.get(status_code, b"") | ||||
|             message = msg.decode("utf8") | ||||
|         if message is None: | ||||
|             if self.message: | ||||
|                 message = self.message | ||||
|             elif status_code is not None: | ||||
|                 msg: bytes = STATUS_CODES.get(status_code, b"") | ||||
|                 message = msg.decode("utf8") | ||||
|  | ||||
|         super().__init__(message) | ||||
|  | ||||
| @@ -122,8 +126,11 @@ class HeaderNotFound(InvalidUsage): | ||||
|     **Status**: 400 Bad Request | ||||
|     """ | ||||
|  | ||||
|     status_code = 400 | ||||
|     quiet = True | ||||
|  | ||||
| class InvalidHeader(InvalidUsage): | ||||
|     """ | ||||
|     **Status**: 400 Bad Request | ||||
|     """ | ||||
|  | ||||
|  | ||||
| class ContentRangeError(SanicException): | ||||
| @@ -230,6 +237,11 @@ class InvalidSignal(SanicException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class WebsocketClosed(SanicException): | ||||
|     quiet = True | ||||
|     message = "Client has closed the websocket connection" | ||||
|  | ||||
|  | ||||
| def abort(status_code: int, message: Optional[Union[str, bytes]] = None): | ||||
|     """ | ||||
|     Raise an exception based on SanicException. Returns the HTTP response | ||||
|   | ||||
| @@ -1,12 +1,14 @@ | ||||
| from traceback import format_exc | ||||
| from inspect import signature | ||||
| from typing import Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| from sanic.errorpages import exception_response | ||||
| from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response | ||||
| from sanic.exceptions import ( | ||||
|     ContentRangeError, | ||||
|     HeaderNotFound, | ||||
|     InvalidRangeType, | ||||
| ) | ||||
| from sanic.log import error_logger | ||||
| from sanic.models.handler_types import RouteHandler | ||||
| from sanic.response import text | ||||
|  | ||||
|  | ||||
| @@ -23,15 +25,54 @@ class ErrorHandler: | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     handlers = None | ||||
|     cached_handlers = None | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.handlers = [] | ||||
|         self.cached_handlers = {} | ||||
|     # Beginning in v22.3, the base renderer will be TextRenderer | ||||
|     def __init__( | ||||
|         self, fallback: str = "auto", base: Type[BaseRenderer] = HTMLRenderer | ||||
|     ): | ||||
|         self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] | ||||
|         self.cached_handlers: Dict[ | ||||
|             Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] | ||||
|         ] = {} | ||||
|         self.debug = False | ||||
|         self.fallback = fallback | ||||
|         self.base = base | ||||
|  | ||||
|     def add(self, exception, handler): | ||||
|     @classmethod | ||||
|     def finalize(cls, error_handler, fallback: Optional[str] = None): | ||||
|         if ( | ||||
|             fallback | ||||
|             and fallback != "auto" | ||||
|             and error_handler.fallback == "auto" | ||||
|         ): | ||||
|             error_handler.fallback = fallback | ||||
|  | ||||
|         if not isinstance(error_handler, cls): | ||||
|             error_logger.warning( | ||||
|                 f"Error handler is non-conforming: {type(error_handler)}" | ||||
|             ) | ||||
|  | ||||
|         sig = signature(error_handler.lookup) | ||||
|         if len(sig.parameters) == 1: | ||||
|             error_logger.warning( | ||||
|                 DeprecationWarning( | ||||
|                     "You are using a deprecated error handler. The lookup " | ||||
|                     "method should accept two positional parameters: " | ||||
|                     "(exception, route_name: Optional[str]). " | ||||
|                     "Until you upgrade your ErrorHandler.lookup, Blueprint " | ||||
|                     "specific exceptions will not work properly. Beginning " | ||||
|                     "in v22.3, the legacy style lookup method will not " | ||||
|                     "work at all." | ||||
|                 ), | ||||
|             ) | ||||
|             error_handler._lookup = error_handler._legacy_lookup | ||||
|  | ||||
|     def _full_lookup(self, exception, route_name: Optional[str] = None): | ||||
|         return self.lookup(exception, route_name) | ||||
|  | ||||
|     def _legacy_lookup(self, exception, route_name: Optional[str] = None): | ||||
|         return self.lookup(exception) | ||||
|  | ||||
|     def add(self, exception, handler, route_names: Optional[List[str]] = None): | ||||
|         """ | ||||
|         Add a new exception handler to an already existing handler object. | ||||
|  | ||||
| @@ -44,11 +85,16 @@ class ErrorHandler: | ||||
|  | ||||
|         :return: None | ||||
|         """ | ||||
|         # self.handlers to be deprecated and removed in version 21.12 | ||||
|         # self.handlers is deprecated and will be removed in version 22.3 | ||||
|         self.handlers.append((exception, handler)) | ||||
|         self.cached_handlers[exception] = handler | ||||
|  | ||||
|     def lookup(self, exception): | ||||
|         if route_names: | ||||
|             for route in route_names: | ||||
|                 self.cached_handlers[(exception, route)] = handler | ||||
|         else: | ||||
|             self.cached_handlers[(exception, None)] = handler | ||||
|  | ||||
|     def lookup(self, exception, route_name: Optional[str] = None): | ||||
|         """ | ||||
|         Lookup the existing instance of :class:`ErrorHandler` and fetch the | ||||
|         registered handler for a specific type of exception. | ||||
| @@ -63,20 +109,31 @@ class ErrorHandler: | ||||
|         :return: Registered function if found ``None`` otherwise | ||||
|         """ | ||||
|         exception_class = type(exception) | ||||
|         if exception_class in self.cached_handlers: | ||||
|             return self.cached_handlers[exception_class] | ||||
|  | ||||
|         for ancestor in type.mro(exception_class): | ||||
|             if ancestor in self.cached_handlers: | ||||
|                 handler = self.cached_handlers[ancestor] | ||||
|                 self.cached_handlers[exception_class] = handler | ||||
|         for name in (route_name, None): | ||||
|             exception_key = (exception_class, name) | ||||
|             handler = self.cached_handlers.get(exception_key) | ||||
|             if handler: | ||||
|                 return handler | ||||
|             if ancestor is BaseException: | ||||
|                 break | ||||
|         self.cached_handlers[exception_class] = None | ||||
|  | ||||
|         for name in (route_name, None): | ||||
|             for ancestor in type.mro(exception_class): | ||||
|                 exception_key = (ancestor, name) | ||||
|                 if exception_key in self.cached_handlers: | ||||
|                     handler = self.cached_handlers[exception_key] | ||||
|                     self.cached_handlers[ | ||||
|                         (exception_class, route_name) | ||||
|                     ] = handler | ||||
|                     return handler | ||||
|  | ||||
|                 if ancestor is BaseException: | ||||
|                     break | ||||
|         self.cached_handlers[(exception_class, route_name)] = None | ||||
|         handler = None | ||||
|         return handler | ||||
|  | ||||
|     _lookup = _full_lookup | ||||
|  | ||||
|     def response(self, request, exception): | ||||
|         """Fetches and executes an exception handler and returns a response | ||||
|         object | ||||
| @@ -91,7 +148,8 @@ class ErrorHandler: | ||||
|         :return: Wrap the return value obtained from :func:`default` | ||||
|             or registered handler for that type of exception. | ||||
|         """ | ||||
|         handler = self.lookup(exception) | ||||
|         route_name = request.name if request else None | ||||
|         handler = self._lookup(exception, route_name) | ||||
|         response = None | ||||
|         try: | ||||
|             if handler: | ||||
| @@ -99,7 +157,6 @@ class ErrorHandler: | ||||
|             if response is None: | ||||
|                 response = self.default(request, exception) | ||||
|         except Exception: | ||||
|             self.log(format_exc()) | ||||
|             try: | ||||
|                 url = repr(request.url) | ||||
|             except AttributeError: | ||||
| @@ -115,11 +172,6 @@ class ErrorHandler: | ||||
|                 return text("An error occurred while handling an error", 500) | ||||
|         return response | ||||
|  | ||||
|     def log(self, message, level="error"): | ||||
|         """ | ||||
|         Deprecated, do not use. | ||||
|         """ | ||||
|  | ||||
|     def default(self, request, exception): | ||||
|         """ | ||||
|         Provide a default behavior for the objects of :class:`ErrorHandler`. | ||||
| @@ -135,6 +187,17 @@ class ErrorHandler: | ||||
|             :class:`Exception` | ||||
|         :return: | ||||
|         """ | ||||
|         self.log(request, exception) | ||||
|         return exception_response( | ||||
|             request, | ||||
|             exception, | ||||
|             debug=self.debug, | ||||
|             base=self.base, | ||||
|             fallback=self.fallback, | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def log(request, exception): | ||||
|         quiet = getattr(exception, "quiet", False) | ||||
|         if quiet is False: | ||||
|             try: | ||||
| @@ -142,13 +205,10 @@ class ErrorHandler: | ||||
|             except AttributeError: | ||||
|                 url = "unknown" | ||||
|  | ||||
|             self.log(format_exc()) | ||||
|             error_logger.exception( | ||||
|                 "Exception occurred while handling uri: %s", url | ||||
|             ) | ||||
|  | ||||
|         return exception_response(request, exception, self.debug) | ||||
|  | ||||
|  | ||||
| class ContentRangeHandler: | ||||
|     """ | ||||
|   | ||||
							
								
								
									
										200
									
								
								sanic/headers.py
									
									
									
									
									
								
							
							
						
						
									
										200
									
								
								sanic/headers.py
									
									
									
									
									
								
							| @@ -1,8 +1,11 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import re | ||||
|  | ||||
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | ||||
| from urllib.parse import unquote | ||||
|  | ||||
| from sanic.exceptions import InvalidHeader | ||||
| from sanic.helpers import STATUS_CODES | ||||
|  | ||||
|  | ||||
| @@ -30,6 +33,175 @@ _host_re = re.compile( | ||||
| # For more information, consult ../tests/test_requests.py | ||||
|  | ||||
|  | ||||
| def parse_arg_as_accept(f): | ||||
|     def func(self, other, *args, **kwargs): | ||||
|         if not isinstance(other, Accept) and other: | ||||
|             other = Accept.parse(other) | ||||
|         return f(self, other, *args, **kwargs) | ||||
|  | ||||
|     return func | ||||
|  | ||||
|  | ||||
| class MediaType(str): | ||||
|     def __new__(cls, value: str): | ||||
|         return str.__new__(cls, value) | ||||
|  | ||||
|     def __init__(self, value: str) -> None: | ||||
|         self.value = value | ||||
|         self.is_wildcard = self.check_if_wildcard(value) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if self.is_wildcard: | ||||
|             return True | ||||
|  | ||||
|         if self.match(other): | ||||
|             return True | ||||
|  | ||||
|         other_is_wildcard = ( | ||||
|             other.is_wildcard | ||||
|             if isinstance(other, MediaType) | ||||
|             else self.check_if_wildcard(other) | ||||
|         ) | ||||
|  | ||||
|         return other_is_wildcard | ||||
|  | ||||
|     def match(self, other): | ||||
|         other_value = other.value if isinstance(other, MediaType) else other | ||||
|         return self.value == other_value | ||||
|  | ||||
|     @staticmethod | ||||
|     def check_if_wildcard(value): | ||||
|         return value == "*" | ||||
|  | ||||
|  | ||||
| class Accept(str): | ||||
|     def __new__(cls, value: str, *args, **kwargs): | ||||
|         return str.__new__(cls, value) | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         value: str, | ||||
|         type_: MediaType, | ||||
|         subtype: MediaType, | ||||
|         *, | ||||
|         q: str = "1.0", | ||||
|         **kwargs: str, | ||||
|     ): | ||||
|         qvalue = float(q) | ||||
|         if qvalue > 1 or qvalue < 0: | ||||
|             raise InvalidHeader( | ||||
|                 f"Accept header qvalue must be between 0 and 1, not: {qvalue}" | ||||
|             ) | ||||
|         self.value = value | ||||
|         self.type_ = type_ | ||||
|         self.subtype = subtype | ||||
|         self.qvalue = qvalue | ||||
|         self.params = kwargs | ||||
|  | ||||
|     def _compare(self, other, method): | ||||
|         try: | ||||
|             return method(self.qvalue, other.qvalue) | ||||
|         except (AttributeError, TypeError): | ||||
|             return NotImplemented | ||||
|  | ||||
|     @parse_arg_as_accept | ||||
|     def __lt__(self, other: Union[str, Accept]): | ||||
|         return self._compare(other, lambda s, o: s < o) | ||||
|  | ||||
|     @parse_arg_as_accept | ||||
|     def __le__(self, other: Union[str, Accept]): | ||||
|         return self._compare(other, lambda s, o: s <= o) | ||||
|  | ||||
|     @parse_arg_as_accept | ||||
|     def __eq__(self, other: Union[str, Accept]):  # type: ignore | ||||
|         return self._compare(other, lambda s, o: s == o) | ||||
|  | ||||
|     @parse_arg_as_accept | ||||
|     def __ge__(self, other: Union[str, Accept]): | ||||
|         return self._compare(other, lambda s, o: s >= o) | ||||
|  | ||||
|     @parse_arg_as_accept | ||||
|     def __gt__(self, other: Union[str, Accept]): | ||||
|         return self._compare(other, lambda s, o: s > o) | ||||
|  | ||||
|     @parse_arg_as_accept | ||||
|     def __ne__(self, other: Union[str, Accept]):  # type: ignore | ||||
|         return self._compare(other, lambda s, o: s != o) | ||||
|  | ||||
|     @parse_arg_as_accept | ||||
|     def match( | ||||
|         self, | ||||
|         other, | ||||
|         *, | ||||
|         allow_type_wildcard: bool = True, | ||||
|         allow_subtype_wildcard: bool = True, | ||||
|     ) -> bool: | ||||
|         type_match = ( | ||||
|             self.type_ == other.type_ | ||||
|             if allow_type_wildcard | ||||
|             else ( | ||||
|                 self.type_.match(other.type_) | ||||
|                 and not self.type_.is_wildcard | ||||
|                 and not other.type_.is_wildcard | ||||
|             ) | ||||
|         ) | ||||
|         subtype_match = ( | ||||
|             self.subtype == other.subtype | ||||
|             if allow_subtype_wildcard | ||||
|             else ( | ||||
|                 self.subtype.match(other.subtype) | ||||
|                 and not self.subtype.is_wildcard | ||||
|                 and not other.subtype.is_wildcard | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         return type_match and subtype_match | ||||
|  | ||||
|     @classmethod | ||||
|     def parse(cls, raw: str) -> Accept: | ||||
|         invalid = False | ||||
|         mtype = raw.strip() | ||||
|  | ||||
|         try: | ||||
|             media, *raw_params = mtype.split(";") | ||||
|             type_, subtype = media.split("/") | ||||
|         except ValueError: | ||||
|             invalid = True | ||||
|  | ||||
|         if invalid or not type_ or not subtype: | ||||
|             raise InvalidHeader(f"Header contains invalid Accept value: {raw}") | ||||
|  | ||||
|         params = dict( | ||||
|             [ | ||||
|                 (key.strip(), value.strip()) | ||||
|                 for key, value in (param.split("=", 1) for param in raw_params) | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         return cls(mtype, MediaType(type_), MediaType(subtype), **params) | ||||
|  | ||||
|  | ||||
| class AcceptContainer(list): | ||||
|     def __contains__(self, o: object) -> bool: | ||||
|         return any(item.match(o) for item in self) | ||||
|  | ||||
|     def match( | ||||
|         self, | ||||
|         o: object, | ||||
|         *, | ||||
|         allow_type_wildcard: bool = True, | ||||
|         allow_subtype_wildcard: bool = True, | ||||
|     ) -> bool: | ||||
|         return any( | ||||
|             item.match( | ||||
|                 o, | ||||
|                 allow_type_wildcard=allow_type_wildcard, | ||||
|                 allow_subtype_wildcard=allow_subtype_wildcard, | ||||
|             ) | ||||
|             for item in self | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def parse_content_header(value: str) -> Tuple[str, Options]: | ||||
|     """Parse content-type and content-disposition header values. | ||||
|  | ||||
| @@ -194,3 +366,31 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: | ||||
|         ret += b"%b: %b\r\n" % h | ||||
|     ret += b"\r\n" | ||||
|     return ret | ||||
|  | ||||
|  | ||||
| def _sort_accept_value(accept: Accept): | ||||
|     return ( | ||||
|         accept.qvalue, | ||||
|         len(accept.params), | ||||
|         accept.subtype != "*", | ||||
|         accept.type_ != "*", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def parse_accept(accept: str) -> AcceptContainer: | ||||
|     """Parse an Accept header and order the acceptable media types in | ||||
|     accorsing to RFC 7231, s. 5.3.2 | ||||
|     https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 | ||||
|     """ | ||||
|     media_types = accept.split(",") | ||||
|     accept_list: List[Accept] = [] | ||||
|  | ||||
|     for mtype in media_types: | ||||
|         if not mtype: | ||||
|             continue | ||||
|  | ||||
|         accept_list.append(Accept.parse(mtype)) | ||||
|  | ||||
|     return AcceptContainer( | ||||
|         sorted(accept_list, key=_sort_accept_value, reverse=True) | ||||
|     ) | ||||
|   | ||||
| @@ -155,3 +155,17 @@ def import_string(module_name, package=None): | ||||
|     if ismodule(obj): | ||||
|         return obj | ||||
|     return obj() | ||||
|  | ||||
|  | ||||
| class Default: | ||||
|     """ | ||||
|     It is used to replace `None` or `object()` as a sentinel | ||||
|     that represents a default value. Sometimes we want to set | ||||
|     a value to `None` so we cannot use `None` to represent the | ||||
|     default value, and `object()` is hard to be typed. | ||||
|     """ | ||||
|  | ||||
|     pass | ||||
|  | ||||
|  | ||||
| _default = Default() | ||||
|   | ||||
| @@ -21,6 +21,7 @@ from sanic.exceptions import ( | ||||
| from sanic.headers import format_http1_response | ||||
| from sanic.helpers import has_message_body | ||||
| from sanic.log import access_logger, error_logger, logger | ||||
| from sanic.touchup import TouchUpMeta | ||||
|  | ||||
|  | ||||
| class Stage(Enum): | ||||
| @@ -45,7 +46,7 @@ class Stage(Enum): | ||||
| HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" | ||||
|  | ||||
|  | ||||
| class Http: | ||||
| class Http(metaclass=TouchUpMeta): | ||||
|     """ | ||||
|     Internal helper for managing the HTTP request/response cycle | ||||
|  | ||||
| @@ -67,9 +68,15 @@ class Http: | ||||
|     HEADER_CEILING = 16_384 | ||||
|     HEADER_MAX_SIZE = 0 | ||||
|  | ||||
|     __touchup__ = ( | ||||
|         "http1_request_header", | ||||
|         "http1_response_header", | ||||
|         "read", | ||||
|     ) | ||||
|     __slots__ = [ | ||||
|         "_send", | ||||
|         "_receive_more", | ||||
|         "dispatch", | ||||
|         "recv_buffer", | ||||
|         "protocol", | ||||
|         "expecting_continue", | ||||
| @@ -97,7 +104,7 @@ class Http: | ||||
|         self.protocol = protocol | ||||
|         self.keep_alive = True | ||||
|         self.stage: Stage = Stage.IDLE | ||||
|         self.init_for_request() | ||||
|         self.dispatch = self.protocol.app.dispatch | ||||
|  | ||||
|     def init_for_request(self): | ||||
|         """Init/reset all per-request variables.""" | ||||
| @@ -121,14 +128,20 @@ class Http: | ||||
|         """ | ||||
|         HTTP 1.1 connection handler | ||||
|         """ | ||||
|         while True:  # As long as connection stays keep-alive | ||||
|         # Handle requests while the connection stays reusable | ||||
|         while self.keep_alive and self.stage is Stage.IDLE: | ||||
|             self.init_for_request() | ||||
|             # Wait for incoming bytes (in IDLE stage) | ||||
|             if not self.recv_buffer: | ||||
|                 await self._receive_more() | ||||
|             self.stage = Stage.REQUEST | ||||
|             try: | ||||
|                 # Receive and handle a request | ||||
|                 self.stage = Stage.REQUEST | ||||
|                 self.response_func = self.http1_response_header | ||||
|  | ||||
|                 await self.http1_request_header() | ||||
|  | ||||
|                 self.stage = Stage.HANDLER | ||||
|                 self.request.conn_info = self.protocol.conn_info | ||||
|                 await self.protocol.request_handler(self.request) | ||||
|  | ||||
| @@ -140,6 +153,12 @@ class Http: | ||||
|                     await self.response.send(end_stream=True) | ||||
|             except CancelledError: | ||||
|                 # Write an appropriate response before exiting | ||||
|                 if not self.protocol.transport: | ||||
|                     logger.info( | ||||
|                         f"Request: {self.request.method} {self.request.url} " | ||||
|                         "stopped. Transport is closed." | ||||
|                     ) | ||||
|                     return | ||||
|                 e = self.exception or ServiceUnavailable("Cancelled") | ||||
|                 self.exception = None | ||||
|                 self.keep_alive = False | ||||
| @@ -173,17 +192,7 @@ class Http: | ||||
|                 if self.response: | ||||
|                     self.response.stream = None | ||||
|  | ||||
|             self.init_for_request() | ||||
|  | ||||
|             # Exit and disconnect if no more requests can be taken | ||||
|             if self.stage is not Stage.IDLE or not self.keep_alive: | ||||
|                 break | ||||
|  | ||||
|             # Wait for the next request | ||||
|             if not self.recv_buffer: | ||||
|                 await self._receive_more() | ||||
|  | ||||
|     async def http1_request_header(self): | ||||
|     async def http1_request_header(self):  # no cov | ||||
|         """ | ||||
|         Receive and parse request header into self.request. | ||||
|         """ | ||||
| @@ -212,6 +221,12 @@ class Http: | ||||
|             reqline, *split_headers = raw_headers.split("\r\n") | ||||
|             method, self.url, protocol = reqline.split(" ") | ||||
|  | ||||
|             await self.dispatch( | ||||
|                 "http.lifecycle.read_head", | ||||
|                 inline=True, | ||||
|                 context={"head": bytes(head)}, | ||||
|             ) | ||||
|  | ||||
|             if protocol == "HTTP/1.1": | ||||
|                 self.keep_alive = True | ||||
|             elif protocol == "HTTP/1.0": | ||||
| @@ -250,6 +265,11 @@ class Http: | ||||
|             transport=self.protocol.transport, | ||||
|             app=self.protocol.app, | ||||
|         ) | ||||
|         await self.dispatch( | ||||
|             "http.lifecycle.request", | ||||
|             inline=True, | ||||
|             context={"request": request}, | ||||
|         ) | ||||
|  | ||||
|         # Prepare for request body | ||||
|         self.request_bytes_left = self.request_bytes = 0 | ||||
| @@ -274,13 +294,12 @@ class Http: | ||||
|  | ||||
|         # Remove header and its trailing CRLF | ||||
|         del buf[: pos + 4] | ||||
|         self.stage = Stage.HANDLER | ||||
|         self.request, request.stream = request, self | ||||
|         self.protocol.state["requests_count"] += 1 | ||||
|  | ||||
|     async def http1_response_header( | ||||
|         self, data: bytes, end_stream: bool | ||||
|     ) -> None: | ||||
|     ) -> None:  # no cov | ||||
|         res = self.response | ||||
|  | ||||
|         # Compatibility with simple response body | ||||
| @@ -452,8 +471,8 @@ class Http: | ||||
|             "request": "nil", | ||||
|         } | ||||
|         if req is not None: | ||||
|             if req.ip: | ||||
|                 extra["host"] = f"{req.ip}:{req.port}" | ||||
|             if req.remote_addr or req.ip: | ||||
|                 extra["host"] = f"{req.remote_addr or req.ip}:{req.port}" | ||||
|             extra["request"] = f"{req.method} {req.url}" | ||||
|         access_logger.info("", extra=extra) | ||||
|  | ||||
| @@ -469,7 +488,7 @@ class Http: | ||||
|             if data: | ||||
|                 yield data | ||||
|  | ||||
|     async def read(self) -> Optional[bytes]: | ||||
|     async def read(self) -> Optional[bytes]:  # no cov | ||||
|         """ | ||||
|         Read some bytes of request body. | ||||
|         """ | ||||
| @@ -543,6 +562,12 @@ class Http: | ||||
|  | ||||
|         self.request_bytes_left -= size | ||||
|  | ||||
|         await self.dispatch( | ||||
|             "http.lifecycle.read_body", | ||||
|             inline=True, | ||||
|             context={"body": data}, | ||||
|         ) | ||||
|  | ||||
|         return data | ||||
|  | ||||
|     # Response methods | ||||
|   | ||||
| @@ -1,18 +1,19 @@ | ||||
| from enum import Enum, auto | ||||
| from functools import partial | ||||
| from typing import Any, Callable, Coroutine, List, Optional, Union | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
| from sanic.models.futures import FutureListener | ||||
| from sanic.models.handler_types import ListenerType | ||||
|  | ||||
|  | ||||
| class ListenerEvent(str, Enum): | ||||
|     def _generate_next_value_(name: str, *args) -> str:  # type: ignore | ||||
|         return name.lower() | ||||
|  | ||||
|     BEFORE_SERVER_START = auto() | ||||
|     AFTER_SERVER_START = auto() | ||||
|     BEFORE_SERVER_STOP = auto() | ||||
|     AFTER_SERVER_STOP = auto() | ||||
|     BEFORE_SERVER_START = "server.init.before" | ||||
|     AFTER_SERVER_START = "server.init.after" | ||||
|     BEFORE_SERVER_STOP = "server.shutdown.before" | ||||
|     AFTER_SERVER_STOP = "server.shutdown.after" | ||||
|     MAIN_PROCESS_START = auto() | ||||
|     MAIN_PROCESS_STOP = auto() | ||||
|  | ||||
| @@ -26,9 +27,7 @@ class ListenerMixin: | ||||
|  | ||||
|     def listener( | ||||
|         self, | ||||
|         listener_or_event: Union[ | ||||
|             Callable[..., Coroutine[Any, Any, None]], str | ||||
|         ], | ||||
|         listener_or_event: Union[ListenerType, str], | ||||
|         event_or_none: Optional[str] = None, | ||||
|         apply: bool = True, | ||||
|     ): | ||||
| @@ -63,20 +62,20 @@ class ListenerMixin: | ||||
|         else: | ||||
|             return partial(register_listener, event=listener_or_event) | ||||
|  | ||||
|     def main_process_start(self, listener): | ||||
|     def main_process_start(self, listener: ListenerType) -> ListenerType: | ||||
|         return self.listener(listener, "main_process_start") | ||||
|  | ||||
|     def main_process_stop(self, listener): | ||||
|     def main_process_stop(self, listener: ListenerType) -> ListenerType: | ||||
|         return self.listener(listener, "main_process_stop") | ||||
|  | ||||
|     def before_server_start(self, listener): | ||||
|     def before_server_start(self, listener: ListenerType) -> ListenerType: | ||||
|         return self.listener(listener, "before_server_start") | ||||
|  | ||||
|     def after_server_start(self, listener): | ||||
|     def after_server_start(self, listener: ListenerType) -> ListenerType: | ||||
|         return self.listener(listener, "after_server_start") | ||||
|  | ||||
|     def before_server_stop(self, listener): | ||||
|     def before_server_stop(self, listener: ListenerType) -> ListenerType: | ||||
|         return self.listener(listener, "before_server_stop") | ||||
|  | ||||
|     def after_server_stop(self, listener): | ||||
|     def after_server_stop(self, listener: ListenerType) -> ListenerType: | ||||
|         return self.listener(listener, "after_server_stop") | ||||
|   | ||||
| @@ -1,17 +1,20 @@ | ||||
| from ast import NodeVisitor, Return, parse | ||||
| from functools import partial, wraps | ||||
| from inspect import signature | ||||
| from inspect import getsource, signature | ||||
| from mimetypes import guess_type | ||||
| from os import path | ||||
| from pathlib import PurePath | ||||
| from re import sub | ||||
| from textwrap import dedent | ||||
| from time import gmtime, strftime | ||||
| from typing import Iterable, List, Optional, Set, Union | ||||
| from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union | ||||
| from urllib.parse import unquote | ||||
|  | ||||
| from sanic_routing.route import Route  # type: ignore | ||||
|  | ||||
| from sanic.compat import stat_async | ||||
| from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS | ||||
| from sanic.errorpages import RESPONSE_MAPPING | ||||
| from sanic.exceptions import ( | ||||
|     ContentRangeError, | ||||
|     FileNotFound, | ||||
| @@ -21,10 +24,16 @@ from sanic.exceptions import ( | ||||
| from sanic.handlers import ContentRangeHandler | ||||
| from sanic.log import error_logger | ||||
| from sanic.models.futures import FutureRoute, FutureStatic | ||||
| from sanic.models.handler_types import RouteHandler | ||||
| from sanic.response import HTTPResponse, file, file_stream | ||||
| from sanic.views import CompositionView | ||||
|  | ||||
|  | ||||
| RouteWrapper = Callable[ | ||||
|     [RouteHandler], Union[RouteHandler, Tuple[Route, RouteHandler]] | ||||
| ] | ||||
|  | ||||
|  | ||||
| class RouteMixin: | ||||
|     name: str | ||||
|  | ||||
| @@ -55,7 +64,8 @@ class RouteMixin: | ||||
|         unquote: bool = False, | ||||
|         static: bool = False, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Decorate a function to be registered as a route | ||||
|  | ||||
| @@ -97,6 +107,7 @@ class RouteMixin: | ||||
|             nonlocal websocket | ||||
|             nonlocal static | ||||
|             nonlocal version_prefix | ||||
|             nonlocal error_format | ||||
|  | ||||
|             if isinstance(handler, tuple): | ||||
|                 # if a handler fn is already wrapped in a route, the handler | ||||
| @@ -115,10 +126,16 @@ class RouteMixin: | ||||
|                         "Expected either string or Iterable of host strings, " | ||||
|                         "not %s" % host | ||||
|                     ) | ||||
|  | ||||
|             if isinstance(subprotocols, (list, tuple, set)): | ||||
|             if isinstance(subprotocols, list): | ||||
|                 # Ordered subprotocols, maintain order | ||||
|                 subprotocols = tuple(subprotocols) | ||||
|             elif isinstance(subprotocols, set): | ||||
|                 # subprotocol is unordered, keep it unordered | ||||
|                 subprotocols = frozenset(subprotocols) | ||||
|  | ||||
|             if not error_format or error_format == "auto": | ||||
|                 error_format = self._determine_error_format(handler) | ||||
|  | ||||
|             route = FutureRoute( | ||||
|                 handler, | ||||
|                 uri, | ||||
| @@ -134,6 +151,7 @@ class RouteMixin: | ||||
|                 unquote, | ||||
|                 static, | ||||
|                 version_prefix, | ||||
|                 error_format, | ||||
|             ) | ||||
|  | ||||
|             self._future_routes.add(route) | ||||
| @@ -168,7 +186,7 @@ class RouteMixin: | ||||
|  | ||||
|     def add_route( | ||||
|         self, | ||||
|         handler, | ||||
|         handler: RouteHandler, | ||||
|         uri: str, | ||||
|         methods: Iterable[str] = frozenset({"GET"}), | ||||
|         host: Optional[str] = None, | ||||
| @@ -177,7 +195,8 @@ class RouteMixin: | ||||
|         name: Optional[str] = None, | ||||
|         stream: bool = False, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteHandler: | ||||
|         """A helper method to register class instance or | ||||
|         functions as a handler to the application url | ||||
|         routes. | ||||
| @@ -200,7 +219,8 @@ class RouteMixin: | ||||
|             methods = set() | ||||
|  | ||||
|             for method in HTTP_METHODS: | ||||
|                 _handler = getattr(handler.view_class, method.lower(), None) | ||||
|                 view_class = getattr(handler, "view_class") | ||||
|                 _handler = getattr(view_class, method.lower(), None) | ||||
|                 if _handler: | ||||
|                     methods.add(method) | ||||
|                     if hasattr(_handler, "is_stream"): | ||||
| @@ -226,6 +246,7 @@ class RouteMixin: | ||||
|             version=version, | ||||
|             name=name, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         )(handler) | ||||
|         return handler | ||||
|  | ||||
| @@ -239,7 +260,8 @@ class RouteMixin: | ||||
|         name: Optional[str] = None, | ||||
|         ignore_body: bool = True, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Add an API URL under the **GET** *HTTP* method | ||||
|  | ||||
| @@ -262,6 +284,7 @@ class RouteMixin: | ||||
|             name=name, | ||||
|             ignore_body=ignore_body, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def post( | ||||
| @@ -273,7 +296,8 @@ class RouteMixin: | ||||
|         version: Optional[int] = None, | ||||
|         name: Optional[str] = None, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Add an API URL under the **POST** *HTTP* method | ||||
|  | ||||
| @@ -296,6 +320,7 @@ class RouteMixin: | ||||
|             version=version, | ||||
|             name=name, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def put( | ||||
| @@ -307,7 +332,8 @@ class RouteMixin: | ||||
|         version: Optional[int] = None, | ||||
|         name: Optional[str] = None, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Add an API URL under the **PUT** *HTTP* method | ||||
|  | ||||
| @@ -330,6 +356,7 @@ class RouteMixin: | ||||
|             version=version, | ||||
|             name=name, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def head( | ||||
| @@ -341,7 +368,8 @@ class RouteMixin: | ||||
|         name: Optional[str] = None, | ||||
|         ignore_body: bool = True, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Add an API URL under the **HEAD** *HTTP* method | ||||
|  | ||||
| @@ -372,6 +400,7 @@ class RouteMixin: | ||||
|             name=name, | ||||
|             ignore_body=ignore_body, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def options( | ||||
| @@ -383,7 +412,8 @@ class RouteMixin: | ||||
|         name: Optional[str] = None, | ||||
|         ignore_body: bool = True, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Add an API URL under the **OPTIONS** *HTTP* method | ||||
|  | ||||
| @@ -414,6 +444,7 @@ class RouteMixin: | ||||
|             name=name, | ||||
|             ignore_body=ignore_body, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def patch( | ||||
| @@ -425,7 +456,8 @@ class RouteMixin: | ||||
|         version: Optional[int] = None, | ||||
|         name: Optional[str] = None, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Add an API URL under the **PATCH** *HTTP* method | ||||
|  | ||||
| @@ -458,6 +490,7 @@ class RouteMixin: | ||||
|             version=version, | ||||
|             name=name, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def delete( | ||||
| @@ -469,7 +502,8 @@ class RouteMixin: | ||||
|         name: Optional[str] = None, | ||||
|         ignore_body: bool = True, | ||||
|         version_prefix: str = "/v", | ||||
|     ): | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> RouteWrapper: | ||||
|         """ | ||||
|         Add an API URL under the **DELETE** *HTTP* method | ||||
|  | ||||
| @@ -492,6 +526,7 @@ class RouteMixin: | ||||
|             name=name, | ||||
|             ignore_body=ignore_body, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def websocket( | ||||
| @@ -504,6 +539,7 @@ class RouteMixin: | ||||
|         name: Optional[str] = None, | ||||
|         apply: bool = True, | ||||
|         version_prefix: str = "/v", | ||||
|         error_format: Optional[str] = None, | ||||
|     ): | ||||
|         """ | ||||
|         Decorate a function to be registered as a websocket route | ||||
| @@ -530,6 +566,7 @@ class RouteMixin: | ||||
|             subprotocols=subprotocols, | ||||
|             websocket=True, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         ) | ||||
|  | ||||
|     def add_websocket_route( | ||||
| @@ -542,6 +579,7 @@ class RouteMixin: | ||||
|         version: Optional[int] = None, | ||||
|         name: Optional[str] = None, | ||||
|         version_prefix: str = "/v", | ||||
|         error_format: Optional[str] = None, | ||||
|     ): | ||||
|         """ | ||||
|         A helper method to register a function as a websocket route. | ||||
| @@ -570,6 +608,7 @@ class RouteMixin: | ||||
|             version=version, | ||||
|             name=name, | ||||
|             version_prefix=version_prefix, | ||||
|             error_format=error_format, | ||||
|         )(handler) | ||||
|  | ||||
|     def static( | ||||
| @@ -585,6 +624,7 @@ class RouteMixin: | ||||
|         strict_slashes=None, | ||||
|         content_type=None, | ||||
|         apply=True, | ||||
|         resource_type=None, | ||||
|     ): | ||||
|         """ | ||||
|         Register a root to serve files from. The input can either be a | ||||
| @@ -634,6 +674,7 @@ class RouteMixin: | ||||
|             host, | ||||
|             strict_slashes, | ||||
|             content_type, | ||||
|             resource_type, | ||||
|         ) | ||||
|         self._future_statics.add(static) | ||||
|  | ||||
| @@ -777,10 +818,11 @@ class RouteMixin: | ||||
|             ) | ||||
|         except Exception: | ||||
|             error_logger.exception( | ||||
|                 f"Exception in static request handler:\ | ||||
|  path={file_or_directory}, " | ||||
|                 f"Exception in static request handler: " | ||||
|                 f"path={file_or_directory}, " | ||||
|                 f"relative_url={__file_uri__}" | ||||
|             ) | ||||
|             raise | ||||
|  | ||||
|     def _register_static( | ||||
|         self, | ||||
| @@ -828,8 +870,27 @@ class RouteMixin: | ||||
|         name = static.name | ||||
|         # If we're not trying to match a file directly, | ||||
|         # serve from the folder | ||||
|         if not path.isfile(file_or_directory): | ||||
|         if not static.resource_type: | ||||
|             if not path.isfile(file_or_directory): | ||||
|                 uri += "/<__file_uri__:path>" | ||||
|         elif static.resource_type == "dir": | ||||
|             if path.isfile(file_or_directory): | ||||
|                 raise TypeError( | ||||
|                     "Resource type improperly identified as directory. " | ||||
|                     f"'{file_or_directory}'" | ||||
|                 ) | ||||
|             uri += "/<__file_uri__:path>" | ||||
|         elif static.resource_type == "file" and not path.isfile( | ||||
|             file_or_directory | ||||
|         ): | ||||
|             raise TypeError( | ||||
|                 "Resource type improperly identified as file. " | ||||
|                 f"'{file_or_directory}'" | ||||
|             ) | ||||
|         elif static.resource_type != "file": | ||||
|             raise ValueError( | ||||
|                 "The resource_type should be set to 'file' or 'dir'" | ||||
|             ) | ||||
|  | ||||
|         # special prefix for static files | ||||
|         # if not static.name.startswith("_static_"): | ||||
| @@ -846,7 +907,7 @@ class RouteMixin: | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         route, _ = self.route( | ||||
|         route, _ = self.route(  # type: ignore | ||||
|             uri=uri, | ||||
|             methods=["GET", "HEAD"], | ||||
|             name=name, | ||||
| @@ -856,3 +917,43 @@ class RouteMixin: | ||||
|         )(_handler) | ||||
|  | ||||
|         return route | ||||
|  | ||||
|     def _determine_error_format(self, handler) -> Optional[str]: | ||||
|         if not isinstance(handler, CompositionView): | ||||
|             try: | ||||
|                 src = dedent(getsource(handler)) | ||||
|                 tree = parse(src) | ||||
|                 http_response_types = self._get_response_types(tree) | ||||
|  | ||||
|                 if len(http_response_types) == 1: | ||||
|                     return next(iter(http_response_types)) | ||||
|             except (OSError, TypeError): | ||||
|                 ... | ||||
|  | ||||
|         return None | ||||
|  | ||||
|     def _get_response_types(self, node): | ||||
|         types = set() | ||||
|  | ||||
|         class HttpResponseVisitor(NodeVisitor): | ||||
|             def visit_Return(self, node: Return) -> Any: | ||||
|                 nonlocal types | ||||
|  | ||||
|                 try: | ||||
|                     checks = [node.value.func.id]  # type: ignore | ||||
|                     if node.value.keywords:  # type: ignore | ||||
|                         checks += [ | ||||
|                             k.value | ||||
|                             for k in node.value.keywords  # type: ignore | ||||
|                             if k.arg == "content_type" | ||||
|                         ] | ||||
|  | ||||
|                     for check in checks: | ||||
|                         if check in RESPONSE_MAPPING: | ||||
|                             types.add(RESPONSE_MAPPING[check]) | ||||
|                 except AttributeError: | ||||
|                     ... | ||||
|  | ||||
|         HttpResponseVisitor().visit(node) | ||||
|  | ||||
|         return types | ||||
|   | ||||
| @@ -23,7 +23,7 @@ class SignalMixin: | ||||
|         *, | ||||
|         apply: bool = True, | ||||
|         condition: Dict[str, Any] = None, | ||||
|     ) -> Callable[[SignalHandler], FutureSignal]: | ||||
|     ) -> Callable[[SignalHandler], SignalHandler]: | ||||
|         """ | ||||
|         For creating a signal handler, used similar to a route handler: | ||||
|  | ||||
| @@ -54,7 +54,7 @@ class SignalMixin: | ||||
|             if apply: | ||||
|                 self._apply_signal(future_signal) | ||||
|  | ||||
|             return future_signal | ||||
|             return handler | ||||
|  | ||||
|         return decorator | ||||
|  | ||||
|   | ||||
| @@ -3,7 +3,7 @@ import asyncio | ||||
| from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union | ||||
|  | ||||
| from sanic.exceptions import InvalidUsage | ||||
| from sanic.websocket import WebSocketConnection | ||||
| from sanic.server.websockets.connection import WebSocketConnection | ||||
|  | ||||
|  | ||||
| ASGIScope = MutableMapping[str, Any] | ||||
|   | ||||
| @@ -24,6 +24,7 @@ class FutureRoute(NamedTuple): | ||||
|     unquote: bool | ||||
|     static: bool | ||||
|     version_prefix: str | ||||
|     error_format: Optional[str] | ||||
|  | ||||
|  | ||||
| class FutureListener(NamedTuple): | ||||
| @@ -52,6 +53,7 @@ class FutureStatic(NamedTuple): | ||||
|     host: Optional[str] | ||||
|     strict_slashes: Optional[bool] | ||||
|     content_type: Optional[bool] | ||||
|     resource_type: Optional[str] | ||||
|  | ||||
|  | ||||
| class FutureSignal(NamedTuple): | ||||
|   | ||||
| @@ -21,5 +21,5 @@ MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType] | ||||
| ListenerType = Callable[ | ||||
|     [Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]] | ||||
| ] | ||||
| RouteHandler = Callable[..., Coroutine[Any, Any, HTTPResponse]] | ||||
| RouteHandler = Callable[..., Coroutine[Any, Any, Optional[HTTPResponse]]] | ||||
| SignalHandler = Callable[..., Coroutine[Any, Any, None]] | ||||
|   | ||||
							
								
								
									
										52
									
								
								sanic/models/server_types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								sanic/models/server_types.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| from types import SimpleNamespace | ||||
|  | ||||
| from sanic.models.protocol_types import TransportProtocol | ||||
|  | ||||
|  | ||||
| class Signal: | ||||
|     stopped = False | ||||
|  | ||||
|  | ||||
| class ConnInfo: | ||||
|     """ | ||||
|     Local and remote addresses and SSL status info. | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ( | ||||
|         "client_port", | ||||
|         "client", | ||||
|         "client_ip", | ||||
|         "ctx", | ||||
|         "peername", | ||||
|         "server_port", | ||||
|         "server", | ||||
|         "sockname", | ||||
|         "ssl", | ||||
|     ) | ||||
|  | ||||
|     def __init__(self, transport: TransportProtocol, unix=None): | ||||
|         self.ctx = SimpleNamespace() | ||||
|         self.peername = None | ||||
|         self.server = self.client = "" | ||||
|         self.server_port = self.client_port = 0 | ||||
|         self.client_ip = "" | ||||
|         self.sockname = addr = transport.get_extra_info("sockname") | ||||
|         self.ssl: bool = bool(transport.get_extra_info("sslcontext")) | ||||
|  | ||||
|         if isinstance(addr, str):  # UNIX socket | ||||
|             self.server = unix or addr | ||||
|             return | ||||
|  | ||||
|         # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) | ||||
|         if isinstance(addr, tuple): | ||||
|             self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" | ||||
|             self.server_port = addr[1] | ||||
|             # self.server gets non-standard port appended | ||||
|             if addr[1] != (443 if self.ssl else 80): | ||||
|                 self.server = f"{self.server}:{addr[1]}" | ||||
|         self.peername = addr = transport.get_extra_info("peername") | ||||
|  | ||||
|         if isinstance(addr, tuple): | ||||
|             self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" | ||||
|             self.client_ip = addr[0] | ||||
|             self.client_port = addr[1] | ||||
| @@ -34,7 +34,9 @@ from sanic.compat import CancelledErrors, Header | ||||
| from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE | ||||
| from sanic.exceptions import InvalidUsage | ||||
| from sanic.headers import ( | ||||
|     AcceptContainer, | ||||
|     Options, | ||||
|     parse_accept, | ||||
|     parse_content_header, | ||||
|     parse_forwarded, | ||||
|     parse_host, | ||||
| @@ -94,6 +96,7 @@ class Request: | ||||
|         "head", | ||||
|         "headers", | ||||
|         "method", | ||||
|         "parsed_accept", | ||||
|         "parsed_args", | ||||
|         "parsed_not_grouped_args", | ||||
|         "parsed_files", | ||||
| @@ -136,6 +139,7 @@ class Request: | ||||
|         self.conn_info: Optional[ConnInfo] = None | ||||
|         self.ctx = SimpleNamespace() | ||||
|         self.parsed_forwarded: Optional[Options] = None | ||||
|         self.parsed_accept: Optional[AcceptContainer] = None | ||||
|         self.parsed_json = None | ||||
|         self.parsed_form = None | ||||
|         self.parsed_files = None | ||||
| @@ -296,6 +300,13 @@ class Request: | ||||
|  | ||||
|         return self.parsed_json | ||||
|  | ||||
|     @property | ||||
|     def accept(self) -> AcceptContainer: | ||||
|         if self.parsed_accept is None: | ||||
|             accept_header = self.headers.getone("accept", "") | ||||
|             self.parsed_accept = parse_accept(accept_header) | ||||
|         return self.parsed_accept | ||||
|  | ||||
|     @property | ||||
|     def token(self): | ||||
|         """Attempt to return the auth header token. | ||||
| @@ -497,6 +508,10 @@ class Request: | ||||
|         """ | ||||
|         return self._match_info | ||||
|  | ||||
|     @match_info.setter | ||||
|     def match_info(self, value): | ||||
|         self._match_info = value | ||||
|  | ||||
|     # Transport properties (obtained from local interface only) | ||||
|  | ||||
|     @property | ||||
|   | ||||
| @@ -1,5 +1,9 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from functools import lru_cache | ||||
| from inspect import signature | ||||
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | ||||
| from uuid import UUID | ||||
|  | ||||
| from sanic_routing import BaseRouter  # type: ignore | ||||
| from sanic_routing.exceptions import NoMethod  # type: ignore | ||||
| @@ -9,6 +13,7 @@ from sanic_routing.exceptions import ( | ||||
| from sanic_routing.route import Route  # type: ignore | ||||
|  | ||||
| from sanic.constants import HTTP_METHODS | ||||
| from sanic.errorpages import check_error_format | ||||
| from sanic.exceptions import MethodNotSupported, NotFound, SanicException | ||||
| from sanic.models.handler_types import RouteHandler | ||||
|  | ||||
| @@ -74,6 +79,7 @@ class Router(BaseRouter): | ||||
|         unquote: bool = False, | ||||
|         static: bool = False, | ||||
|         version_prefix: str = "/v", | ||||
|         error_format: Optional[str] = None, | ||||
|     ) -> Union[Route, List[Route]]: | ||||
|         """ | ||||
|         Add a handler to the router | ||||
| @@ -106,6 +112,8 @@ class Router(BaseRouter): | ||||
|             version = str(version).strip("/").lstrip("v") | ||||
|             uri = "/".join([f"{version_prefix}{version}", uri.lstrip("/")]) | ||||
|  | ||||
|         uri = self._normalize(uri, handler) | ||||
|  | ||||
|         params = dict( | ||||
|             path=uri, | ||||
|             handler=handler, | ||||
| @@ -131,6 +139,10 @@ class Router(BaseRouter): | ||||
|             route.ctx.stream = stream | ||||
|             route.ctx.hosts = hosts | ||||
|             route.ctx.static = static | ||||
|             route.ctx.error_format = error_format | ||||
|  | ||||
|             if error_format: | ||||
|                 check_error_format(route.ctx.error_format) | ||||
|  | ||||
|             routes.append(route) | ||||
|  | ||||
| @@ -187,3 +199,24 @@ class Router(BaseRouter): | ||||
|                 raise SanicException( | ||||
|                     f"Invalid route: {route}. Parameter names cannot use '__'." | ||||
|                 ) | ||||
|  | ||||
|     def _normalize(self, uri: str, handler: RouteHandler) -> str: | ||||
|         if "<" not in uri: | ||||
|             return uri | ||||
|  | ||||
|         sig = signature(handler) | ||||
|         mapping = { | ||||
|             param.name: param.annotation.__name__.lower() | ||||
|             for param in sig.parameters.values() | ||||
|             if param.annotation in (str, int, float, UUID) | ||||
|         } | ||||
|  | ||||
|         reconstruction = [] | ||||
|         for part in uri.split("/"): | ||||
|             if part.startswith("<") and ":" not in part: | ||||
|                 name = part[1:-1] | ||||
|                 annotation = mapping.get(name) | ||||
|                 if annotation: | ||||
|                     part = f"<{name}:{annotation}>" | ||||
|             reconstruction.append(part) | ||||
|         return "/".join(reconstruction) | ||||
|   | ||||
							
								
								
									
										793
									
								
								sanic/server.py
									
									
									
									
									
								
							
							
						
						
									
										793
									
								
								sanic/server.py
									
									
									
									
									
								
							| @@ -1,793 +0,0 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from ssl import SSLContext | ||||
| from types import SimpleNamespace | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Any, | ||||
|     Callable, | ||||
|     Dict, | ||||
|     Iterable, | ||||
|     Optional, | ||||
|     Type, | ||||
|     Union, | ||||
| ) | ||||
|  | ||||
| from sanic.models.handler_types import ListenerType | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from sanic.app import Sanic | ||||
|  | ||||
| import asyncio | ||||
| import multiprocessing | ||||
| import os | ||||
| import secrets | ||||
| import socket | ||||
| import stat | ||||
|  | ||||
| from asyncio import CancelledError | ||||
| from asyncio.transports import Transport | ||||
| from functools import partial | ||||
| from inspect import isawaitable | ||||
| from ipaddress import ip_address | ||||
| from signal import SIG_IGN, SIGINT, SIGTERM, Signals | ||||
| from signal import signal as signal_func | ||||
| from time import monotonic as current_time | ||||
|  | ||||
| from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows | ||||
| from sanic.config import Config | ||||
| from sanic.exceptions import RequestTimeout, ServiceUnavailable | ||||
| from sanic.http import Http, Stage | ||||
| from sanic.log import error_logger, logger | ||||
| from sanic.models.protocol_types import TransportProtocol | ||||
| from sanic.request import Request | ||||
|  | ||||
|  | ||||
| try: | ||||
|     import uvloop  # type: ignore | ||||
|  | ||||
|     if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): | ||||
|         asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | ||||
| except ImportError: | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Signal: | ||||
|     stopped = False | ||||
|  | ||||
|  | ||||
| class ConnInfo: | ||||
|     """ | ||||
|     Local and remote addresses and SSL status info. | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ( | ||||
|         "client_port", | ||||
|         "client", | ||||
|         "client_ip", | ||||
|         "ctx", | ||||
|         "peername", | ||||
|         "server_port", | ||||
|         "server", | ||||
|         "sockname", | ||||
|         "ssl", | ||||
|     ) | ||||
|  | ||||
|     def __init__(self, transport: TransportProtocol, unix=None): | ||||
|         self.ctx = SimpleNamespace() | ||||
|         self.peername = None | ||||
|         self.server = self.client = "" | ||||
|         self.server_port = self.client_port = 0 | ||||
|         self.client_ip = "" | ||||
|         self.sockname = addr = transport.get_extra_info("sockname") | ||||
|         self.ssl: bool = bool(transport.get_extra_info("sslcontext")) | ||||
|  | ||||
|         if isinstance(addr, str):  # UNIX socket | ||||
|             self.server = unix or addr | ||||
|             return | ||||
|  | ||||
|         # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) | ||||
|         if isinstance(addr, tuple): | ||||
|             self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" | ||||
|             self.server_port = addr[1] | ||||
|             # self.server gets non-standard port appended | ||||
|             if addr[1] != (443 if self.ssl else 80): | ||||
|                 self.server = f"{self.server}:{addr[1]}" | ||||
|         self.peername = addr = transport.get_extra_info("peername") | ||||
|  | ||||
|         if isinstance(addr, tuple): | ||||
|             self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" | ||||
|             self.client_ip = addr[0] | ||||
|             self.client_port = addr[1] | ||||
|  | ||||
|  | ||||
| class HttpProtocol(asyncio.Protocol): | ||||
|     """ | ||||
|     This class provides a basic HTTP implementation of the sanic framework. | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ( | ||||
|         # app | ||||
|         "app", | ||||
|         # event loop, connection | ||||
|         "loop", | ||||
|         "transport", | ||||
|         "connections", | ||||
|         "signal", | ||||
|         "conn_info", | ||||
|         "ctx", | ||||
|         # request params | ||||
|         "request", | ||||
|         # request config | ||||
|         "request_handler", | ||||
|         "request_timeout", | ||||
|         "response_timeout", | ||||
|         "keep_alive_timeout", | ||||
|         "request_max_size", | ||||
|         "request_class", | ||||
|         "error_handler", | ||||
|         # enable or disable access log purpose | ||||
|         "access_log", | ||||
|         # connection management | ||||
|         "state", | ||||
|         "url", | ||||
|         "_handler_task", | ||||
|         "_can_write", | ||||
|         "_data_received", | ||||
|         "_time", | ||||
|         "_task", | ||||
|         "_http", | ||||
|         "_exception", | ||||
|         "recv_buffer", | ||||
|         "_unix", | ||||
|     ) | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         loop, | ||||
|         app: Sanic, | ||||
|         signal=None, | ||||
|         connections=None, | ||||
|         state=None, | ||||
|         unix=None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         asyncio.set_event_loop(loop) | ||||
|         self.loop = loop | ||||
|         self.app: Sanic = app | ||||
|         self.url = None | ||||
|         self.transport: Optional[Transport] = None | ||||
|         self.conn_info: Optional[ConnInfo] = None | ||||
|         self.request: Optional[Request] = None | ||||
|         self.signal = signal or Signal() | ||||
|         self.access_log = self.app.config.ACCESS_LOG | ||||
|         self.connections = connections if connections is not None else set() | ||||
|         self.request_handler = self.app.handle_request | ||||
|         self.error_handler = self.app.error_handler | ||||
|         self.request_timeout = self.app.config.REQUEST_TIMEOUT | ||||
|         self.response_timeout = self.app.config.RESPONSE_TIMEOUT | ||||
|         self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT | ||||
|         self.request_max_size = self.app.config.REQUEST_MAX_SIZE | ||||
|         self.request_class = self.app.request_class or Request | ||||
|         self.state = state if state else {} | ||||
|         if "requests_count" not in self.state: | ||||
|             self.state["requests_count"] = 0 | ||||
|         self._data_received = asyncio.Event() | ||||
|         self._can_write = asyncio.Event() | ||||
|         self._can_write.set() | ||||
|         self._exception = None | ||||
|         self._unix = unix | ||||
|  | ||||
|     def _setup_connection(self): | ||||
|         self._http = Http(self) | ||||
|         self._time = current_time() | ||||
|         self.check_timeouts() | ||||
|  | ||||
|     async def connection_task(self): | ||||
|         """ | ||||
|         Run a HTTP connection. | ||||
|  | ||||
|         Timeouts and some additional error handling occur here, while most of | ||||
|         everything else happens in class Http or in code called from there. | ||||
|         """ | ||||
|         try: | ||||
|             self._setup_connection() | ||||
|             await self._http.http1() | ||||
|         except CancelledError: | ||||
|             pass | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.connection_task uncaught") | ||||
|         finally: | ||||
|             if self.app.debug and self._http: | ||||
|                 ip = self.transport.get_extra_info("peername") | ||||
|                 error_logger.error( | ||||
|                     "Connection lost before response written" | ||||
|                     f" @ {ip} {self._http.request}" | ||||
|                 ) | ||||
|             self._http = None | ||||
|             self._task = None | ||||
|             try: | ||||
|                 self.close() | ||||
|             except BaseException: | ||||
|                 error_logger.exception("Closing failed") | ||||
|  | ||||
|     async def receive_more(self): | ||||
|         """ | ||||
|         Wait until more data is received into the Server protocol's buffer | ||||
|         """ | ||||
|         self.transport.resume_reading() | ||||
|         self._data_received.clear() | ||||
|         await self._data_received.wait() | ||||
|  | ||||
|     def check_timeouts(self): | ||||
|         """ | ||||
|         Runs itself periodically to enforce any expired timeouts. | ||||
|         """ | ||||
|         try: | ||||
|             if not self._task: | ||||
|                 return | ||||
|             duration = current_time() - self._time | ||||
|             stage = self._http.stage | ||||
|             if stage is Stage.IDLE and duration > self.keep_alive_timeout: | ||||
|                 logger.debug("KeepAlive Timeout. Closing connection.") | ||||
|             elif stage is Stage.REQUEST and duration > self.request_timeout: | ||||
|                 logger.debug("Request Timeout. Closing connection.") | ||||
|                 self._http.exception = RequestTimeout("Request Timeout") | ||||
|             elif stage is Stage.HANDLER and self._http.upgrade_websocket: | ||||
|                 logger.debug("Handling websocket. Timeouts disabled.") | ||||
|                 return | ||||
|             elif ( | ||||
|                 stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) | ||||
|                 and duration > self.response_timeout | ||||
|             ): | ||||
|                 logger.debug("Response Timeout. Closing connection.") | ||||
|                 self._http.exception = ServiceUnavailable("Response Timeout") | ||||
|             else: | ||||
|                 interval = ( | ||||
|                     min( | ||||
|                         self.keep_alive_timeout, | ||||
|                         self.request_timeout, | ||||
|                         self.response_timeout, | ||||
|                     ) | ||||
|                     / 2 | ||||
|                 ) | ||||
|                 self.loop.call_later(max(0.1, interval), self.check_timeouts) | ||||
|                 return | ||||
|             self._task.cancel() | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.check_timeouts") | ||||
|  | ||||
|     async def send(self, data): | ||||
|         """ | ||||
|         Writes data with backpressure control. | ||||
|         """ | ||||
|         await self._can_write.wait() | ||||
|         if self.transport.is_closing(): | ||||
|             raise CancelledError | ||||
|         self.transport.write(data) | ||||
|         self._time = current_time() | ||||
|  | ||||
|     def close_if_idle(self) -> bool: | ||||
|         """ | ||||
|         Close the connection if a request is not being sent or received | ||||
|  | ||||
|         :return: boolean - True if closed, false if staying open | ||||
|         """ | ||||
|         if self._http is None or self._http.stage is Stage.IDLE: | ||||
|             self.close() | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
|     def close(self): | ||||
|         """ | ||||
|         Force close the connection. | ||||
|         """ | ||||
|         # Cause a call to connection_lost where further cleanup occurs | ||||
|         if self.transport: | ||||
|             self.transport.close() | ||||
|             self.transport = None | ||||
|  | ||||
|     # -------------------------------------------- # | ||||
|     # Only asyncio.Protocol callbacks below this | ||||
|     # -------------------------------------------- # | ||||
|  | ||||
|     def connection_made(self, transport): | ||||
|         try: | ||||
|             # TODO: Benchmark to find suitable write buffer limits | ||||
|             transport.set_write_buffer_limits(low=16384, high=65536) | ||||
|             self.connections.add(self) | ||||
|             self.transport = transport | ||||
|             self._task = self.loop.create_task(self.connection_task()) | ||||
|             self.recv_buffer = bytearray() | ||||
|             self.conn_info = ConnInfo(self.transport, unix=self._unix) | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.connect_made") | ||||
|  | ||||
|     def connection_lost(self, exc): | ||||
|         try: | ||||
|             self.connections.discard(self) | ||||
|             self.resume_writing() | ||||
|             if self._task: | ||||
|                 self._task.cancel() | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.connection_lost") | ||||
|  | ||||
|     def pause_writing(self): | ||||
|         self._can_write.clear() | ||||
|  | ||||
|     def resume_writing(self): | ||||
|         self._can_write.set() | ||||
|  | ||||
|     def data_received(self, data: bytes): | ||||
|         try: | ||||
|             self._time = current_time() | ||||
|             if not data: | ||||
|                 return self.close() | ||||
|             self.recv_buffer += data | ||||
|  | ||||
|             if ( | ||||
|                 len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE | ||||
|                 and self.transport | ||||
|             ): | ||||
|                 self.transport.pause_reading() | ||||
|  | ||||
|             if self._data_received: | ||||
|                 self._data_received.set() | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.data_received") | ||||
|  | ||||
|  | ||||
| def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): | ||||
|     """ | ||||
|     Trigger event callbacks (functions or async) | ||||
|  | ||||
|     :param events: one or more sync or async functions to execute | ||||
|     :param loop: event loop | ||||
|     """ | ||||
|     if events: | ||||
|         for event in events: | ||||
|             result = event(loop) | ||||
|             if isawaitable(result): | ||||
|                 loop.run_until_complete(result) | ||||
|  | ||||
|  | ||||
| class AsyncioServer: | ||||
|     """ | ||||
|     Wraps an asyncio server with functionality that might be useful to | ||||
|     a user who needs to manage the server lifecycle manually. | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ( | ||||
|         "loop", | ||||
|         "serve_coro", | ||||
|         "_after_start", | ||||
|         "_before_stop", | ||||
|         "_after_stop", | ||||
|         "server", | ||||
|         "connections", | ||||
|     ) | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         loop, | ||||
|         serve_coro, | ||||
|         connections, | ||||
|         after_start: Optional[Iterable[ListenerType]], | ||||
|         before_stop: Optional[Iterable[ListenerType]], | ||||
|         after_stop: Optional[Iterable[ListenerType]], | ||||
|     ): | ||||
|         # Note, Sanic already called "before_server_start" events | ||||
|         # before this helper was even created. So we don't need it here. | ||||
|         self.loop = loop | ||||
|         self.serve_coro = serve_coro | ||||
|         self._after_start = after_start | ||||
|         self._before_stop = before_stop | ||||
|         self._after_stop = after_stop | ||||
|         self.server = None | ||||
|         self.connections = connections | ||||
|  | ||||
|     def after_start(self): | ||||
|         """ | ||||
|         Trigger "after_server_start" events | ||||
|         """ | ||||
|         trigger_events(self._after_start, self.loop) | ||||
|  | ||||
|     def before_stop(self): | ||||
|         """ | ||||
|         Trigger "before_server_stop" events | ||||
|         """ | ||||
|         trigger_events(self._before_stop, self.loop) | ||||
|  | ||||
|     def after_stop(self): | ||||
|         """ | ||||
|         Trigger "after_server_stop" events | ||||
|         """ | ||||
|         trigger_events(self._after_stop, self.loop) | ||||
|  | ||||
|     def is_serving(self) -> bool: | ||||
|         if self.server: | ||||
|             return self.server.is_serving() | ||||
|         return False | ||||
|  | ||||
|     def wait_closed(self): | ||||
|         if self.server: | ||||
|             return self.server.wait_closed() | ||||
|  | ||||
|     def close(self): | ||||
|         if self.server: | ||||
|             self.server.close() | ||||
|             coro = self.wait_closed() | ||||
|             task = asyncio.ensure_future(coro, loop=self.loop) | ||||
|             return task | ||||
|  | ||||
|     def start_serving(self): | ||||
|         if self.server: | ||||
|             try: | ||||
|                 return self.server.start_serving() | ||||
|             except AttributeError: | ||||
|                 raise NotImplementedError( | ||||
|                     "server.start_serving not available in this version " | ||||
|                     "of asyncio or uvloop." | ||||
|                 ) | ||||
|  | ||||
|     def serve_forever(self): | ||||
|         if self.server: | ||||
|             try: | ||||
|                 return self.server.serve_forever() | ||||
|             except AttributeError: | ||||
|                 raise NotImplementedError( | ||||
|                     "server.serve_forever not available in this version " | ||||
|                     "of asyncio or uvloop." | ||||
|                 ) | ||||
|  | ||||
|     def __await__(self): | ||||
|         """ | ||||
|         Starts the asyncio server, returns AsyncServerCoro | ||||
|         """ | ||||
|         task = asyncio.ensure_future(self.serve_coro) | ||||
|         while not task.done(): | ||||
|             yield | ||||
|         self.server = task.result() | ||||
|         return self | ||||
|  | ||||
|  | ||||
| def serve( | ||||
|     host, | ||||
|     port, | ||||
|     app, | ||||
|     before_start: Optional[Iterable[ListenerType]] = None, | ||||
|     after_start: Optional[Iterable[ListenerType]] = None, | ||||
|     before_stop: Optional[Iterable[ListenerType]] = None, | ||||
|     after_stop: Optional[Iterable[ListenerType]] = None, | ||||
|     ssl: Optional[SSLContext] = None, | ||||
|     sock: Optional[socket.socket] = None, | ||||
|     unix: Optional[str] = None, | ||||
|     reuse_port: bool = False, | ||||
|     loop=None, | ||||
|     protocol: Type[asyncio.Protocol] = HttpProtocol, | ||||
|     backlog: int = 100, | ||||
|     register_sys_signals: bool = True, | ||||
|     run_multiple: bool = False, | ||||
|     run_async: bool = False, | ||||
|     connections=None, | ||||
|     signal=Signal(), | ||||
|     state=None, | ||||
|     asyncio_server_kwargs=None, | ||||
| ): | ||||
|     """Start asynchronous HTTP Server on an individual process. | ||||
|  | ||||
|     :param host: Address to host on | ||||
|     :param port: Port to host on | ||||
|     :param before_start: function to be executed before the server starts | ||||
|                          listening. Takes arguments `app` instance and `loop` | ||||
|     :param after_start: function to be executed after the server starts | ||||
|                         listening. Takes  arguments `app` instance and `loop` | ||||
|     :param before_stop: function to be executed when a stop signal is | ||||
|                         received before it is respected. Takes arguments | ||||
|                         `app` instance and `loop` | ||||
|     :param after_stop: function to be executed when a stop signal is | ||||
|                        received after it is respected. Takes arguments | ||||
|                        `app` instance and `loop` | ||||
|     :param ssl: SSLContext | ||||
|     :param sock: Socket for the server to accept connections from | ||||
|     :param unix: Unix socket to listen on instead of TCP port | ||||
|     :param reuse_port: `True` for multiple workers | ||||
|     :param loop: asyncio compatible event loop | ||||
|     :param run_async: bool: Do not create a new event loop for the server, | ||||
|                       and return an AsyncServer object rather than running it | ||||
|     :param asyncio_server_kwargs: key-value args for asyncio/uvloop | ||||
|                                   create_server method | ||||
|     :return: Nothing | ||||
|     """ | ||||
|     if not run_async and not loop: | ||||
|         # create new event_loop after fork | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|  | ||||
|     if app.debug: | ||||
|         loop.set_debug(app.debug) | ||||
|  | ||||
|     app.asgi = False | ||||
|  | ||||
|     connections = connections if connections is not None else set() | ||||
|     protocol_kwargs = _build_protocol_kwargs(protocol, app.config) | ||||
|     server = partial( | ||||
|         protocol, | ||||
|         loop=loop, | ||||
|         connections=connections, | ||||
|         signal=signal, | ||||
|         app=app, | ||||
|         state=state, | ||||
|         unix=unix, | ||||
|         **protocol_kwargs, | ||||
|     ) | ||||
|     asyncio_server_kwargs = ( | ||||
|         asyncio_server_kwargs if asyncio_server_kwargs else {} | ||||
|     ) | ||||
|     # UNIX sockets are always bound by us (to preserve semantics between modes) | ||||
|     if unix: | ||||
|         sock = bind_unix_socket(unix, backlog=backlog) | ||||
|     server_coroutine = loop.create_server( | ||||
|         server, | ||||
|         None if sock else host, | ||||
|         None if sock else port, | ||||
|         ssl=ssl, | ||||
|         reuse_port=reuse_port, | ||||
|         sock=sock, | ||||
|         backlog=backlog, | ||||
|         **asyncio_server_kwargs, | ||||
|     ) | ||||
|  | ||||
|     if run_async: | ||||
|         return AsyncioServer( | ||||
|             loop=loop, | ||||
|             serve_coro=server_coroutine, | ||||
|             connections=connections, | ||||
|             after_start=after_start, | ||||
|             before_stop=before_stop, | ||||
|             after_stop=after_stop, | ||||
|         ) | ||||
|  | ||||
|     trigger_events(before_start, loop) | ||||
|  | ||||
|     try: | ||||
|         http_server = loop.run_until_complete(server_coroutine) | ||||
|     except BaseException: | ||||
|         error_logger.exception("Unable to start server") | ||||
|         return | ||||
|  | ||||
|     trigger_events(after_start, loop) | ||||
|  | ||||
|     # Ignore SIGINT when run_multiple | ||||
|     if run_multiple: | ||||
|         signal_func(SIGINT, SIG_IGN) | ||||
|  | ||||
|     # Register signals for graceful termination | ||||
|     if register_sys_signals: | ||||
|         if OS_IS_WINDOWS: | ||||
|             ctrlc_workaround_for_windows(app) | ||||
|         else: | ||||
|             for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: | ||||
|                 loop.add_signal_handler(_signal, app.stop) | ||||
|     pid = os.getpid() | ||||
|     try: | ||||
|         logger.info("Starting worker [%s]", pid) | ||||
|         loop.run_forever() | ||||
|     finally: | ||||
|         logger.info("Stopping worker [%s]", pid) | ||||
|  | ||||
|         # Run the on_stop function if provided | ||||
|         trigger_events(before_stop, loop) | ||||
|  | ||||
|         # Wait for event loop to finish and all connections to drain | ||||
|         http_server.close() | ||||
|         loop.run_until_complete(http_server.wait_closed()) | ||||
|  | ||||
|         # Complete all tasks on the loop | ||||
|         signal.stopped = True | ||||
|         for connection in connections: | ||||
|             connection.close_if_idle() | ||||
|  | ||||
|         # Gracefully shutdown timeout. | ||||
|         # We should provide graceful_shutdown_timeout, | ||||
|         # instead of letting connection hangs forever. | ||||
|         # Let's roughly calcucate time. | ||||
|         graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT | ||||
|         start_shutdown: float = 0 | ||||
|         while connections and (start_shutdown < graceful): | ||||
|             loop.run_until_complete(asyncio.sleep(0.1)) | ||||
|             start_shutdown = start_shutdown + 0.1 | ||||
|  | ||||
|         # Force close non-idle connection after waiting for | ||||
|         # graceful_shutdown_timeout | ||||
|         coros = [] | ||||
|         for conn in connections: | ||||
|             if hasattr(conn, "websocket") and conn.websocket: | ||||
|                 coros.append(conn.websocket.close_connection()) | ||||
|             else: | ||||
|                 conn.close() | ||||
|  | ||||
|         _shutdown = asyncio.gather(*coros) | ||||
|         loop.run_until_complete(_shutdown) | ||||
|  | ||||
|         trigger_events(after_stop, loop) | ||||
|  | ||||
|         remove_unix_socket(unix) | ||||
|  | ||||
|  | ||||
| def _build_protocol_kwargs( | ||||
|     protocol: Type[asyncio.Protocol], config: Config | ||||
| ) -> Dict[str, Union[int, float]]: | ||||
|     if hasattr(protocol, "websocket_handshake"): | ||||
|         return { | ||||
|             "websocket_max_size": config.WEBSOCKET_MAX_SIZE, | ||||
|             "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, | ||||
|             "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, | ||||
|             "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, | ||||
|             "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, | ||||
|             "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, | ||||
|         } | ||||
|     return {} | ||||
|  | ||||
|  | ||||
| def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: | ||||
|     """Create TCP server socket. | ||||
|     :param host: IPv4, IPv6 or hostname may be specified | ||||
|     :param port: TCP port number | ||||
|     :param backlog: Maximum number of connections to queue | ||||
|     :return: socket.socket object | ||||
|     """ | ||||
|     try:  # IP address: family must be specified for IPv6 at least | ||||
|         ip = ip_address(host) | ||||
|         host = str(ip) | ||||
|         sock = socket.socket( | ||||
|             socket.AF_INET6 if ip.version == 6 else socket.AF_INET | ||||
|         ) | ||||
|     except ValueError:  # Hostname, may become AF_INET or AF_INET6 | ||||
|         sock = socket.socket() | ||||
|     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||||
|     sock.bind((host, port)) | ||||
|     sock.listen(backlog) | ||||
|     return sock | ||||
|  | ||||
|  | ||||
| def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: | ||||
|     """Create unix socket. | ||||
|     :param path: filesystem path | ||||
|     :param backlog: Maximum number of connections to queue | ||||
|     :return: socket.socket object | ||||
|     """ | ||||
|     """Open or atomically replace existing socket with zero downtime.""" | ||||
|     # Sanitise and pre-verify socket path | ||||
|     path = os.path.abspath(path) | ||||
|     folder = os.path.dirname(path) | ||||
|     if not os.path.isdir(folder): | ||||
|         raise FileNotFoundError(f"Socket folder does not exist: {folder}") | ||||
|     try: | ||||
|         if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): | ||||
|             raise FileExistsError(f"Existing file is not a socket: {path}") | ||||
|     except FileNotFoundError: | ||||
|         pass | ||||
|     # Create new socket with a random temporary name | ||||
|     tmp_path = f"{path}.{secrets.token_urlsafe()}" | ||||
|     sock = socket.socket(socket.AF_UNIX) | ||||
|     try: | ||||
|         # Critical section begins (filename races) | ||||
|         sock.bind(tmp_path) | ||||
|         try: | ||||
|             os.chmod(tmp_path, mode) | ||||
|             # Start listening before rename to avoid connection failures | ||||
|             sock.listen(backlog) | ||||
|             os.rename(tmp_path, path) | ||||
|         except:  # noqa: E722 | ||||
|             try: | ||||
|                 os.unlink(tmp_path) | ||||
|             finally: | ||||
|                 raise | ||||
|     except:  # noqa: E722 | ||||
|         try: | ||||
|             sock.close() | ||||
|         finally: | ||||
|             raise | ||||
|     return sock | ||||
|  | ||||
|  | ||||
| def remove_unix_socket(path: Optional[str]) -> None: | ||||
|     """Remove dead unix socket during server exit.""" | ||||
|     if not path: | ||||
|         return | ||||
|     try: | ||||
|         if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): | ||||
|             # Is it actually dead (doesn't belong to a new server instance)? | ||||
|             with socket.socket(socket.AF_UNIX) as testsock: | ||||
|                 try: | ||||
|                     testsock.connect(path) | ||||
|                 except ConnectionRefusedError: | ||||
|                     os.unlink(path) | ||||
|     except FileNotFoundError: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| def serve_single(server_settings): | ||||
|     main_start = server_settings.pop("main_start", None) | ||||
|     main_stop = server_settings.pop("main_stop", None) | ||||
|  | ||||
|     if not server_settings.get("run_async"): | ||||
|         # create new event_loop after fork | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         server_settings["loop"] = loop | ||||
|  | ||||
|     trigger_events(main_start, server_settings["loop"]) | ||||
|     serve(**server_settings) | ||||
|     trigger_events(main_stop, server_settings["loop"]) | ||||
|  | ||||
|     server_settings["loop"].close() | ||||
|  | ||||
|  | ||||
| def serve_multiple(server_settings, workers): | ||||
|     """Start multiple server processes simultaneously.  Stop on interrupt | ||||
|     and terminate signals, and drain connections when complete. | ||||
|  | ||||
|     :param server_settings: kw arguments to be passed to the serve function | ||||
|     :param workers: number of workers to launch | ||||
|     :param stop_event: if provided, is used as a stop signal | ||||
|     :return: | ||||
|     """ | ||||
|     server_settings["reuse_port"] = True | ||||
|     server_settings["run_multiple"] = True | ||||
|  | ||||
|     main_start = server_settings.pop("main_start", None) | ||||
|     main_stop = server_settings.pop("main_stop", None) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|  | ||||
|     trigger_events(main_start, loop) | ||||
|  | ||||
|     # Create a listening socket or use the one in settings | ||||
|     sock = server_settings.get("sock") | ||||
|     unix = server_settings["unix"] | ||||
|     backlog = server_settings["backlog"] | ||||
|     if unix: | ||||
|         sock = bind_unix_socket(unix, backlog=backlog) | ||||
|         server_settings["unix"] = unix | ||||
|     if sock is None: | ||||
|         sock = bind_socket( | ||||
|             server_settings["host"], server_settings["port"], backlog=backlog | ||||
|         ) | ||||
|         sock.set_inheritable(True) | ||||
|         server_settings["sock"] = sock | ||||
|         server_settings["host"] = None | ||||
|         server_settings["port"] = None | ||||
|  | ||||
|     processes = [] | ||||
|  | ||||
|     def sig_handler(signal, frame): | ||||
|         logger.info("Received signal %s. Shutting down.", Signals(signal).name) | ||||
|         for process in processes: | ||||
|             os.kill(process.pid, SIGTERM) | ||||
|  | ||||
|     signal_func(SIGINT, lambda s, f: sig_handler(s, f)) | ||||
|     signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) | ||||
|     mp = multiprocessing.get_context("fork") | ||||
|  | ||||
|     for _ in range(workers): | ||||
|         process = mp.Process(target=serve, kwargs=server_settings) | ||||
|         process.daemon = True | ||||
|         process.start() | ||||
|         processes.append(process) | ||||
|  | ||||
|     for process in processes: | ||||
|         process.join() | ||||
|  | ||||
|     # the above processes will block this until they're stopped | ||||
|     for process in processes: | ||||
|         process.terminate() | ||||
|  | ||||
|     trigger_events(main_stop, loop) | ||||
|  | ||||
|     sock.close() | ||||
|     loop.close() | ||||
|     remove_unix_socket(unix) | ||||
							
								
								
									
										26
									
								
								sanic/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								sanic/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| import asyncio | ||||
|  | ||||
| from sanic.models.server_types import ConnInfo, Signal | ||||
| from sanic.server.async_server import AsyncioServer | ||||
| from sanic.server.protocols.http_protocol import HttpProtocol | ||||
| from sanic.server.runners import serve, serve_multiple, serve_single | ||||
|  | ||||
|  | ||||
| try: | ||||
|     import uvloop  # type: ignore | ||||
|  | ||||
|     if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): | ||||
|         asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | ||||
| except ImportError: | ||||
|     pass | ||||
|  | ||||
|  | ||||
| __all__ = ( | ||||
|     "AsyncioServer", | ||||
|     "ConnInfo", | ||||
|     "HttpProtocol", | ||||
|     "Signal", | ||||
|     "serve", | ||||
|     "serve_multiple", | ||||
|     "serve_single", | ||||
| ) | ||||
							
								
								
									
										115
									
								
								sanic/server/async_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								sanic/server/async_server.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,115 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
|  | ||||
| from sanic.exceptions import SanicException | ||||
|  | ||||
|  | ||||
| class AsyncioServer: | ||||
|     """ | ||||
|     Wraps an asyncio server with functionality that might be useful to | ||||
|     a user who needs to manage the server lifecycle manually. | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ("app", "connections", "loop", "serve_coro", "server", "init") | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         app, | ||||
|         loop, | ||||
|         serve_coro, | ||||
|         connections, | ||||
|     ): | ||||
|         # Note, Sanic already called "before_server_start" events | ||||
|         # before this helper was even created. So we don't need it here. | ||||
|         self.app = app | ||||
|         self.connections = connections | ||||
|         self.loop = loop | ||||
|         self.serve_coro = serve_coro | ||||
|         self.server = None | ||||
|         self.init = False | ||||
|  | ||||
|     def startup(self): | ||||
|         """ | ||||
|         Trigger "before_server_start" events | ||||
|         """ | ||||
|         self.init = True | ||||
|         return self.app._startup() | ||||
|  | ||||
|     def before_start(self): | ||||
|         """ | ||||
|         Trigger "before_server_start" events | ||||
|         """ | ||||
|         return self._server_event("init", "before") | ||||
|  | ||||
|     def after_start(self): | ||||
|         """ | ||||
|         Trigger "after_server_start" events | ||||
|         """ | ||||
|         return self._server_event("init", "after") | ||||
|  | ||||
|     def before_stop(self): | ||||
|         """ | ||||
|         Trigger "before_server_stop" events | ||||
|         """ | ||||
|         return self._server_event("shutdown", "before") | ||||
|  | ||||
|     def after_stop(self): | ||||
|         """ | ||||
|         Trigger "after_server_stop" events | ||||
|         """ | ||||
|         return self._server_event("shutdown", "after") | ||||
|  | ||||
|     def is_serving(self) -> bool: | ||||
|         if self.server: | ||||
|             return self.server.is_serving() | ||||
|         return False | ||||
|  | ||||
|     def wait_closed(self): | ||||
|         if self.server: | ||||
|             return self.server.wait_closed() | ||||
|  | ||||
|     def close(self): | ||||
|         if self.server: | ||||
|             self.server.close() | ||||
|             coro = self.wait_closed() | ||||
|             task = asyncio.ensure_future(coro, loop=self.loop) | ||||
|             return task | ||||
|  | ||||
|     def start_serving(self): | ||||
|         if self.server: | ||||
|             try: | ||||
|                 return self.server.start_serving() | ||||
|             except AttributeError: | ||||
|                 raise NotImplementedError( | ||||
|                     "server.start_serving not available in this version " | ||||
|                     "of asyncio or uvloop." | ||||
|                 ) | ||||
|  | ||||
|     def serve_forever(self): | ||||
|         if self.server: | ||||
|             try: | ||||
|                 return self.server.serve_forever() | ||||
|             except AttributeError: | ||||
|                 raise NotImplementedError( | ||||
|                     "server.serve_forever not available in this version " | ||||
|                     "of asyncio or uvloop." | ||||
|                 ) | ||||
|  | ||||
|     def _server_event(self, concern: str, action: str): | ||||
|         if not self.init: | ||||
|             raise SanicException( | ||||
|                 "Cannot dispatch server event without " | ||||
|                 "first running server.startup()" | ||||
|             ) | ||||
|         return self.app._server_event(concern, action, loop=self.loop) | ||||
|  | ||||
|     def __await__(self): | ||||
|         """ | ||||
|         Starts the asyncio server, returns AsyncServerCoro | ||||
|         """ | ||||
|         task = asyncio.ensure_future(self.serve_coro) | ||||
|         while not task.done(): | ||||
|             yield | ||||
|         self.server = task.result() | ||||
|         return self | ||||
							
								
								
									
										16
									
								
								sanic/server/events.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								sanic/server/events.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| from inspect import isawaitable | ||||
| from typing import Any, Callable, Iterable, Optional | ||||
|  | ||||
|  | ||||
| def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): | ||||
|     """ | ||||
|     Trigger event callbacks (functions or async) | ||||
|  | ||||
|     :param events: one or more sync or async functions to execute | ||||
|     :param loop: event loop | ||||
|     """ | ||||
|     if events: | ||||
|         for event in events: | ||||
|             result = event(loop) | ||||
|             if isawaitable(result): | ||||
|                 loop.run_until_complete(result) | ||||
							
								
								
									
										0
									
								
								sanic/server/protocols/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sanic/server/protocols/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										143
									
								
								sanic/server/protocols/base_protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								sanic/server/protocols/base_protocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from sanic.app import Sanic | ||||
|  | ||||
| import asyncio | ||||
|  | ||||
| from asyncio import CancelledError | ||||
| from asyncio.transports import Transport | ||||
| from time import monotonic as current_time | ||||
|  | ||||
| from sanic.log import error_logger | ||||
| from sanic.models.server_types import ConnInfo, Signal | ||||
|  | ||||
|  | ||||
| class SanicProtocol(asyncio.Protocol): | ||||
|     __slots__ = ( | ||||
|         "app", | ||||
|         # event loop, connection | ||||
|         "loop", | ||||
|         "transport", | ||||
|         "connections", | ||||
|         "conn_info", | ||||
|         "signal", | ||||
|         "_can_write", | ||||
|         "_time", | ||||
|         "_task", | ||||
|         "_unix", | ||||
|         "_data_received", | ||||
|     ) | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         loop, | ||||
|         app: Sanic, | ||||
|         signal=None, | ||||
|         connections=None, | ||||
|         unix=None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         asyncio.set_event_loop(loop) | ||||
|         self.loop = loop | ||||
|         self.app: Sanic = app | ||||
|         self.signal = signal or Signal() | ||||
|         self.transport: Optional[Transport] = None | ||||
|         self.connections = connections if connections is not None else set() | ||||
|         self.conn_info: Optional[ConnInfo] = None | ||||
|         self._can_write = asyncio.Event() | ||||
|         self._can_write.set() | ||||
|         self._unix = unix | ||||
|         self._time = 0.0  # type: float | ||||
|         self._task = None  # type: Optional[asyncio.Task] | ||||
|         self._data_received = asyncio.Event() | ||||
|  | ||||
|     @property | ||||
|     def ctx(self): | ||||
|         if self.conn_info is not None: | ||||
|             return self.conn_info.ctx | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     async def send(self, data): | ||||
|         """ | ||||
|         Generic data write implementation with backpressure control. | ||||
|         """ | ||||
|         await self._can_write.wait() | ||||
|         if self.transport.is_closing(): | ||||
|             raise CancelledError | ||||
|         self.transport.write(data) | ||||
|         self._time = current_time() | ||||
|  | ||||
|     async def receive_more(self): | ||||
|         """ | ||||
|         Wait until more data is received into the Server protocol's buffer | ||||
|         """ | ||||
|         self.transport.resume_reading() | ||||
|         self._data_received.clear() | ||||
|         await self._data_received.wait() | ||||
|  | ||||
|     def close(self, timeout: Optional[float] = None): | ||||
|         """ | ||||
|         Attempt close the connection. | ||||
|         """ | ||||
|         # Cause a call to connection_lost where further cleanup occurs | ||||
|         if self.transport: | ||||
|             self.transport.close() | ||||
|             if timeout is None: | ||||
|                 timeout = self.app.config.GRACEFUL_SHUTDOWN_TIMEOUT | ||||
|             self.loop.call_later(timeout, self.abort) | ||||
|  | ||||
|     def abort(self): | ||||
|         """ | ||||
|         Force close the connection. | ||||
|         """ | ||||
|         # Cause a call to connection_lost where further cleanup occurs | ||||
|         if self.transport: | ||||
|             self.transport.abort() | ||||
|             self.transport = None | ||||
|  | ||||
|     # asyncio.Protocol API Callbacks # | ||||
|     # ------------------------------ # | ||||
|     def connection_made(self, transport): | ||||
|         """ | ||||
|         Generic connection-made, with no connection_task, and no recv_buffer. | ||||
|         Override this for protocol-specific connection implementations. | ||||
|         """ | ||||
|         try: | ||||
|             transport.set_write_buffer_limits(low=16384, high=65536) | ||||
|             self.connections.add(self) | ||||
|             self.transport = transport | ||||
|             self.conn_info = ConnInfo(self.transport, unix=self._unix) | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.connect_made") | ||||
|  | ||||
|     def connection_lost(self, exc): | ||||
|         try: | ||||
|             self.connections.discard(self) | ||||
|             self.resume_writing() | ||||
|             if self._task: | ||||
|                 self._task.cancel() | ||||
|         except BaseException: | ||||
|             error_logger.exception("protocol.connection_lost") | ||||
|  | ||||
|     def pause_writing(self): | ||||
|         self._can_write.clear() | ||||
|  | ||||
|     def resume_writing(self): | ||||
|         self._can_write.set() | ||||
|  | ||||
|     def data_received(self, data: bytes): | ||||
|         try: | ||||
|             self._time = current_time() | ||||
|             if not data: | ||||
|                 return self.close() | ||||
|  | ||||
|             if self._data_received: | ||||
|                 self._data_received.set() | ||||
|         except BaseException: | ||||
|             error_logger.exception("protocol.data_received") | ||||
							
								
								
									
										238
									
								
								sanic/server/protocols/http_protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								sanic/server/protocols/http_protocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,238 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| from sanic.touchup.meta import TouchUpMeta | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from sanic.app import Sanic | ||||
|  | ||||
| from asyncio import CancelledError | ||||
| from time import monotonic as current_time | ||||
|  | ||||
| from sanic.exceptions import RequestTimeout, ServiceUnavailable | ||||
| from sanic.http import Http, Stage | ||||
| from sanic.log import error_logger, logger | ||||
| from sanic.models.server_types import ConnInfo | ||||
| from sanic.request import Request | ||||
| from sanic.server.protocols.base_protocol import SanicProtocol | ||||
|  | ||||
|  | ||||
| class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): | ||||
|     """ | ||||
|     This class provides implements the HTTP 1.1 protocol on top of our | ||||
|     Sanic Server transport | ||||
|     """ | ||||
|  | ||||
|     __touchup__ = ( | ||||
|         "send", | ||||
|         "connection_task", | ||||
|     ) | ||||
|     __slots__ = ( | ||||
|         # request params | ||||
|         "request", | ||||
|         # request config | ||||
|         "request_handler", | ||||
|         "request_timeout", | ||||
|         "response_timeout", | ||||
|         "keep_alive_timeout", | ||||
|         "request_max_size", | ||||
|         "request_class", | ||||
|         "error_handler", | ||||
|         # enable or disable access log purpose | ||||
|         "access_log", | ||||
|         # connection management | ||||
|         "state", | ||||
|         "url", | ||||
|         "_handler_task", | ||||
|         "_http", | ||||
|         "_exception", | ||||
|         "recv_buffer", | ||||
|     ) | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         loop, | ||||
|         app: Sanic, | ||||
|         signal=None, | ||||
|         connections=None, | ||||
|         state=None, | ||||
|         unix=None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__( | ||||
|             loop=loop, | ||||
|             app=app, | ||||
|             signal=signal, | ||||
|             connections=connections, | ||||
|             unix=unix, | ||||
|         ) | ||||
|         self.url = None | ||||
|         self.request: Optional[Request] = None | ||||
|         self.access_log = self.app.config.ACCESS_LOG | ||||
|         self.request_handler = self.app.handle_request | ||||
|         self.error_handler = self.app.error_handler | ||||
|         self.request_timeout = self.app.config.REQUEST_TIMEOUT | ||||
|         self.response_timeout = self.app.config.RESPONSE_TIMEOUT | ||||
|         self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT | ||||
|         self.request_max_size = self.app.config.REQUEST_MAX_SIZE | ||||
|         self.request_class = self.app.request_class or Request | ||||
|         self.state = state if state else {} | ||||
|         if "requests_count" not in self.state: | ||||
|             self.state["requests_count"] = 0 | ||||
|         self._exception = None | ||||
|  | ||||
|     def _setup_connection(self): | ||||
|         self._http = Http(self) | ||||
|         self._time = current_time() | ||||
|         self.check_timeouts() | ||||
|  | ||||
|     async def connection_task(self):  # no cov | ||||
|         """ | ||||
|         Run a HTTP connection. | ||||
|  | ||||
|         Timeouts and some additional error handling occur here, while most of | ||||
|         everything else happens in class Http or in code called from there. | ||||
|         """ | ||||
|         try: | ||||
|             self._setup_connection() | ||||
|             await self.app.dispatch( | ||||
|                 "http.lifecycle.begin", | ||||
|                 inline=True, | ||||
|                 context={"conn_info": self.conn_info}, | ||||
|             ) | ||||
|             await self._http.http1() | ||||
|         except CancelledError: | ||||
|             pass | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.connection_task uncaught") | ||||
|         finally: | ||||
|             if ( | ||||
|                 self.app.debug | ||||
|                 and self._http | ||||
|                 and self.transport | ||||
|                 and not self._http.upgrade_websocket | ||||
|             ): | ||||
|                 ip = self.transport.get_extra_info("peername") | ||||
|                 error_logger.error( | ||||
|                     "Connection lost before response written" | ||||
|                     f" @ {ip} {self._http.request}" | ||||
|                 ) | ||||
|             self._http = None | ||||
|             self._task = None | ||||
|             try: | ||||
|                 self.close() | ||||
|             except BaseException: | ||||
|                 error_logger.exception("Closing failed") | ||||
|             finally: | ||||
|                 await self.app.dispatch( | ||||
|                     "http.lifecycle.complete", | ||||
|                     inline=True, | ||||
|                     context={"conn_info": self.conn_info}, | ||||
|                 ) | ||||
|                 # Important to keep this Ellipsis here for the TouchUp module | ||||
|                 ... | ||||
|  | ||||
|     def check_timeouts(self): | ||||
|         """ | ||||
|         Runs itself periodically to enforce any expired timeouts. | ||||
|         """ | ||||
|         try: | ||||
|             if not self._task: | ||||
|                 return | ||||
|             duration = current_time() - self._time | ||||
|             stage = self._http.stage | ||||
|             if stage is Stage.IDLE and duration > self.keep_alive_timeout: | ||||
|                 logger.debug("KeepAlive Timeout. Closing connection.") | ||||
|             elif stage is Stage.REQUEST and duration > self.request_timeout: | ||||
|                 logger.debug("Request Timeout. Closing connection.") | ||||
|                 self._http.exception = RequestTimeout("Request Timeout") | ||||
|             elif stage is Stage.HANDLER and self._http.upgrade_websocket: | ||||
|                 logger.debug("Handling websocket. Timeouts disabled.") | ||||
|                 return | ||||
|             elif ( | ||||
|                 stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) | ||||
|                 and duration > self.response_timeout | ||||
|             ): | ||||
|                 logger.debug("Response Timeout. Closing connection.") | ||||
|                 self._http.exception = ServiceUnavailable("Response Timeout") | ||||
|             else: | ||||
|                 interval = ( | ||||
|                     min( | ||||
|                         self.keep_alive_timeout, | ||||
|                         self.request_timeout, | ||||
|                         self.response_timeout, | ||||
|                     ) | ||||
|                     / 2 | ||||
|                 ) | ||||
|                 self.loop.call_later(max(0.1, interval), self.check_timeouts) | ||||
|                 return | ||||
|             self._task.cancel() | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.check_timeouts") | ||||
|  | ||||
|     async def send(self, data):  # no cov | ||||
|         """ | ||||
|         Writes HTTP data with backpressure control. | ||||
|         """ | ||||
|         await self._can_write.wait() | ||||
|         if self.transport.is_closing(): | ||||
|             raise CancelledError | ||||
|         await self.app.dispatch( | ||||
|             "http.lifecycle.send", | ||||
|             inline=True, | ||||
|             context={"data": data}, | ||||
|         ) | ||||
|         self.transport.write(data) | ||||
|         self._time = current_time() | ||||
|  | ||||
|     def close_if_idle(self) -> bool: | ||||
|         """ | ||||
|         Close the connection if a request is not being sent or received | ||||
|  | ||||
|         :return: boolean - True if closed, false if staying open | ||||
|         """ | ||||
|         if self._http is None or self._http.stage is Stage.IDLE: | ||||
|             self.close() | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
|     # -------------------------------------------- # | ||||
|     # Only asyncio.Protocol callbacks below this | ||||
|     # -------------------------------------------- # | ||||
|  | ||||
|     def connection_made(self, transport): | ||||
|         """ | ||||
|         HTTP-protocol-specific new connection handler | ||||
|         """ | ||||
|         try: | ||||
|             # TODO: Benchmark to find suitable write buffer limits | ||||
|             transport.set_write_buffer_limits(low=16384, high=65536) | ||||
|             self.connections.add(self) | ||||
|             self.transport = transport | ||||
|             self._task = self.loop.create_task(self.connection_task()) | ||||
|             self.recv_buffer = bytearray() | ||||
|             self.conn_info = ConnInfo(self.transport, unix=self._unix) | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.connect_made") | ||||
|  | ||||
|     def data_received(self, data: bytes): | ||||
|  | ||||
|         try: | ||||
|             self._time = current_time() | ||||
|             if not data: | ||||
|                 return self.close() | ||||
|             self.recv_buffer += data | ||||
|  | ||||
|             if ( | ||||
|                 len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE | ||||
|                 and self.transport | ||||
|             ): | ||||
|                 self.transport.pause_reading() | ||||
|  | ||||
|             if self._data_received: | ||||
|                 self._data_received.set() | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.data_received") | ||||
							
								
								
									
										164
									
								
								sanic/server/protocols/websocket_protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								sanic/server/protocols/websocket_protocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,164 @@ | ||||
| from typing import TYPE_CHECKING, Optional, Sequence | ||||
|  | ||||
| from websockets.connection import CLOSED, CLOSING, OPEN | ||||
| from websockets.server import ServerConnection | ||||
|  | ||||
| from sanic.exceptions import ServerError | ||||
| from sanic.log import error_logger | ||||
| from sanic.server import HttpProtocol | ||||
|  | ||||
| from ..websockets.impl import WebsocketImplProtocol | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from websockets import http11 | ||||
|  | ||||
|  | ||||
| class WebSocketProtocol(HttpProtocol): | ||||
|  | ||||
|     websocket: Optional[WebsocketImplProtocol] | ||||
|     websocket_timeout: float | ||||
|     websocket_max_size = Optional[int] | ||||
|     websocket_ping_interval = Optional[float] | ||||
|     websocket_ping_timeout = Optional[float] | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         *args, | ||||
|         websocket_timeout: float = 10.0, | ||||
|         websocket_max_size: Optional[int] = None, | ||||
|         websocket_max_queue: Optional[int] = None,  # max_queue is deprecated | ||||
|         websocket_read_limit: Optional[int] = None,  # read_limit is deprecated | ||||
|         websocket_write_limit: Optional[int] = None,  # write_limit deprecated | ||||
|         websocket_ping_interval: Optional[float] = 20.0, | ||||
|         websocket_ping_timeout: Optional[float] = 20.0, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.websocket = None | ||||
|         self.websocket_timeout = websocket_timeout | ||||
|         self.websocket_max_size = websocket_max_size | ||||
|         if websocket_max_queue is not None and websocket_max_queue > 0: | ||||
|             # TODO: Reminder remove this warning in v22.3 | ||||
|             error_logger.warning( | ||||
|                 DeprecationWarning( | ||||
|                     "Websocket no longer uses queueing, so websocket_max_queue" | ||||
|                     " is no longer required." | ||||
|                 ) | ||||
|             ) | ||||
|         if websocket_read_limit is not None and websocket_read_limit > 0: | ||||
|             # TODO: Reminder remove this warning in v22.3 | ||||
|             error_logger.warning( | ||||
|                 DeprecationWarning( | ||||
|                     "Websocket no longer uses read buffers, so " | ||||
|                     "websocket_read_limit is not required." | ||||
|                 ) | ||||
|             ) | ||||
|         if websocket_write_limit is not None and websocket_write_limit > 0: | ||||
|             # TODO: Reminder remove this warning in v22.3 | ||||
|             error_logger.warning( | ||||
|                 DeprecationWarning( | ||||
|                     "Websocket no longer uses write buffers, so " | ||||
|                     "websocket_write_limit is not required." | ||||
|                 ) | ||||
|             ) | ||||
|         self.websocket_ping_interval = websocket_ping_interval | ||||
|         self.websocket_ping_timeout = websocket_ping_timeout | ||||
|  | ||||
|     def connection_lost(self, exc): | ||||
|         if self.websocket is not None: | ||||
|             self.websocket.connection_lost(exc) | ||||
|         super().connection_lost(exc) | ||||
|  | ||||
|     def data_received(self, data): | ||||
|         if self.websocket is not None: | ||||
|             self.websocket.data_received(data) | ||||
|         else: | ||||
|             # Pass it to HttpProtocol handler first | ||||
|             # That will (hopefully) upgrade it to a websocket. | ||||
|             super().data_received(data) | ||||
|  | ||||
|     def eof_received(self) -> Optional[bool]: | ||||
|         if self.websocket is not None: | ||||
|             return self.websocket.eof_received() | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|     def close(self, timeout: Optional[float] = None): | ||||
|         # Called by HttpProtocol at the end of connection_task | ||||
|         # If we've upgraded to websocket, we do our own closing | ||||
|         if self.websocket is not None: | ||||
|             # Note, we don't want to use websocket.close() | ||||
|             # That is used for user's application code to send a | ||||
|             # websocket close packet. This is different. | ||||
|             self.websocket.end_connection(1001) | ||||
|         else: | ||||
|             super().close() | ||||
|  | ||||
|     def close_if_idle(self): | ||||
|         # Called by Sanic Server when shutting down | ||||
|         # If we've upgraded to websocket, shut it down | ||||
|         if self.websocket is not None: | ||||
|             if self.websocket.connection.state in (CLOSING, CLOSED): | ||||
|                 return True | ||||
|             elif self.websocket.loop is not None: | ||||
|                 self.websocket.loop.create_task(self.websocket.close(1001)) | ||||
|             else: | ||||
|                 self.websocket.end_connection(1001) | ||||
|         else: | ||||
|             return super().close_if_idle() | ||||
|  | ||||
|     async def websocket_handshake( | ||||
|         self, request, subprotocols=Optional[Sequence[str]] | ||||
|     ): | ||||
|         # let the websockets package do the handshake with the client | ||||
|         try: | ||||
|             if subprotocols is not None: | ||||
|                 # subprotocols can be a set or frozenset, | ||||
|                 # but ServerConnection needs a list | ||||
|                 subprotocols = list(subprotocols) | ||||
|             ws_conn = ServerConnection( | ||||
|                 max_size=self.websocket_max_size, | ||||
|                 subprotocols=subprotocols, | ||||
|                 state=OPEN, | ||||
|                 logger=error_logger, | ||||
|             ) | ||||
|             resp: "http11.Response" = ws_conn.accept(request) | ||||
|         except Exception: | ||||
|             msg = ( | ||||
|                 "Failed to open a WebSocket connection.\n" | ||||
|                 "See server log for more information.\n" | ||||
|             ) | ||||
|             raise ServerError(msg, status_code=500) | ||||
|         if 100 <= resp.status_code <= 299: | ||||
|             rbody = "".join( | ||||
|                 [ | ||||
|                     "HTTP/1.1 ", | ||||
|                     str(resp.status_code), | ||||
|                     " ", | ||||
|                     resp.reason_phrase, | ||||
|                     "\r\n", | ||||
|                 ] | ||||
|             ) | ||||
|             rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) | ||||
|             if resp.body is not None: | ||||
|                 rbody += f"\r\n{resp.body}\r\n\r\n" | ||||
|             else: | ||||
|                 rbody += "\r\n" | ||||
|             await super().send(rbody.encode()) | ||||
|         else: | ||||
|             raise ServerError(resp.body, resp.status_code) | ||||
|         self.websocket = WebsocketImplProtocol( | ||||
|             ws_conn, | ||||
|             ping_interval=self.websocket_ping_interval, | ||||
|             ping_timeout=self.websocket_ping_timeout, | ||||
|             close_timeout=self.websocket_timeout, | ||||
|         ) | ||||
|         loop = ( | ||||
|             request.transport.loop | ||||
|             if hasattr(request, "transport") | ||||
|             and hasattr(request.transport, "loop") | ||||
|             else None | ||||
|         ) | ||||
|         await self.websocket.connection_made(self, loop=loop) | ||||
|         return self.websocket | ||||
							
								
								
									
										280
									
								
								sanic/server/runners.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										280
									
								
								sanic/server/runners.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,280 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from ssl import SSLContext | ||||
| from typing import TYPE_CHECKING, Dict, Optional, Type, Union | ||||
|  | ||||
| from sanic.config import Config | ||||
| from sanic.server.events import trigger_events | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from sanic.app import Sanic | ||||
|  | ||||
| import asyncio | ||||
| import multiprocessing | ||||
| import os | ||||
| import socket | ||||
|  | ||||
| from functools import partial | ||||
| from signal import SIG_IGN, SIGINT, SIGTERM, Signals | ||||
| from signal import signal as signal_func | ||||
|  | ||||
| from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows | ||||
| from sanic.log import error_logger, logger | ||||
| from sanic.models.server_types import Signal | ||||
| from sanic.server.async_server import AsyncioServer | ||||
| from sanic.server.protocols.http_protocol import HttpProtocol | ||||
| from sanic.server.socket import ( | ||||
|     bind_socket, | ||||
|     bind_unix_socket, | ||||
|     remove_unix_socket, | ||||
| ) | ||||
|  | ||||
|  | ||||
| def serve( | ||||
|     host, | ||||
|     port, | ||||
|     app: Sanic, | ||||
|     ssl: Optional[SSLContext] = None, | ||||
|     sock: Optional[socket.socket] = None, | ||||
|     unix: Optional[str] = None, | ||||
|     reuse_port: bool = False, | ||||
|     loop=None, | ||||
|     protocol: Type[asyncio.Protocol] = HttpProtocol, | ||||
|     backlog: int = 100, | ||||
|     register_sys_signals: bool = True, | ||||
|     run_multiple: bool = False, | ||||
|     run_async: bool = False, | ||||
|     connections=None, | ||||
|     signal=Signal(), | ||||
|     state=None, | ||||
|     asyncio_server_kwargs=None, | ||||
| ): | ||||
|     """Start asynchronous HTTP Server on an individual process. | ||||
|  | ||||
|     :param host: Address to host on | ||||
|     :param port: Port to host on | ||||
|     :param before_start: function to be executed before the server starts | ||||
|                          listening. Takes arguments `app` instance and `loop` | ||||
|     :param after_start: function to be executed after the server starts | ||||
|                         listening. Takes  arguments `app` instance and `loop` | ||||
|     :param before_stop: function to be executed when a stop signal is | ||||
|                         received before it is respected. Takes arguments | ||||
|                         `app` instance and `loop` | ||||
|     :param after_stop: function to be executed when a stop signal is | ||||
|                        received after it is respected. Takes arguments | ||||
|                        `app` instance and `loop` | ||||
|     :param ssl: SSLContext | ||||
|     :param sock: Socket for the server to accept connections from | ||||
|     :param unix: Unix socket to listen on instead of TCP port | ||||
|     :param reuse_port: `True` for multiple workers | ||||
|     :param loop: asyncio compatible event loop | ||||
|     :param run_async: bool: Do not create a new event loop for the server, | ||||
|                       and return an AsyncServer object rather than running it | ||||
|     :param asyncio_server_kwargs: key-value args for asyncio/uvloop | ||||
|                                   create_server method | ||||
|     :return: Nothing | ||||
|     """ | ||||
|     if not run_async and not loop: | ||||
|         # create new event_loop after fork | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|  | ||||
|     if app.debug: | ||||
|         loop.set_debug(app.debug) | ||||
|  | ||||
|     app.asgi = False | ||||
|  | ||||
|     connections = connections if connections is not None else set() | ||||
|     protocol_kwargs = _build_protocol_kwargs(protocol, app.config) | ||||
|     server = partial( | ||||
|         protocol, | ||||
|         loop=loop, | ||||
|         connections=connections, | ||||
|         signal=signal, | ||||
|         app=app, | ||||
|         state=state, | ||||
|         unix=unix, | ||||
|         **protocol_kwargs, | ||||
|     ) | ||||
|     asyncio_server_kwargs = ( | ||||
|         asyncio_server_kwargs if asyncio_server_kwargs else {} | ||||
|     ) | ||||
|     # UNIX sockets are always bound by us (to preserve semantics between modes) | ||||
|     if unix: | ||||
|         sock = bind_unix_socket(unix, backlog=backlog) | ||||
|     server_coroutine = loop.create_server( | ||||
|         server, | ||||
|         None if sock else host, | ||||
|         None if sock else port, | ||||
|         ssl=ssl, | ||||
|         reuse_port=reuse_port, | ||||
|         sock=sock, | ||||
|         backlog=backlog, | ||||
|         **asyncio_server_kwargs, | ||||
|     ) | ||||
|  | ||||
|     if run_async: | ||||
|         return AsyncioServer( | ||||
|             app=app, | ||||
|             loop=loop, | ||||
|             serve_coro=server_coroutine, | ||||
|             connections=connections, | ||||
|         ) | ||||
|  | ||||
|     loop.run_until_complete(app._startup()) | ||||
|     loop.run_until_complete(app._server_event("init", "before")) | ||||
|  | ||||
|     try: | ||||
|         http_server = loop.run_until_complete(server_coroutine) | ||||
|     except BaseException: | ||||
|         error_logger.exception("Unable to start server") | ||||
|         return | ||||
|  | ||||
|     # Ignore SIGINT when run_multiple | ||||
|     if run_multiple: | ||||
|         signal_func(SIGINT, SIG_IGN) | ||||
|  | ||||
|     # Register signals for graceful termination | ||||
|     if register_sys_signals: | ||||
|         if OS_IS_WINDOWS: | ||||
|             ctrlc_workaround_for_windows(app) | ||||
|         else: | ||||
|             for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: | ||||
|                 loop.add_signal_handler(_signal, app.stop) | ||||
|  | ||||
|     loop.run_until_complete(app._server_event("init", "after")) | ||||
|     pid = os.getpid() | ||||
|     try: | ||||
|         logger.info("Starting worker [%s]", pid) | ||||
|         loop.run_forever() | ||||
|     finally: | ||||
|         logger.info("Stopping worker [%s]", pid) | ||||
|  | ||||
|         # Run the on_stop function if provided | ||||
|         loop.run_until_complete(app._server_event("shutdown", "before")) | ||||
|  | ||||
|         # Wait for event loop to finish and all connections to drain | ||||
|         http_server.close() | ||||
|         loop.run_until_complete(http_server.wait_closed()) | ||||
|  | ||||
|         # Complete all tasks on the loop | ||||
|         signal.stopped = True | ||||
|         for connection in connections: | ||||
|             connection.close_if_idle() | ||||
|  | ||||
|         # Gracefully shutdown timeout. | ||||
|         # We should provide graceful_shutdown_timeout, | ||||
|         # instead of letting connection hangs forever. | ||||
|         # Let's roughly calcucate time. | ||||
|         graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT | ||||
|         start_shutdown: float = 0 | ||||
|         while connections and (start_shutdown < graceful): | ||||
|             loop.run_until_complete(asyncio.sleep(0.1)) | ||||
|             start_shutdown = start_shutdown + 0.1 | ||||
|  | ||||
|         # Force close non-idle connection after waiting for | ||||
|         # graceful_shutdown_timeout | ||||
|         for conn in connections: | ||||
|             if hasattr(conn, "websocket") and conn.websocket: | ||||
|                 conn.websocket.fail_connection(code=1001) | ||||
|             else: | ||||
|                 conn.abort() | ||||
|         loop.run_until_complete(app._server_event("shutdown", "after")) | ||||
|  | ||||
|         remove_unix_socket(unix) | ||||
|  | ||||
|  | ||||
| def serve_single(server_settings): | ||||
|     main_start = server_settings.pop("main_start", None) | ||||
|     main_stop = server_settings.pop("main_stop", None) | ||||
|  | ||||
|     if not server_settings.get("run_async"): | ||||
|         # create new event_loop after fork | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         server_settings["loop"] = loop | ||||
|  | ||||
|     trigger_events(main_start, server_settings["loop"]) | ||||
|     serve(**server_settings) | ||||
|     trigger_events(main_stop, server_settings["loop"]) | ||||
|  | ||||
|     server_settings["loop"].close() | ||||
|  | ||||
|  | ||||
| def serve_multiple(server_settings, workers): | ||||
|     """Start multiple server processes simultaneously.  Stop on interrupt | ||||
|     and terminate signals, and drain connections when complete. | ||||
|  | ||||
|     :param server_settings: kw arguments to be passed to the serve function | ||||
|     :param workers: number of workers to launch | ||||
|     :param stop_event: if provided, is used as a stop signal | ||||
|     :return: | ||||
|     """ | ||||
|     server_settings["reuse_port"] = True | ||||
|     server_settings["run_multiple"] = True | ||||
|  | ||||
|     main_start = server_settings.pop("main_start", None) | ||||
|     main_stop = server_settings.pop("main_stop", None) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|  | ||||
|     trigger_events(main_start, loop) | ||||
|  | ||||
|     # Create a listening socket or use the one in settings | ||||
|     sock = server_settings.get("sock") | ||||
|     unix = server_settings["unix"] | ||||
|     backlog = server_settings["backlog"] | ||||
|     if unix: | ||||
|         sock = bind_unix_socket(unix, backlog=backlog) | ||||
|         server_settings["unix"] = unix | ||||
|     if sock is None: | ||||
|         sock = bind_socket( | ||||
|             server_settings["host"], server_settings["port"], backlog=backlog | ||||
|         ) | ||||
|         sock.set_inheritable(True) | ||||
|         server_settings["sock"] = sock | ||||
|         server_settings["host"] = None | ||||
|         server_settings["port"] = None | ||||
|  | ||||
|     processes = [] | ||||
|  | ||||
|     def sig_handler(signal, frame): | ||||
|         logger.info("Received signal %s. Shutting down.", Signals(signal).name) | ||||
|         for process in processes: | ||||
|             os.kill(process.pid, SIGTERM) | ||||
|  | ||||
|     signal_func(SIGINT, lambda s, f: sig_handler(s, f)) | ||||
|     signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) | ||||
|     mp = multiprocessing.get_context("fork") | ||||
|  | ||||
|     for _ in range(workers): | ||||
|         process = mp.Process(target=serve, kwargs=server_settings) | ||||
|         process.daemon = True | ||||
|         process.start() | ||||
|         processes.append(process) | ||||
|  | ||||
|     for process in processes: | ||||
|         process.join() | ||||
|  | ||||
|     # the above processes will block this until they're stopped | ||||
|     for process in processes: | ||||
|         process.terminate() | ||||
|  | ||||
|     trigger_events(main_stop, loop) | ||||
|  | ||||
|     sock.close() | ||||
|     loop.close() | ||||
|     remove_unix_socket(unix) | ||||
|  | ||||
|  | ||||
| def _build_protocol_kwargs( | ||||
|     protocol: Type[asyncio.Protocol], config: Config | ||||
| ) -> Dict[str, Union[int, float]]: | ||||
|     if hasattr(protocol, "websocket_handshake"): | ||||
|         return { | ||||
|             "websocket_max_size": config.WEBSOCKET_MAX_SIZE, | ||||
|             "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, | ||||
|             "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, | ||||
|         } | ||||
|     return {} | ||||
							
								
								
									
										87
									
								
								sanic/server/socket.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								sanic/server/socket.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import os | ||||
| import secrets | ||||
| import socket | ||||
| import stat | ||||
|  | ||||
| from ipaddress import ip_address | ||||
| from typing import Optional | ||||
|  | ||||
|  | ||||
| def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: | ||||
|     """Create TCP server socket. | ||||
|     :param host: IPv4, IPv6 or hostname may be specified | ||||
|     :param port: TCP port number | ||||
|     :param backlog: Maximum number of connections to queue | ||||
|     :return: socket.socket object | ||||
|     """ | ||||
|     try:  # IP address: family must be specified for IPv6 at least | ||||
|         ip = ip_address(host) | ||||
|         host = str(ip) | ||||
|         sock = socket.socket( | ||||
|             socket.AF_INET6 if ip.version == 6 else socket.AF_INET | ||||
|         ) | ||||
|     except ValueError:  # Hostname, may become AF_INET or AF_INET6 | ||||
|         sock = socket.socket() | ||||
|     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||||
|     sock.bind((host, port)) | ||||
|     sock.listen(backlog) | ||||
|     return sock | ||||
|  | ||||
|  | ||||
| def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: | ||||
|     """Create unix socket. | ||||
|     :param path: filesystem path | ||||
|     :param backlog: Maximum number of connections to queue | ||||
|     :return: socket.socket object | ||||
|     """ | ||||
|     """Open or atomically replace existing socket with zero downtime.""" | ||||
|     # Sanitise and pre-verify socket path | ||||
|     path = os.path.abspath(path) | ||||
|     folder = os.path.dirname(path) | ||||
|     if not os.path.isdir(folder): | ||||
|         raise FileNotFoundError(f"Socket folder does not exist: {folder}") | ||||
|     try: | ||||
|         if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): | ||||
|             raise FileExistsError(f"Existing file is not a socket: {path}") | ||||
|     except FileNotFoundError: | ||||
|         pass | ||||
|     # Create new socket with a random temporary name | ||||
|     tmp_path = f"{path}.{secrets.token_urlsafe()}" | ||||
|     sock = socket.socket(socket.AF_UNIX) | ||||
|     try: | ||||
|         # Critical section begins (filename races) | ||||
|         sock.bind(tmp_path) | ||||
|         try: | ||||
|             os.chmod(tmp_path, mode) | ||||
|             # Start listening before rename to avoid connection failures | ||||
|             sock.listen(backlog) | ||||
|             os.rename(tmp_path, path) | ||||
|         except:  # noqa: E722 | ||||
|             try: | ||||
|                 os.unlink(tmp_path) | ||||
|             finally: | ||||
|                 raise | ||||
|     except:  # noqa: E722 | ||||
|         try: | ||||
|             sock.close() | ||||
|         finally: | ||||
|             raise | ||||
|     return sock | ||||
|  | ||||
|  | ||||
| def remove_unix_socket(path: Optional[str]) -> None: | ||||
|     """Remove dead unix socket during server exit.""" | ||||
|     if not path: | ||||
|         return | ||||
|     try: | ||||
|         if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): | ||||
|             # Is it actually dead (doesn't belong to a new server instance)? | ||||
|             with socket.socket(socket.AF_UNIX) as testsock: | ||||
|                 try: | ||||
|                     testsock.connect(path) | ||||
|                 except ConnectionRefusedError: | ||||
|                     os.unlink(path) | ||||
|     except FileNotFoundError: | ||||
|         pass | ||||
							
								
								
									
										0
									
								
								sanic/server/websockets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sanic/server/websockets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										82
									
								
								sanic/server/websockets/connection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								sanic/server/websockets/connection.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Awaitable, | ||||
|     Callable, | ||||
|     Dict, | ||||
|     List, | ||||
|     MutableMapping, | ||||
|     Optional, | ||||
|     Union, | ||||
| ) | ||||
|  | ||||
|  | ||||
| ASIMessage = MutableMapping[str, Any] | ||||
|  | ||||
|  | ||||
| class WebSocketConnection: | ||||
|     """ | ||||
|     This is for ASGI Connections. | ||||
|     It provides an interface similar to WebsocketProtocol, but | ||||
|     sends/receives over an ASGI connection. | ||||
|     """ | ||||
|  | ||||
|     # TODO | ||||
|     # - Implement ping/pong | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         send: Callable[[ASIMessage], Awaitable[None]], | ||||
|         receive: Callable[[], Awaitable[ASIMessage]], | ||||
|         subprotocols: Optional[List[str]] = None, | ||||
|     ) -> None: | ||||
|         self._send = send | ||||
|         self._receive = receive | ||||
|         self._subprotocols = subprotocols or [] | ||||
|  | ||||
|     async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: | ||||
|         message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} | ||||
|  | ||||
|         if isinstance(data, bytes): | ||||
|             message.update({"bytes": data}) | ||||
|         else: | ||||
|             message.update({"text": str(data)}) | ||||
|  | ||||
|         await self._send(message) | ||||
|  | ||||
|     async def recv(self, *args, **kwargs) -> Optional[str]: | ||||
|         message = await self._receive() | ||||
|  | ||||
|         if message["type"] == "websocket.receive": | ||||
|             return message["text"] | ||||
|         elif message["type"] == "websocket.disconnect": | ||||
|             pass | ||||
|  | ||||
|         return None | ||||
|  | ||||
|     receive = recv | ||||
|  | ||||
|     async def accept(self, subprotocols: Optional[List[str]] = None) -> None: | ||||
|         subprotocol = None | ||||
|         if subprotocols: | ||||
|             for subp in subprotocols: | ||||
|                 if subp in self.subprotocols: | ||||
|                     subprotocol = subp | ||||
|                     break | ||||
|  | ||||
|         await self._send( | ||||
|             { | ||||
|                 "type": "websocket.accept", | ||||
|                 "subprotocol": subprotocol, | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     async def close(self, code: int = 1000, reason: str = "") -> None: | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     def subprotocols(self): | ||||
|         return self._subprotocols | ||||
|  | ||||
|     @subprotocols.setter | ||||
|     def subprotocols(self, subprotocols: Optional[List[str]] = None): | ||||
|         self._subprotocols = subprotocols or [] | ||||
							
								
								
									
										294
									
								
								sanic/server/websockets/frame.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								sanic/server/websockets/frame.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,294 @@ | ||||
| import asyncio | ||||
| import codecs | ||||
|  | ||||
| from typing import TYPE_CHECKING, AsyncIterator, List, Optional | ||||
|  | ||||
| from websockets.frames import Frame, Opcode | ||||
| from websockets.typing import Data | ||||
|  | ||||
| from sanic.exceptions import ServerError | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from .impl import WebsocketImplProtocol | ||||
|  | ||||
| UTF8Decoder = codecs.getincrementaldecoder("utf-8") | ||||
|  | ||||
|  | ||||
| class WebsocketFrameAssembler: | ||||
|     """ | ||||
|     Assemble a message from frames. | ||||
|     Code borrowed from aaugustin/websockets project: | ||||
|     https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ( | ||||
|         "protocol", | ||||
|         "read_mutex", | ||||
|         "write_mutex", | ||||
|         "message_complete", | ||||
|         "message_fetched", | ||||
|         "get_in_progress", | ||||
|         "decoder", | ||||
|         "completed_queue", | ||||
|         "chunks", | ||||
|         "chunks_queue", | ||||
|         "paused", | ||||
|         "get_id", | ||||
|         "put_id", | ||||
|     ) | ||||
|     if TYPE_CHECKING: | ||||
|         protocol: "WebsocketImplProtocol" | ||||
|         read_mutex: asyncio.Lock | ||||
|         write_mutex: asyncio.Lock | ||||
|         message_complete: asyncio.Event | ||||
|         message_fetched: asyncio.Event | ||||
|         completed_queue: asyncio.Queue | ||||
|         get_in_progress: bool | ||||
|         decoder: Optional[codecs.IncrementalDecoder] | ||||
|         # For streaming chunks rather than messages: | ||||
|         chunks: List[Data] | ||||
|         chunks_queue: Optional[asyncio.Queue[Optional[Data]]] | ||||
|         paused: bool | ||||
|  | ||||
|     def __init__(self, protocol) -> None: | ||||
|  | ||||
|         self.protocol = protocol | ||||
|  | ||||
|         self.read_mutex = asyncio.Lock() | ||||
|         self.write_mutex = asyncio.Lock() | ||||
|  | ||||
|         self.completed_queue = asyncio.Queue( | ||||
|             maxsize=1 | ||||
|         )  # type: asyncio.Queue[Data] | ||||
|  | ||||
|         # put() sets this event to tell get() that a message can be fetched. | ||||
|         self.message_complete = asyncio.Event() | ||||
|         # get() sets this event to let put() | ||||
|         self.message_fetched = asyncio.Event() | ||||
|  | ||||
|         # This flag prevents concurrent calls to get() by user code. | ||||
|         self.get_in_progress = False | ||||
|  | ||||
|         # Decoder for text frames, None for binary frames. | ||||
|         self.decoder = None | ||||
|  | ||||
|         # Buffer data from frames belonging to the same message. | ||||
|         self.chunks = [] | ||||
|  | ||||
|         # When switching from "buffering" to "streaming", we use a thread-safe | ||||
|         # queue for transferring frames from the writing thread (library code) | ||||
|         # to the reading thread (user code). We're buffering when chunks_queue | ||||
|         # is None and streaming when it's a Queue. None is a sentinel | ||||
|         # value marking the end of the stream, superseding message_complete. | ||||
|  | ||||
|         # Stream data from frames belonging to the same message. | ||||
|         self.chunks_queue = None | ||||
|  | ||||
|         # Flag to indicate we've paused the protocol | ||||
|         self.paused = False | ||||
|  | ||||
|     async def get(self, timeout: Optional[float] = None) -> Optional[Data]: | ||||
|         """ | ||||
|         Read the next message. | ||||
|         :meth:`get` returns a single :class:`str` or :class:`bytes`. | ||||
|         If the :message was fragmented, :meth:`get` waits until the last frame | ||||
|         is received, then it reassembles the message. | ||||
|         If ``timeout`` is set and elapses before a complete message is | ||||
|         received, :meth:`get` returns ``None``. | ||||
|         """ | ||||
|         async with self.read_mutex: | ||||
|             if timeout is not None and timeout <= 0: | ||||
|                 if not self.message_complete.is_set(): | ||||
|                     return None | ||||
|             if self.get_in_progress: | ||||
|                 # This should be guarded against with the read_mutex, | ||||
|                 # exception is only here as a failsafe | ||||
|                 raise ServerError( | ||||
|                     "Called get() on Websocket frame assembler " | ||||
|                     "while asynchronous get is already in progress." | ||||
|                 ) | ||||
|             self.get_in_progress = True | ||||
|  | ||||
|             # If the message_complete event isn't set yet, release the lock to | ||||
|             # allow put() to run and eventually set it. | ||||
|             # Locking with get_in_progress ensures only one task can get here. | ||||
|             if timeout is None: | ||||
|                 completed = await self.message_complete.wait() | ||||
|             elif timeout <= 0: | ||||
|                 completed = self.message_complete.is_set() | ||||
|             else: | ||||
|                 try: | ||||
|                     await asyncio.wait_for( | ||||
|                         self.message_complete.wait(), timeout=timeout | ||||
|                     ) | ||||
|                 except asyncio.TimeoutError: | ||||
|                     ... | ||||
|                 finally: | ||||
|                     completed = self.message_complete.is_set() | ||||
|  | ||||
|             # Unpause the transport, if its paused | ||||
|             if self.paused: | ||||
|                 self.protocol.resume_frames() | ||||
|                 self.paused = False | ||||
|             if not self.get_in_progress: | ||||
|                 # This should be guarded against with the read_mutex, | ||||
|                 # exception is here as a failsafe | ||||
|                 raise ServerError( | ||||
|                     "State of Websocket frame assembler was modified while an " | ||||
|                     "asynchronous get was in progress." | ||||
|                 ) | ||||
|             self.get_in_progress = False | ||||
|  | ||||
|             # Waiting for a complete message timed out. | ||||
|             if not completed: | ||||
|                 return None | ||||
|             if not self.message_complete.is_set(): | ||||
|                 return None | ||||
|  | ||||
|             self.message_complete.clear() | ||||
|  | ||||
|             joiner: Data = b"" if self.decoder is None else "" | ||||
|             # mypy cannot figure out that chunks have the proper type. | ||||
|             message: Data = joiner.join(self.chunks)  # type: ignore | ||||
|             if self.message_fetched.is_set(): | ||||
|                 # This should be guarded against with the read_mutex, | ||||
|                 # and get_in_progress check, this exception is here | ||||
|                 # as a failsafe | ||||
|                 raise ServerError( | ||||
|                     "Websocket get() found a message when " | ||||
|                     "state was already fetched." | ||||
|                 ) | ||||
|             self.message_fetched.set() | ||||
|             self.chunks = [] | ||||
|             # this should already be None, but set it here for safety | ||||
|             self.chunks_queue = None | ||||
|             return message | ||||
|  | ||||
|     async def get_iter(self) -> AsyncIterator[Data]: | ||||
|         """ | ||||
|         Stream the next message. | ||||
|         Iterating the return value of :meth:`get_iter` yields a :class:`str` | ||||
|         or :class:`bytes` for each frame in the message. | ||||
|         """ | ||||
|         async with self.read_mutex: | ||||
|             if self.get_in_progress: | ||||
|                 # This should be guarded against with the read_mutex, | ||||
|                 # exception is only here as a failsafe | ||||
|                 raise ServerError( | ||||
|                     "Called get_iter on Websocket frame assembler " | ||||
|                     "while asynchronous get is already in progress." | ||||
|                 ) | ||||
|             self.get_in_progress = True | ||||
|  | ||||
|             chunks = self.chunks | ||||
|             self.chunks = [] | ||||
|             self.chunks_queue = asyncio.Queue() | ||||
|  | ||||
|             # Sending None in chunk_queue supersedes setting message_complete | ||||
|             # when switching to "streaming". If message is already complete | ||||
|             # when the switch happens, put() didn't send None, so we have to. | ||||
|             if self.message_complete.is_set(): | ||||
|                 await self.chunks_queue.put(None) | ||||
|  | ||||
|             # Locking with get_in_progress ensures only one task can get here | ||||
|             for c in chunks: | ||||
|                 yield c | ||||
|             while True: | ||||
|                 chunk = await self.chunks_queue.get() | ||||
|                 if chunk is None: | ||||
|                     break | ||||
|                 yield chunk | ||||
|  | ||||
|             # Unpause the transport, if its paused | ||||
|             if self.paused: | ||||
|                 self.protocol.resume_frames() | ||||
|                 self.paused = False | ||||
|             if not self.get_in_progress: | ||||
|                 # This should be guarded against with the read_mutex, | ||||
|                 # exception is here as a failsafe | ||||
|                 raise ServerError( | ||||
|                     "State of Websocket frame assembler was modified while an " | ||||
|                     "asynchronous get was in progress." | ||||
|                 ) | ||||
|             self.get_in_progress = False | ||||
|             if not self.message_complete.is_set(): | ||||
|                 # This should be guarded against with the read_mutex, | ||||
|                 # exception is here as a failsafe | ||||
|                 raise ServerError( | ||||
|                     "Websocket frame assembler chunks queue ended before " | ||||
|                     "message was complete." | ||||
|                 ) | ||||
|             self.message_complete.clear() | ||||
|             if self.message_fetched.is_set(): | ||||
|                 # This should be guarded against with the read_mutex, | ||||
|                 # and get_in_progress check, this exception is | ||||
|                 # here as a failsafe | ||||
|                 raise ServerError( | ||||
|                     "Websocket get_iter() found a message when state was " | ||||
|                     "already fetched." | ||||
|                 ) | ||||
|  | ||||
|             self.message_fetched.set() | ||||
|             # this should already be empty, but set it here for safety | ||||
|             self.chunks = [] | ||||
|             self.chunks_queue = None | ||||
|  | ||||
|     async def put(self, frame: Frame) -> None: | ||||
|         """ | ||||
|         Add ``frame`` to the next message. | ||||
|         When ``frame`` is the final frame in a message, :meth:`put` waits | ||||
|         until the message is fetched, either by calling :meth:`get` or by | ||||
|         iterating the return value of :meth:`get_iter`. | ||||
|         :meth:`put` assumes that the stream of frames respects the protocol. | ||||
|         If it doesn't, the behavior is undefined. | ||||
|         """ | ||||
|  | ||||
|         async with self.write_mutex: | ||||
|             if frame.opcode is Opcode.TEXT: | ||||
|                 self.decoder = UTF8Decoder(errors="strict") | ||||
|             elif frame.opcode is Opcode.BINARY: | ||||
|                 self.decoder = None | ||||
|             elif frame.opcode is Opcode.CONT: | ||||
|                 pass | ||||
|             else: | ||||
|                 # Ignore control frames. | ||||
|                 return | ||||
|             data: Data | ||||
|             if self.decoder is not None: | ||||
|                 data = self.decoder.decode(frame.data, frame.fin) | ||||
|             else: | ||||
|                 data = frame.data | ||||
|             if self.chunks_queue is None: | ||||
|                 self.chunks.append(data) | ||||
|             else: | ||||
|                 await self.chunks_queue.put(data) | ||||
|  | ||||
|             if not frame.fin: | ||||
|                 return | ||||
|             if not self.get_in_progress: | ||||
|                 # nobody is waiting for this frame, so try to pause subsequent | ||||
|                 # frames at the protocol level | ||||
|                 self.paused = self.protocol.pause_frames() | ||||
|             # Message is complete. Wait until it's fetched to return. | ||||
|  | ||||
|             if self.chunks_queue is not None: | ||||
|                 await self.chunks_queue.put(None) | ||||
|             if self.message_complete.is_set(): | ||||
|                 # This should be guarded against with the write_mutex | ||||
|                 raise ServerError( | ||||
|                     "Websocket put() got a new message when a message was " | ||||
|                     "already in its chamber." | ||||
|                 ) | ||||
|             self.message_complete.set()  # Signal to get() it can serve the | ||||
|             if self.message_fetched.is_set(): | ||||
|                 # This should be guarded against with the write_mutex | ||||
|                 raise ServerError( | ||||
|                     "Websocket put() got a new message when the previous " | ||||
|                     "message was not yet fetched." | ||||
|                 ) | ||||
|  | ||||
|             # Allow get() to run and eventually set the event. | ||||
|             await self.message_fetched.wait() | ||||
|             self.message_fetched.clear() | ||||
|             self.decoder = None | ||||
							
								
								
									
										834
									
								
								sanic/server/websockets/impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										834
									
								
								sanic/server/websockets/impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,834 @@ | ||||
| import asyncio | ||||
| import random | ||||
| import struct | ||||
|  | ||||
| from typing import ( | ||||
|     AsyncIterator, | ||||
|     Dict, | ||||
|     Iterable, | ||||
|     Mapping, | ||||
|     Optional, | ||||
|     Sequence, | ||||
|     Union, | ||||
| ) | ||||
|  | ||||
| from websockets.connection import CLOSED, CLOSING, OPEN, Event | ||||
| from websockets.exceptions import ConnectionClosed, ConnectionClosedError | ||||
| from websockets.frames import Frame, Opcode | ||||
| from websockets.server import ServerConnection | ||||
| from websockets.typing import Data | ||||
|  | ||||
| from sanic.log import error_logger, logger | ||||
| from sanic.server.protocols.base_protocol import SanicProtocol | ||||
|  | ||||
| from ...exceptions import ServerError, WebsocketClosed | ||||
| from .frame import WebsocketFrameAssembler | ||||
|  | ||||
|  | ||||
| class WebsocketImplProtocol: | ||||
|     connection: ServerConnection | ||||
|     io_proto: Optional[SanicProtocol] | ||||
|     loop: Optional[asyncio.AbstractEventLoop] | ||||
|     max_queue: int | ||||
|     close_timeout: float | ||||
|     ping_interval: Optional[float] | ||||
|     ping_timeout: Optional[float] | ||||
|     assembler: WebsocketFrameAssembler | ||||
|     # Dict[bytes, asyncio.Future[None]] | ||||
|     pings: Dict[bytes, asyncio.Future] | ||||
|     conn_mutex: asyncio.Lock | ||||
|     recv_lock: asyncio.Lock | ||||
|     recv_cancel: Optional[asyncio.Future] | ||||
|     process_event_mutex: asyncio.Lock | ||||
|     can_pause: bool | ||||
|     # Optional[asyncio.Future[None]] | ||||
|     data_finished_fut: Optional[asyncio.Future] | ||||
|     # Optional[asyncio.Future[None]] | ||||
|     pause_frame_fut: Optional[asyncio.Future] | ||||
|     # Optional[asyncio.Future[None]] | ||||
|     connection_lost_waiter: Optional[asyncio.Future] | ||||
|     keepalive_ping_task: Optional[asyncio.Task] | ||||
|     auto_closer_task: Optional[asyncio.Task] | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         connection, | ||||
|         max_queue=None, | ||||
|         ping_interval: Optional[float] = 20, | ||||
|         ping_timeout: Optional[float] = 20, | ||||
|         close_timeout: float = 10, | ||||
|         loop=None, | ||||
|     ): | ||||
|         self.connection = connection | ||||
|         self.io_proto = None | ||||
|         self.loop = None | ||||
|         self.max_queue = max_queue | ||||
|         self.close_timeout = close_timeout | ||||
|         self.ping_interval = ping_interval | ||||
|         self.ping_timeout = ping_timeout | ||||
|         self.assembler = WebsocketFrameAssembler(self) | ||||
|         self.pings = {} | ||||
|         self.conn_mutex = asyncio.Lock() | ||||
|         self.recv_lock = asyncio.Lock() | ||||
|         self.recv_cancel = None | ||||
|         self.process_event_mutex = asyncio.Lock() | ||||
|         self.data_finished_fut = None | ||||
|         self.can_pause = True | ||||
|         self.pause_frame_fut = None | ||||
|         self.keepalive_ping_task = None | ||||
|         self.auto_closer_task = None | ||||
|         self.connection_lost_waiter = None | ||||
|  | ||||
|     @property | ||||
|     def subprotocol(self): | ||||
|         return self.connection.subprotocol | ||||
|  | ||||
|     def pause_frames(self): | ||||
|         if not self.can_pause: | ||||
|             return False | ||||
|         if self.pause_frame_fut: | ||||
|             logger.debug("Websocket connection already paused.") | ||||
|             return False | ||||
|         if (not self.loop) or (not self.io_proto): | ||||
|             return False | ||||
|         if self.io_proto.transport: | ||||
|             self.io_proto.transport.pause_reading() | ||||
|         self.pause_frame_fut = self.loop.create_future() | ||||
|         logger.debug("Websocket connection paused.") | ||||
|         return True | ||||
|  | ||||
|     def resume_frames(self): | ||||
|         if not self.pause_frame_fut: | ||||
|             logger.debug("Websocket connection not paused.") | ||||
|             return False | ||||
|         if (not self.loop) or (not self.io_proto): | ||||
|             logger.debug( | ||||
|                 "Websocket attempting to resume reading frames, " | ||||
|                 "but connection is gone." | ||||
|             ) | ||||
|             return False | ||||
|         if self.io_proto.transport: | ||||
|             self.io_proto.transport.resume_reading() | ||||
|         self.pause_frame_fut.set_result(None) | ||||
|         self.pause_frame_fut = None | ||||
|         logger.debug("Websocket connection unpaused.") | ||||
|         return True | ||||
|  | ||||
|     async def connection_made( | ||||
|         self, | ||||
|         io_proto: SanicProtocol, | ||||
|         loop: Optional[asyncio.AbstractEventLoop] = None, | ||||
|     ): | ||||
|         if not loop: | ||||
|             try: | ||||
|                 loop = getattr(io_proto, "loop") | ||||
|             except AttributeError: | ||||
|                 loop = asyncio.get_event_loop() | ||||
|         if not loop: | ||||
|             # This catch is for mypy type checker | ||||
|             # to assert loop is not None here. | ||||
|             raise ServerError("Connection received with no asyncio loop.") | ||||
|         if self.auto_closer_task: | ||||
|             raise ServerError( | ||||
|                 "Cannot call connection_made more than once " | ||||
|                 "on a websocket connection." | ||||
|             ) | ||||
|         self.loop = loop | ||||
|         self.io_proto = io_proto | ||||
|         self.connection_lost_waiter = self.loop.create_future() | ||||
|         self.data_finished_fut = asyncio.shield(self.loop.create_future()) | ||||
|  | ||||
|         if self.ping_interval: | ||||
|             self.keepalive_ping_task = asyncio.create_task( | ||||
|                 self.keepalive_ping() | ||||
|             ) | ||||
|         self.auto_closer_task = asyncio.create_task( | ||||
|             self.auto_close_connection() | ||||
|         ) | ||||
|  | ||||
|     async def wait_for_connection_lost(self, timeout=None) -> bool: | ||||
|         """ | ||||
|         Wait until the TCP connection is closed or ``timeout`` elapses. | ||||
|         If timeout is None, wait forever. | ||||
|         Recommend you should pass in self.close_timeout as timeout | ||||
|  | ||||
|         Return ``True`` if the connection is closed and ``False`` otherwise. | ||||
|  | ||||
|         """ | ||||
|         if not self.connection_lost_waiter: | ||||
|             return False | ||||
|         if self.connection_lost_waiter.done(): | ||||
|             return True | ||||
|         else: | ||||
|             try: | ||||
|                 await asyncio.wait_for( | ||||
|                     asyncio.shield(self.connection_lost_waiter), timeout | ||||
|                 ) | ||||
|                 return True | ||||
|             except asyncio.TimeoutError: | ||||
|                 # Re-check self.connection_lost_waiter.done() synchronously | ||||
|                 # because connection_lost() could run between the moment the | ||||
|                 # timeout occurs and the moment this coroutine resumes running | ||||
|                 return self.connection_lost_waiter.done() | ||||
|  | ||||
|     async def process_events(self, events: Sequence[Event]) -> None: | ||||
|         """ | ||||
|         Process a list of incoming events. | ||||
|         """ | ||||
|         # Wrapped in a mutex lock, to prevent other incoming events | ||||
|         # from processing at the same time | ||||
|         async with self.process_event_mutex: | ||||
|             for event in events: | ||||
|                 if not isinstance(event, Frame): | ||||
|                     # Event is not a frame. Ignore it. | ||||
|                     continue | ||||
|                 if event.opcode == Opcode.PONG: | ||||
|                     await self.process_pong(event) | ||||
|                 elif event.opcode == Opcode.CLOSE: | ||||
|                     if self.recv_cancel: | ||||
|                         self.recv_cancel.cancel() | ||||
|                 else: | ||||
|                     await self.assembler.put(event) | ||||
|  | ||||
|     async def process_pong(self, frame: Frame) -> None: | ||||
|         if frame.data in self.pings: | ||||
|             # Acknowledge all pings up to the one matching this pong. | ||||
|             ping_ids = [] | ||||
|             for ping_id, ping in self.pings.items(): | ||||
|                 ping_ids.append(ping_id) | ||||
|                 if not ping.done(): | ||||
|                     ping.set_result(None) | ||||
|                 if ping_id == frame.data: | ||||
|                     break | ||||
|             else:  # noqa | ||||
|                 raise ServerError("ping_id is not in self.pings") | ||||
|             # Remove acknowledged pings from self.pings. | ||||
|             for ping_id in ping_ids: | ||||
|                 del self.pings[ping_id] | ||||
|  | ||||
|     async def keepalive_ping(self) -> None: | ||||
|         """ | ||||
|         Send a Ping frame and wait for a Pong frame at regular intervals. | ||||
|         This coroutine exits when the connection terminates and one of the | ||||
|         following happens: | ||||
|         - :meth:`ping` raises :exc:`ConnectionClosed`, or | ||||
|         - :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`. | ||||
|         """ | ||||
|         if self.ping_interval is None: | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             while True: | ||||
|                 await asyncio.sleep(self.ping_interval) | ||||
|  | ||||
|                 # ping() raises CancelledError if the connection is closed, | ||||
|                 # when auto_close_connection() cancels keepalive_ping_task. | ||||
|  | ||||
|                 # ping() raises ConnectionClosed if the connection is lost, | ||||
|                 # when connection_lost() calls abort_pings(). | ||||
|  | ||||
|                 ping_waiter = await self.ping() | ||||
|  | ||||
|                 if self.ping_timeout is not None: | ||||
|                     try: | ||||
|                         await asyncio.wait_for(ping_waiter, self.ping_timeout) | ||||
|                     except asyncio.TimeoutError: | ||||
|                         error_logger.warning( | ||||
|                             "Websocket timed out waiting for pong" | ||||
|                         ) | ||||
|                         self.fail_connection(1011) | ||||
|                         break | ||||
|         except asyncio.CancelledError: | ||||
|             # It is expected for this task to be cancelled during during | ||||
|             # normal operation, when the connection is closed. | ||||
|             logger.debug("Websocket keepalive ping task was cancelled.") | ||||
|         except (ConnectionClosed, WebsocketClosed): | ||||
|             logger.debug("Websocket closed. Keepalive ping task exiting.") | ||||
|         except Exception as e: | ||||
|             error_logger.warning( | ||||
|                 "Unexpected exception in websocket keepalive ping task." | ||||
|             ) | ||||
|             logger.debug(str(e)) | ||||
|  | ||||
|     def _force_disconnect(self) -> bool: | ||||
|         """ | ||||
|         Internal methdod used by end_connection and fail_connection | ||||
|         only when the graceful auto-closer cannot be used | ||||
|         """ | ||||
|         if self.auto_closer_task and not self.auto_closer_task.done(): | ||||
|             self.auto_closer_task.cancel() | ||||
|         if self.data_finished_fut and not self.data_finished_fut.done(): | ||||
|             self.data_finished_fut.cancel() | ||||
|             self.data_finished_fut = None | ||||
|         if self.keepalive_ping_task and not self.keepalive_ping_task.done(): | ||||
|             self.keepalive_ping_task.cancel() | ||||
|             self.keepalive_ping_task = None | ||||
|         if self.loop and self.io_proto and self.io_proto.transport: | ||||
|             self.io_proto.transport.close() | ||||
|             self.loop.call_later( | ||||
|                 self.close_timeout, self.io_proto.transport.abort | ||||
|             ) | ||||
|         # We were never open, or already closed | ||||
|         return True | ||||
|  | ||||
|     def fail_connection(self, code: int = 1006, reason: str = "") -> bool: | ||||
|         """ | ||||
|         Fail the WebSocket Connection | ||||
|         This requires: | ||||
|         1. Stopping all processing of incoming data, which means cancelling | ||||
|            pausing the underlying io protocol. The close code will be 1006 | ||||
|            unless a close frame was received earlier. | ||||
|         2. Sending a close frame with an appropriate code if the opening | ||||
|            handshake succeeded and the other side is likely to process it. | ||||
|         3. Closing the connection. :meth:`auto_close_connection` takes care | ||||
|            of this. | ||||
|         (The specification describes these steps in the opposite order.) | ||||
|         """ | ||||
|         if self.io_proto and self.io_proto.transport: | ||||
|             # Stop new data coming in | ||||
|             # In Python Version 3.7: pause_reading is idempotent | ||||
|             # ut can be called when the transport is already paused or closed | ||||
|             self.io_proto.transport.pause_reading() | ||||
|  | ||||
|             # Keeping fail_connection() synchronous guarantees it can't | ||||
|             # get stuck and simplifies the implementation of the callers. | ||||
|             # Not draining the write buffer is acceptable in this context. | ||||
|  | ||||
|             # clear the send buffer | ||||
|             _ = self.connection.data_to_send() | ||||
|             # If we're not already CLOSED or CLOSING, then send the close. | ||||
|             if self.connection.state is OPEN: | ||||
|                 if code in (1000, 1001): | ||||
|                     self.connection.send_close(code, reason) | ||||
|                 else: | ||||
|                     self.connection.fail(code, reason) | ||||
|                 try: | ||||
|                     data_to_send = self.connection.data_to_send() | ||||
|                     while ( | ||||
|                         len(data_to_send) | ||||
|                         and self.io_proto | ||||
|                         and self.io_proto.transport | ||||
|                     ): | ||||
|                         frame_data = data_to_send.pop(0) | ||||
|                         self.io_proto.transport.write(frame_data) | ||||
|                 except Exception: | ||||
|                     # sending close frames may fail if the | ||||
|                     # transport closes during this period | ||||
|                     ... | ||||
|         if code == 1006: | ||||
|             # Special case: 1006 consider the transport already closed | ||||
|             self.connection.state = CLOSED | ||||
|         if self.data_finished_fut and not self.data_finished_fut.done(): | ||||
|             # We have a graceful auto-closer. Use it to close the connection. | ||||
|             self.data_finished_fut.cancel() | ||||
|             self.data_finished_fut = None | ||||
|         if (not self.auto_closer_task) or self.auto_closer_task.done(): | ||||
|             return self._force_disconnect() | ||||
|         return False | ||||
|  | ||||
|     def end_connection(self, code=1000, reason=""): | ||||
|         # This is like slightly more graceful form of fail_connection | ||||
|         # Use this instead of close() when you need an immediate | ||||
|         # close and cannot await websocket.close() handshake. | ||||
|  | ||||
|         if code == 1006 or not self.io_proto or not self.io_proto.transport: | ||||
|             return self.fail_connection(code, reason) | ||||
|  | ||||
|         # Stop new data coming in | ||||
|         # In Python Version 3.7: pause_reading is idempotent | ||||
|         # i.e. it can be called when the transport is already paused or closed. | ||||
|         self.io_proto.transport.pause_reading() | ||||
|         if self.connection.state == OPEN: | ||||
|             data_to_send = self.connection.data_to_send() | ||||
|             self.connection.send_close(code, reason) | ||||
|             data_to_send.extend(self.connection.data_to_send()) | ||||
|             try: | ||||
|                 while ( | ||||
|                     len(data_to_send) | ||||
|                     and self.io_proto | ||||
|                     and self.io_proto.transport | ||||
|                 ): | ||||
|                     frame_data = data_to_send.pop(0) | ||||
|                     self.io_proto.transport.write(frame_data) | ||||
|             except Exception: | ||||
|                 # sending close frames may fail if the | ||||
|                 # transport closes during this period | ||||
|                 # But that doesn't matter at this point | ||||
|                 ... | ||||
|         if self.data_finished_fut and not self.data_finished_fut.done(): | ||||
|             # We have the ability to signal the auto-closer | ||||
|             # try to trigger it to auto-close the connection | ||||
|             self.data_finished_fut.cancel() | ||||
|             self.data_finished_fut = None | ||||
|         if (not self.auto_closer_task) or self.auto_closer_task.done(): | ||||
|             # Auto-closer is not running, do force disconnect | ||||
|             return self._force_disconnect() | ||||
|         return False | ||||
|  | ||||
|     async def auto_close_connection(self) -> None: | ||||
|         """ | ||||
|         Close the WebSocket Connection | ||||
|         When the opening handshake succeeds, :meth:`connection_open` starts | ||||
|         this coroutine in a task. It waits for the data transfer phase to | ||||
|         complete then it closes the TCP connection cleanly. | ||||
|         When the opening handshake fails, :meth:`fail_connection` does the | ||||
|         same. There's no data transfer phase in that case. | ||||
|         """ | ||||
|         try: | ||||
|             # Wait for the data transfer phase to complete. | ||||
|             if self.data_finished_fut: | ||||
|                 try: | ||||
|                     await self.data_finished_fut | ||||
|                     logger.debug( | ||||
|                         "Websocket task finished. Closing the connection." | ||||
|                     ) | ||||
|                 except asyncio.CancelledError: | ||||
|                     # Cancelled error is called when data phase is cancelled | ||||
|                     # if an error occurred or the client closed the connection | ||||
|                     logger.debug( | ||||
|                         "Websocket handler cancelled. Closing the connection." | ||||
|                     ) | ||||
|  | ||||
|             # Cancel the keepalive ping task. | ||||
|             if self.keepalive_ping_task: | ||||
|                 self.keepalive_ping_task.cancel() | ||||
|                 self.keepalive_ping_task = None | ||||
|  | ||||
|             # Half-close the TCP connection if possible (when there's no TLS). | ||||
|             if ( | ||||
|                 self.io_proto | ||||
|                 and self.io_proto.transport | ||||
|                 and self.io_proto.transport.can_write_eof() | ||||
|             ): | ||||
|                 logger.debug("Websocket half-closing TCP connection") | ||||
|                 self.io_proto.transport.write_eof() | ||||
|                 if self.connection_lost_waiter: | ||||
|                     if await self.wait_for_connection_lost(timeout=0): | ||||
|                         return | ||||
|         except asyncio.CancelledError: | ||||
|             ... | ||||
|         finally: | ||||
|             # The try/finally ensures that the transport never remains open, | ||||
|             # even if this coroutine is cancelled (for example). | ||||
|             if (not self.io_proto) or (not self.io_proto.transport): | ||||
|                 # we were never open, or done. Can't do any finalization. | ||||
|                 return | ||||
|             elif ( | ||||
|                 self.connection_lost_waiter | ||||
|                 and self.connection_lost_waiter.done() | ||||
|             ): | ||||
|                 # connection confirmed closed already, proceed to abort waiter | ||||
|                 ... | ||||
|             elif self.io_proto.transport.is_closing(): | ||||
|                 # Connection is already closing (due to half-close above) | ||||
|                 # proceed to abort waiter | ||||
|                 ... | ||||
|             else: | ||||
|                 self.io_proto.transport.close() | ||||
|             if not self.connection_lost_waiter: | ||||
|                 # Our connection monitor task isn't running. | ||||
|                 try: | ||||
|                     await asyncio.sleep(self.close_timeout) | ||||
|                 except asyncio.CancelledError: | ||||
|                     ... | ||||
|                 if self.io_proto and self.io_proto.transport: | ||||
|                     self.io_proto.transport.abort() | ||||
|             else: | ||||
|                 if await self.wait_for_connection_lost( | ||||
|                     timeout=self.close_timeout | ||||
|                 ): | ||||
|                     # Connection aborted before the timeout expired. | ||||
|                     return | ||||
|                 error_logger.warning( | ||||
|                     "Timeout waiting for TCP connection to close. Aborting" | ||||
|                 ) | ||||
|                 if self.io_proto and self.io_proto.transport: | ||||
|                     self.io_proto.transport.abort() | ||||
|  | ||||
|     def abort_pings(self) -> None: | ||||
|         """ | ||||
|         Raise ConnectionClosed in pending keepalive pings. | ||||
|         They'll never receive a pong once the connection is closed. | ||||
|         """ | ||||
|         if self.connection.state is not CLOSED: | ||||
|             raise ServerError( | ||||
|                 "Webscoket about_pings should only be called " | ||||
|                 "after connection state is changed to CLOSED" | ||||
|             ) | ||||
|  | ||||
|         for ping in self.pings.values(): | ||||
|             ping.set_exception(ConnectionClosedError(None, None)) | ||||
|             # If the exception is never retrieved, it will be logged when ping | ||||
|             # is garbage-collected. This is confusing for users. | ||||
|             # Given that ping is done (with an exception), canceling it does | ||||
|             # nothing, but it prevents logging the exception. | ||||
|             ping.cancel() | ||||
|  | ||||
|     async def close(self, code: int = 1000, reason: str = "") -> None: | ||||
|         """ | ||||
|         Perform the closing handshake. | ||||
|         This is a websocket-protocol level close. | ||||
|         :meth:`close` waits for the other end to complete the handshake and | ||||
|         for the TCP connection to terminate. | ||||
|         :meth:`close` is idempotent: it doesn't do anything once the | ||||
|         connection is closed. | ||||
|         :param code: WebSocket close code | ||||
|         :param reason: WebSocket close reason | ||||
|         """ | ||||
|         if code == 1006: | ||||
|             self.fail_connection(code, reason) | ||||
|             return | ||||
|         async with self.conn_mutex: | ||||
|             if self.connection.state is OPEN: | ||||
|                 self.connection.send_close(code, reason) | ||||
|                 data_to_send = self.connection.data_to_send() | ||||
|                 await self.send_data(data_to_send) | ||||
|  | ||||
|     async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: | ||||
|         """ | ||||
|         Receive the next message. | ||||
|         Return a :class:`str` for a text frame and :class:`bytes` for a binary | ||||
|         frame. | ||||
|         When the end of the message stream is reached, :meth:`recv` raises | ||||
|         :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it | ||||
|         raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal | ||||
|         connection closure and | ||||
|         :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol | ||||
|         error or a network failure. | ||||
|         If ``timeout`` is ``None``, block until a message is received. Else, | ||||
|         if no message is received within ``timeout`` seconds, return ``None``. | ||||
|         Set ``timeout`` to ``0`` to check if a message was already received. | ||||
|         :raises ~websockets.exceptions.ConnectionClosed: when the | ||||
|             connection is closed | ||||
|         :raises asyncio.CancelledError: if the websocket closes while waiting | ||||
|         :raises ServerError: if two tasks call :meth:`recv` or | ||||
|             :meth:`recv_streaming` concurrently | ||||
|         """ | ||||
|  | ||||
|         if self.recv_lock.locked(): | ||||
|             raise ServerError( | ||||
|                 "cannot call recv while another task is " | ||||
|                 "already waiting for the next message" | ||||
|             ) | ||||
|         await self.recv_lock.acquire() | ||||
|         if self.connection.state is CLOSED: | ||||
|             self.recv_lock.release() | ||||
|             raise WebsocketClosed( | ||||
|                 "Cannot receive from websocket interface after it is closed." | ||||
|             ) | ||||
|         try: | ||||
|             self.recv_cancel = asyncio.Future() | ||||
|             done, pending = await asyncio.wait( | ||||
|                 (self.recv_cancel, self.assembler.get(timeout)), | ||||
|                 return_when=asyncio.FIRST_COMPLETED, | ||||
|             ) | ||||
|             done_task = next(iter(done)) | ||||
|             if done_task is self.recv_cancel: | ||||
|                 # recv was cancelled | ||||
|                 for p in pending: | ||||
|                     p.cancel() | ||||
|                 raise asyncio.CancelledError() | ||||
|             else: | ||||
|                 self.recv_cancel.cancel() | ||||
|                 return done_task.result() | ||||
|         finally: | ||||
|             self.recv_cancel = None | ||||
|             self.recv_lock.release() | ||||
|  | ||||
|     async def recv_burst(self, max_recv=256) -> Sequence[Data]: | ||||
|         """ | ||||
|         Receive the messages which have arrived since last checking. | ||||
|         Return a :class:`list` containing :class:`str` for a text frame | ||||
|         and :class:`bytes` for a binary frame. | ||||
|         When the end of the message stream is reached, :meth:`recv_burst` | ||||
|         raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, | ||||
|         it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a | ||||
|         normal connection closure and | ||||
|         :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol | ||||
|         error or a network failure. | ||||
|         :raises ~websockets.exceptions.ConnectionClosed: when the | ||||
|             connection is closed | ||||
|         :raises ServerError: if two tasks call :meth:`recv_burst` or | ||||
|             :meth:`recv_streaming` concurrently | ||||
|         """ | ||||
|  | ||||
|         if self.recv_lock.locked(): | ||||
|             raise ServerError( | ||||
|                 "cannot call recv_burst while another task is already waiting " | ||||
|                 "for the next message" | ||||
|             ) | ||||
|         await self.recv_lock.acquire() | ||||
|         if self.connection.state is CLOSED: | ||||
|             self.recv_lock.release() | ||||
|             raise WebsocketClosed( | ||||
|                 "Cannot receive from websocket interface after it is closed." | ||||
|             ) | ||||
|         messages = [] | ||||
|         try: | ||||
|             # Prevent pausing the transport when we're | ||||
|             # receiving a burst of messages | ||||
|             self.can_pause = False | ||||
|             self.recv_cancel = asyncio.Future() | ||||
|             while True: | ||||
|                 done, pending = await asyncio.wait( | ||||
|                     (self.recv_cancel, self.assembler.get(timeout=0)), | ||||
|                     return_when=asyncio.FIRST_COMPLETED, | ||||
|                 ) | ||||
|                 done_task = next(iter(done)) | ||||
|                 if done_task is self.recv_cancel: | ||||
|                     # recv_burst was cancelled | ||||
|                     for p in pending: | ||||
|                         p.cancel() | ||||
|                     raise asyncio.CancelledError() | ||||
|                 m = done_task.result() | ||||
|                 if m is None: | ||||
|                     # None left in the burst. This is good! | ||||
|                     break | ||||
|                 messages.append(m) | ||||
|                 if len(messages) >= max_recv: | ||||
|                     # Too much data in the pipe. Hit our burst limit. | ||||
|                     break | ||||
|                 # Allow an eventloop iteration for the | ||||
|                 # next message to pass into the Assembler | ||||
|                 await asyncio.sleep(0) | ||||
|             self.recv_cancel.cancel() | ||||
|         finally: | ||||
|             self.recv_cancel = None | ||||
|             self.can_pause = True | ||||
|             self.recv_lock.release() | ||||
|         return messages | ||||
|  | ||||
|     async def recv_streaming(self) -> AsyncIterator[Data]: | ||||
|         """ | ||||
|         Receive the next message frame by frame. | ||||
|         Return an iterator of :class:`str` for a text frame and :class:`bytes` | ||||
|         for a binary frame. The iterator should be exhausted, or else the | ||||
|         connection will become unusable. | ||||
|         With the exception of the return value, :meth:`recv_streaming` behaves | ||||
|         like :meth:`recv`. | ||||
|         """ | ||||
|         if self.recv_lock.locked(): | ||||
|             raise ServerError( | ||||
|                 "Cannot call recv_streaming while another task " | ||||
|                 "is already waiting for the next message" | ||||
|             ) | ||||
|         await self.recv_lock.acquire() | ||||
|         if self.connection.state is CLOSED: | ||||
|             self.recv_lock.release() | ||||
|             raise WebsocketClosed( | ||||
|                 "Cannot receive from websocket interface after it is closed." | ||||
|             ) | ||||
|         try: | ||||
|             cancelled = False | ||||
|             self.recv_cancel = asyncio.Future() | ||||
|             self.can_pause = False | ||||
|             async for m in self.assembler.get_iter(): | ||||
|                 if self.recv_cancel.done(): | ||||
|                     cancelled = True | ||||
|                     break | ||||
|                 yield m | ||||
|             if cancelled: | ||||
|                 raise asyncio.CancelledError() | ||||
|         finally: | ||||
|             self.can_pause = True | ||||
|             self.recv_cancel = None | ||||
|             self.recv_lock.release() | ||||
|  | ||||
|     async def send(self, message: Union[Data, Iterable[Data]]) -> None: | ||||
|         """ | ||||
|         Send a message. | ||||
|         A string (:class:`str`) is sent as a `Text frame`_. A bytestring or | ||||
|         bytes-like object (:class:`bytes`, :class:`bytearray`, or | ||||
|         :class:`memoryview`) is sent as a `Binary frame`_. | ||||
|         .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 | ||||
|         .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 | ||||
|         :meth:`send` also accepts an iterable of strings, bytestrings, or | ||||
|         bytes-like objects. In that case the message is fragmented. Each item | ||||
|         is treated as a message fragment and sent in its own frame. All items | ||||
|         must be of the same type, or else :meth:`send` will raise a | ||||
|         :exc:`TypeError` and the connection will be closed. | ||||
|         :meth:`send` rejects dict-like objects because this is often an error. | ||||
|         If you wish to send the keys of a dict-like object as fragments, call | ||||
|         its :meth:`~dict.keys` method and pass the result to :meth:`send`. | ||||
|         :raises TypeError: for unsupported inputs | ||||
|         """ | ||||
|         async with self.conn_mutex: | ||||
|  | ||||
|             if self.connection.state in (CLOSED, CLOSING): | ||||
|                 raise WebsocketClosed( | ||||
|                     "Cannot write to websocket interface after it is closed." | ||||
|                 ) | ||||
|             if (not self.data_finished_fut) or self.data_finished_fut.done(): | ||||
|                 raise ServerError( | ||||
|                     "Cannot write to websocket interface after it is finished." | ||||
|                 ) | ||||
|  | ||||
|             # Unfragmented message -- this case must be handled first because | ||||
|             # strings and bytes-like objects are iterable. | ||||
|  | ||||
|             if isinstance(message, str): | ||||
|                 self.connection.send_text(message.encode("utf-8")) | ||||
|                 await self.send_data(self.connection.data_to_send()) | ||||
|  | ||||
|             elif isinstance(message, (bytes, bytearray, memoryview)): | ||||
|                 self.connection.send_binary(message) | ||||
|                 await self.send_data(self.connection.data_to_send()) | ||||
|  | ||||
|             elif isinstance(message, Mapping): | ||||
|                 # Catch a common mistake -- passing a dict to send(). | ||||
|                 raise TypeError("data is a dict-like object") | ||||
|  | ||||
|             elif isinstance(message, Iterable): | ||||
|                 # Fragmented message -- regular iterator. | ||||
|                 raise NotImplementedError( | ||||
|                     "Fragmented websocket messages are not supported." | ||||
|                 ) | ||||
|             else: | ||||
|                 raise TypeError("Websocket data must be bytes, str.") | ||||
|  | ||||
|     async def ping(self, data: Optional[Data] = None) -> asyncio.Future: | ||||
|         """ | ||||
|         Send a ping. | ||||
|         Return an :class:`~asyncio.Future` that will be resolved when the | ||||
|         corresponding pong is received. You can ignore it if you don't intend | ||||
|         to wait. | ||||
|         A ping may serve as a keepalive or as a check that the remote endpoint | ||||
|         received all messages up to this point:: | ||||
|             await pong_event = ws.ping() | ||||
|             await pong_event # only if you want to wait for the pong | ||||
|         By default, the ping contains four random bytes. This payload may be | ||||
|         overridden with the optional ``data`` argument which must be a string | ||||
|         (which will be encoded to UTF-8) or a bytes-like object. | ||||
|         """ | ||||
|         async with self.conn_mutex: | ||||
|             if self.connection.state in (CLOSED, CLOSING): | ||||
|                 raise WebsocketClosed( | ||||
|                     "Cannot send a ping when the websocket interface " | ||||
|                     "is closed." | ||||
|                 ) | ||||
|             if (not self.io_proto) or (not self.io_proto.loop): | ||||
|                 raise ServerError( | ||||
|                     "Cannot send a ping when the websocket has no I/O " | ||||
|                     "protocol attached." | ||||
|                 ) | ||||
|             if data is not None: | ||||
|                 if isinstance(data, str): | ||||
|                     data = data.encode("utf-8") | ||||
|                 elif isinstance(data, (bytearray, memoryview)): | ||||
|                     data = bytes(data) | ||||
|  | ||||
|             # Protect against duplicates if a payload is explicitly set. | ||||
|             if data in self.pings: | ||||
|                 raise ValueError( | ||||
|                     "already waiting for a pong with the same data" | ||||
|                 ) | ||||
|  | ||||
|             # Generate a unique random payload otherwise. | ||||
|             while data is None or data in self.pings: | ||||
|                 data = struct.pack("!I", random.getrandbits(32)) | ||||
|  | ||||
|             self.pings[data] = self.io_proto.loop.create_future() | ||||
|  | ||||
|             self.connection.send_ping(data) | ||||
|             await self.send_data(self.connection.data_to_send()) | ||||
|  | ||||
|             return asyncio.shield(self.pings[data]) | ||||
|  | ||||
|     async def pong(self, data: Data = b"") -> None: | ||||
|         """ | ||||
|         Send a pong. | ||||
|         An unsolicited pong may serve as a unidirectional heartbeat. | ||||
|         The payload may be set with the optional ``data`` argument which must | ||||
|         be a string (which will be encoded to UTF-8) or a bytes-like object. | ||||
|         """ | ||||
|         async with self.conn_mutex: | ||||
|             if self.connection.state in (CLOSED, CLOSING): | ||||
|                 # Cannot send pong after transport is shutting down | ||||
|                 return | ||||
|             if isinstance(data, str): | ||||
|                 data = data.encode("utf-8") | ||||
|             elif isinstance(data, (bytearray, memoryview)): | ||||
|                 data = bytes(data) | ||||
|             self.connection.send_pong(data) | ||||
|             await self.send_data(self.connection.data_to_send()) | ||||
|  | ||||
|     async def send_data(self, data_to_send): | ||||
|         for data in data_to_send: | ||||
|             if data: | ||||
|                 await self.io_proto.send(data) | ||||
|             else: | ||||
|                 # Send an EOF - We don't actually send it, | ||||
|                 # just trigger to autoclose the connection | ||||
|                 if ( | ||||
|                     self.auto_closer_task | ||||
|                     and not self.auto_closer_task.done() | ||||
|                     and self.data_finished_fut | ||||
|                     and not self.data_finished_fut.done() | ||||
|                 ): | ||||
|                     # Auto-close the connection | ||||
|                     self.data_finished_fut.set_result(None) | ||||
|                 else: | ||||
|                     # This will fail the connection appropriately | ||||
|                     SanicProtocol.close(self.io_proto, timeout=1.0) | ||||
|  | ||||
|     async def async_data_received(self, data_to_send, events_to_process): | ||||
|         if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: | ||||
|             # receiving data can generate data to send (eg, pong for a ping) | ||||
|             # send connection.data_to_send() | ||||
|             await self.send_data(data_to_send) | ||||
|         if len(events_to_process) > 0: | ||||
|             await self.process_events(events_to_process) | ||||
|  | ||||
|     def data_received(self, data): | ||||
|         self.connection.receive_data(data) | ||||
|         data_to_send = self.connection.data_to_send() | ||||
|         events_to_process = self.connection.events_received() | ||||
|         if len(data_to_send) > 0 or len(events_to_process) > 0: | ||||
|             asyncio.create_task( | ||||
|                 self.async_data_received(data_to_send, events_to_process) | ||||
|             ) | ||||
|  | ||||
|     async def async_eof_received(self, data_to_send, events_to_process): | ||||
|         # receiving EOF can generate data to send | ||||
|         # send connection.data_to_send() | ||||
|         if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: | ||||
|             await self.send_data(data_to_send) | ||||
|         if len(events_to_process) > 0: | ||||
|             await self.process_events(events_to_process) | ||||
|         if self.recv_cancel: | ||||
|             self.recv_cancel.cancel() | ||||
|         if ( | ||||
|             self.auto_closer_task | ||||
|             and not self.auto_closer_task.done() | ||||
|             and self.data_finished_fut | ||||
|             and not self.data_finished_fut.done() | ||||
|         ): | ||||
|             # Auto-close the connection | ||||
|             self.data_finished_fut.set_result(None) | ||||
|             # Cancel the running handler if its waiting | ||||
|         else: | ||||
|             # This will fail the connection appropriately | ||||
|             SanicProtocol.close(self.io_proto, timeout=1.0) | ||||
|  | ||||
|     def eof_received(self) -> Optional[bool]: | ||||
|         self.connection.receive_eof() | ||||
|         data_to_send = self.connection.data_to_send() | ||||
|         events_to_process = self.connection.events_received() | ||||
|         asyncio.create_task( | ||||
|             self.async_eof_received(data_to_send, events_to_process) | ||||
|         ) | ||||
|         return False | ||||
|  | ||||
|     def connection_lost(self, exc): | ||||
|         """ | ||||
|         The WebSocket Connection is Closed. | ||||
|         """ | ||||
|         if not self.connection.state == CLOSED: | ||||
|             # signal to the websocket connection handler | ||||
|             # we've lost the connection | ||||
|             self.connection.fail(code=1006) | ||||
|             self.connection.state = CLOSED | ||||
|  | ||||
|         self.abort_pings() | ||||
|         if self.connection_lost_waiter: | ||||
|             self.connection_lost_waiter.set_result(None) | ||||
| @@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound  # type: ignore | ||||
| from sanic_routing.utils import path_to_parts  # type: ignore | ||||
|  | ||||
| from sanic.exceptions import InvalidSignal | ||||
| from sanic.log import error_logger, logger | ||||
| from sanic.models.handler_types import SignalHandler | ||||
|  | ||||
|  | ||||
| RESERVED_NAMESPACES = ( | ||||
|     "server", | ||||
|     "http", | ||||
| ) | ||||
| RESERVED_NAMESPACES = { | ||||
|     "server": ( | ||||
|         # "server.main.start", | ||||
|         # "server.main.stop", | ||||
|         "server.init.before", | ||||
|         "server.init.after", | ||||
|         "server.shutdown.before", | ||||
|         "server.shutdown.after", | ||||
|     ), | ||||
|     "http": ( | ||||
|         "http.lifecycle.begin", | ||||
|         "http.lifecycle.complete", | ||||
|         "http.lifecycle.exception", | ||||
|         "http.lifecycle.handle", | ||||
|         "http.lifecycle.read_body", | ||||
|         "http.lifecycle.read_head", | ||||
|         "http.lifecycle.request", | ||||
|         "http.lifecycle.response", | ||||
|         "http.routing.after", | ||||
|         "http.routing.before", | ||||
|         "http.lifecycle.send", | ||||
|         "http.middleware.after", | ||||
|         "http.middleware.before", | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
| def _blank(): | ||||
|     ... | ||||
|  | ||||
|  | ||||
| class Signal(Route): | ||||
| @@ -59,8 +85,13 @@ class SignalRouter(BaseRouter): | ||||
|                 terms.append(extra) | ||||
|             raise NotFound(message % tuple(terms)) | ||||
|  | ||||
|         # Regex routes evaluate and can extract params directly. They are set | ||||
|         # on param_basket["__params__"] | ||||
|         params = param_basket["__params__"] | ||||
|         if not params: | ||||
|             # If param_basket["__params__"] does not exist, we might have | ||||
|             # param_basket["__matches__"], which are indexed based matches | ||||
|             # on path segments. They should already be cast types. | ||||
|             params = { | ||||
|                 param.name: param_basket["__matches__"][idx] | ||||
|                 for idx, param in group.params.items() | ||||
| @@ -73,8 +104,18 @@ class SignalRouter(BaseRouter): | ||||
|         event: str, | ||||
|         context: Optional[Dict[str, Any]] = None, | ||||
|         condition: Optional[Dict[str, str]] = None, | ||||
|     ) -> None: | ||||
|         group, handlers, params = self.get(event, condition=condition) | ||||
|         fail_not_found: bool = True, | ||||
|         reverse: bool = False, | ||||
|     ) -> Any: | ||||
|         try: | ||||
|             group, handlers, params = self.get(event, condition=condition) | ||||
|         except NotFound as e: | ||||
|             if fail_not_found: | ||||
|                 raise e | ||||
|             else: | ||||
|                 if self.ctx.app.debug: | ||||
|                     error_logger.warning(str(e)) | ||||
|                 return None | ||||
|  | ||||
|         events = [signal.ctx.event for signal in group] | ||||
|         for signal_event in events: | ||||
| @@ -82,12 +123,19 @@ class SignalRouter(BaseRouter): | ||||
|         if context: | ||||
|             params.update(context) | ||||
|  | ||||
|         if not reverse: | ||||
|             handlers = handlers[::-1] | ||||
|         try: | ||||
|             for handler in handlers: | ||||
|                 if condition is None or condition == handler.__requirements__: | ||||
|                     maybe_coroutine = handler(**params) | ||||
|                     if isawaitable(maybe_coroutine): | ||||
|                         await maybe_coroutine | ||||
|                         retval = await maybe_coroutine | ||||
|                         if retval: | ||||
|                             return retval | ||||
|                     elif maybe_coroutine: | ||||
|                         return maybe_coroutine | ||||
|             return None | ||||
|         finally: | ||||
|             for signal_event in events: | ||||
|                 signal_event.clear() | ||||
| @@ -98,14 +146,23 @@ class SignalRouter(BaseRouter): | ||||
|         *, | ||||
|         context: Optional[Dict[str, Any]] = None, | ||||
|         condition: Optional[Dict[str, str]] = None, | ||||
|     ) -> asyncio.Task: | ||||
|         task = self.ctx.loop.create_task( | ||||
|             self._dispatch( | ||||
|                 event, | ||||
|                 context=context, | ||||
|                 condition=condition, | ||||
|             ) | ||||
|         fail_not_found: bool = True, | ||||
|         inline: bool = False, | ||||
|         reverse: bool = False, | ||||
|     ) -> Union[asyncio.Task, Any]: | ||||
|         dispatch = self._dispatch( | ||||
|             event, | ||||
|             context=context, | ||||
|             condition=condition, | ||||
|             fail_not_found=fail_not_found and inline, | ||||
|             reverse=reverse, | ||||
|         ) | ||||
|         logger.debug(f"Dispatching signal: {event}") | ||||
|  | ||||
|         if inline: | ||||
|             return await dispatch | ||||
|  | ||||
|         task = asyncio.get_running_loop().create_task(dispatch) | ||||
|         await asyncio.sleep(0) | ||||
|         return task | ||||
|  | ||||
| @@ -131,7 +188,9 @@ class SignalRouter(BaseRouter): | ||||
|             append=True, | ||||
|         )  # type: ignore | ||||
|  | ||||
|     def finalize(self, do_compile: bool = True): | ||||
|     def finalize(self, do_compile: bool = True, do_optimize: bool = False): | ||||
|         self.add(_blank, "sanic.__signal__.__init__") | ||||
|  | ||||
|         try: | ||||
|             self.ctx.loop = asyncio.get_running_loop() | ||||
|         except RuntimeError: | ||||
| @@ -140,7 +199,7 @@ class SignalRouter(BaseRouter): | ||||
|         for signal in self.routes: | ||||
|             signal.ctx.event = asyncio.Event() | ||||
|  | ||||
|         return super().finalize(do_compile=do_compile) | ||||
|         return super().finalize(do_compile=do_compile, do_optimize=do_optimize) | ||||
|  | ||||
|     def _build_event_parts(self, event: str) -> Tuple[str, str, str]: | ||||
|         parts = path_to_parts(event, self.delimiter) | ||||
| @@ -151,7 +210,11 @@ class SignalRouter(BaseRouter): | ||||
|         ): | ||||
|             raise InvalidSignal("Invalid signal event: %s" % event) | ||||
|  | ||||
|         if parts[0] in RESERVED_NAMESPACES: | ||||
|         if ( | ||||
|             parts[0] in RESERVED_NAMESPACES | ||||
|             and event not in RESERVED_NAMESPACES[parts[0]] | ||||
|             and not (parts[2].startswith("<") and parts[2].endswith(">")) | ||||
|         ): | ||||
|             raise InvalidSignal( | ||||
|                 "Cannot declare reserved signal event: %s" % event | ||||
|             ) | ||||
|   | ||||
							
								
								
									
										8
									
								
								sanic/touchup/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								sanic/touchup/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| from .meta import TouchUpMeta | ||||
| from .service import TouchUp | ||||
|  | ||||
|  | ||||
| __all__ = ( | ||||
|     "TouchUp", | ||||
|     "TouchUpMeta", | ||||
| ) | ||||
							
								
								
									
										22
									
								
								sanic/touchup/meta.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								sanic/touchup/meta.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| from sanic.exceptions import SanicException | ||||
|  | ||||
| from .service import TouchUp | ||||
|  | ||||
|  | ||||
| class TouchUpMeta(type): | ||||
|     def __new__(cls, name, bases, attrs, **kwargs): | ||||
|         gen_class = super().__new__(cls, name, bases, attrs, **kwargs) | ||||
|  | ||||
|         methods = attrs.get("__touchup__") | ||||
|         attrs["__touched__"] = False | ||||
|         if methods: | ||||
|  | ||||
|             for method in methods: | ||||
|                 if method not in attrs: | ||||
|                     raise SanicException( | ||||
|                         "Cannot perform touchup on non-existent method: " | ||||
|                         f"{name}.{method}" | ||||
|                     ) | ||||
|                 TouchUp.register(gen_class, method) | ||||
|  | ||||
|         return gen_class | ||||
							
								
								
									
										5
									
								
								sanic/touchup/schemes/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								sanic/touchup/schemes/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| from .base import BaseScheme | ||||
| from .ode import OptionalDispatchEvent  # noqa | ||||
|  | ||||
|  | ||||
| __all__ = ("BaseScheme",) | ||||
							
								
								
									
										20
									
								
								sanic/touchup/schemes/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								sanic/touchup/schemes/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Set, Type | ||||
|  | ||||
|  | ||||
| class BaseScheme(ABC): | ||||
|     ident: str | ||||
|     _registry: Set[Type] = set() | ||||
|  | ||||
|     def __init__(self, app) -> None: | ||||
|         self.app = app | ||||
|  | ||||
|     @abstractmethod | ||||
|     def run(self, method, module_globals) -> None: | ||||
|         ... | ||||
|  | ||||
|     def __init_subclass__(cls): | ||||
|         BaseScheme._registry.add(cls) | ||||
|  | ||||
|     def __call__(self, method, module_globals): | ||||
|         return self.run(method, module_globals) | ||||
							
								
								
									
										67
									
								
								sanic/touchup/schemes/ode.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								sanic/touchup/schemes/ode.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | ||||
| from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse | ||||
| from inspect import getsource | ||||
| from textwrap import dedent | ||||
| from typing import Any | ||||
|  | ||||
| from sanic.log import logger | ||||
|  | ||||
| from .base import BaseScheme | ||||
|  | ||||
|  | ||||
| class OptionalDispatchEvent(BaseScheme): | ||||
|     ident = "ODE" | ||||
|  | ||||
|     def __init__(self, app) -> None: | ||||
|         super().__init__(app) | ||||
|  | ||||
|         self._registered_events = [ | ||||
|             signal.path for signal in app.signal_router.routes | ||||
|         ] | ||||
|  | ||||
|     def run(self, method, module_globals): | ||||
|         raw_source = getsource(method) | ||||
|         src = dedent(raw_source) | ||||
|         tree = parse(src) | ||||
|         node = RemoveDispatch(self._registered_events).visit(tree) | ||||
|         compiled_src = compile(node, method.__name__, "exec") | ||||
|         exec_locals: Dict[str, Any] = {} | ||||
|         exec(compiled_src, module_globals, exec_locals)  # nosec | ||||
|  | ||||
|         return exec_locals[method.__name__] | ||||
|  | ||||
|  | ||||
| class RemoveDispatch(NodeTransformer): | ||||
|     def __init__(self, registered_events) -> None: | ||||
|         self._registered_events = registered_events | ||||
|  | ||||
|     def visit_Expr(self, node: Expr) -> Any: | ||||
|         call = node.value | ||||
|         if isinstance(call, Await): | ||||
|             call = call.value | ||||
|  | ||||
|         func = getattr(call, "func", None) | ||||
|         args = getattr(call, "args", None) | ||||
|         if not func or not args: | ||||
|             return node | ||||
|  | ||||
|         if isinstance(func, Attribute) and func.attr == "dispatch": | ||||
|             event = args[0] | ||||
|             if hasattr(event, "s"): | ||||
|                 event_name = getattr(event, "value", event.s) | ||||
|                 if self._not_registered(event_name): | ||||
|                     logger.debug(f"Disabling event: {event_name}") | ||||
|                     return None | ||||
|         return node | ||||
|  | ||||
|     def _not_registered(self, event_name): | ||||
|         dynamic = [] | ||||
|         for event in self._registered_events: | ||||
|             if event.endswith(">"): | ||||
|                 namespace_concern, _ = event.rsplit(".", 1) | ||||
|                 dynamic.append(namespace_concern) | ||||
|  | ||||
|         namespace_concern, _ = event_name.rsplit(".", 1) | ||||
|         return ( | ||||
|             event_name not in self._registered_events | ||||
|             and namespace_concern not in dynamic | ||||
|         ) | ||||
							
								
								
									
										33
									
								
								sanic/touchup/service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								sanic/touchup/service.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| from inspect import getmembers, getmodule | ||||
| from typing import Set, Tuple, Type | ||||
|  | ||||
| from .schemes import BaseScheme | ||||
|  | ||||
|  | ||||
| class TouchUp: | ||||
|     _registry: Set[Tuple[Type, str]] = set() | ||||
|  | ||||
|     @classmethod | ||||
|     def run(cls, app): | ||||
|         for target, method_name in cls._registry: | ||||
|             method = getattr(target, method_name) | ||||
|  | ||||
|             if app.test_mode: | ||||
|                 placeholder = f"_{method_name}" | ||||
|                 if hasattr(target, placeholder): | ||||
|                     method = getattr(target, placeholder) | ||||
|                 else: | ||||
|                     setattr(target, placeholder, method) | ||||
|  | ||||
|             module = getmodule(target) | ||||
|             module_globals = dict(getmembers(module)) | ||||
|  | ||||
|             for scheme in BaseScheme._registry: | ||||
|                 modified = scheme(app)(method, module_globals) | ||||
|                 setattr(target, method_name, modified) | ||||
|  | ||||
|             target.__touched__ = True | ||||
|  | ||||
|     @classmethod | ||||
|     def register(cls, target, method_name): | ||||
|         cls._registry.add((target, method_name)) | ||||
| @@ -13,6 +13,7 @@ from warnings import warn | ||||
|  | ||||
| from sanic.constants import HTTP_METHODS | ||||
| from sanic.exceptions import InvalidUsage | ||||
| from sanic.models.handler_types import RouteHandler | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| @@ -86,7 +87,7 @@ class HTTPMethodView: | ||||
|         return handler(request, *args, **kwargs) | ||||
|  | ||||
|     @classmethod | ||||
|     def as_view(cls, *class_args, **class_kwargs): | ||||
|     def as_view(cls, *class_args: Any, **class_kwargs: Any) -> RouteHandler: | ||||
|         """Return view function for use with the routing system, that | ||||
|         dispatches request to appropriate handler method. | ||||
|         """ | ||||
| @@ -100,7 +101,7 @@ class HTTPMethodView: | ||||
|             for decorator in cls.decorators: | ||||
|                 view = decorator(view) | ||||
|  | ||||
|         view.view_class = cls | ||||
|         view.view_class = cls  # type: ignore | ||||
|         view.__doc__ = cls.__doc__ | ||||
|         view.__module__ = cls.__module__ | ||||
|         view.__name__ = cls.__name__ | ||||
|   | ||||
| @@ -1,205 +0,0 @@ | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Awaitable, | ||||
|     Callable, | ||||
|     Dict, | ||||
|     List, | ||||
|     MutableMapping, | ||||
|     Optional, | ||||
|     Union, | ||||
| ) | ||||
|  | ||||
| from httptools import HttpParserUpgrade  # type: ignore | ||||
| from websockets import (  # type: ignore | ||||
|     ConnectionClosed, | ||||
|     InvalidHandshake, | ||||
|     WebSocketCommonProtocol, | ||||
| ) | ||||
|  | ||||
| # Despite the "legacy" namespace, the primary maintainer of websockets | ||||
| # committed to maintaining backwards-compatibility until 2026 and will | ||||
| # consider extending it if sanic continues depending on this module. | ||||
| from websockets.legacy import handshake | ||||
|  | ||||
| from sanic.exceptions import InvalidUsage | ||||
| from sanic.server import HttpProtocol | ||||
|  | ||||
|  | ||||
| __all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"] | ||||
|  | ||||
| ASIMessage = MutableMapping[str, Any] | ||||
|  | ||||
|  | ||||
| class WebSocketProtocol(HttpProtocol): | ||||
|     def __init__( | ||||
|         self, | ||||
|         *args, | ||||
|         websocket_timeout=10, | ||||
|         websocket_max_size=None, | ||||
|         websocket_max_queue=None, | ||||
|         websocket_read_limit=2 ** 16, | ||||
|         websocket_write_limit=2 ** 16, | ||||
|         websocket_ping_interval=20, | ||||
|         websocket_ping_timeout=20, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.websocket = None | ||||
|         # self.app = None | ||||
|         self.websocket_timeout = websocket_timeout | ||||
|         self.websocket_max_size = websocket_max_size | ||||
|         self.websocket_max_queue = websocket_max_queue | ||||
|         self.websocket_read_limit = websocket_read_limit | ||||
|         self.websocket_write_limit = websocket_write_limit | ||||
|         self.websocket_ping_interval = websocket_ping_interval | ||||
|         self.websocket_ping_timeout = websocket_ping_timeout | ||||
|  | ||||
|     # timeouts make no sense for websocket routes | ||||
|     def request_timeout_callback(self): | ||||
|         if self.websocket is None: | ||||
|             super().request_timeout_callback() | ||||
|  | ||||
|     def response_timeout_callback(self): | ||||
|         if self.websocket is None: | ||||
|             super().response_timeout_callback() | ||||
|  | ||||
|     def keep_alive_timeout_callback(self): | ||||
|         if self.websocket is None: | ||||
|             super().keep_alive_timeout_callback() | ||||
|  | ||||
|     def connection_lost(self, exc): | ||||
|         if self.websocket is not None: | ||||
|             self.websocket.connection_lost(exc) | ||||
|         super().connection_lost(exc) | ||||
|  | ||||
|     def data_received(self, data): | ||||
|         if self.websocket is not None: | ||||
|             # pass the data to the websocket protocol | ||||
|             self.websocket.data_received(data) | ||||
|         else: | ||||
|             try: | ||||
|                 super().data_received(data) | ||||
|             except HttpParserUpgrade: | ||||
|                 # this is okay, it just indicates we've got an upgrade request | ||||
|                 pass | ||||
|  | ||||
|     def write_response(self, response): | ||||
|         if self.websocket is not None: | ||||
|             # websocket requests do not write a response | ||||
|             self.transport.close() | ||||
|         else: | ||||
|             super().write_response(response) | ||||
|  | ||||
|     async def websocket_handshake(self, request, subprotocols=None): | ||||
|         # let the websockets package do the handshake with the client | ||||
|         headers = {} | ||||
|  | ||||
|         try: | ||||
|             key = handshake.check_request(request.headers) | ||||
|             handshake.build_response(headers, key) | ||||
|         except InvalidHandshake: | ||||
|             raise InvalidUsage("Invalid websocket request") | ||||
|  | ||||
|         subprotocol = None | ||||
|         if subprotocols and "Sec-Websocket-Protocol" in request.headers: | ||||
|             # select a subprotocol | ||||
|             client_subprotocols = [ | ||||
|                 p.strip() | ||||
|                 for p in request.headers["Sec-Websocket-Protocol"].split(",") | ||||
|             ] | ||||
|             for p in client_subprotocols: | ||||
|                 if p in subprotocols: | ||||
|                     subprotocol = p | ||||
|                     headers["Sec-Websocket-Protocol"] = subprotocol | ||||
|                     break | ||||
|  | ||||
|         # write the 101 response back to the client | ||||
|         rv = b"HTTP/1.1 101 Switching Protocols\r\n" | ||||
|         for k, v in headers.items(): | ||||
|             rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" | ||||
|         rv += b"\r\n" | ||||
|         request.transport.write(rv) | ||||
|  | ||||
|         # hook up the websocket protocol | ||||
|         self.websocket = WebSocketCommonProtocol( | ||||
|             close_timeout=self.websocket_timeout, | ||||
|             max_size=self.websocket_max_size, | ||||
|             max_queue=self.websocket_max_queue, | ||||
|             read_limit=self.websocket_read_limit, | ||||
|             write_limit=self.websocket_write_limit, | ||||
|             ping_interval=self.websocket_ping_interval, | ||||
|             ping_timeout=self.websocket_ping_timeout, | ||||
|         ) | ||||
|         # we use WebSocketCommonProtocol because we don't want the handshake | ||||
|         # logic from WebSocketServerProtocol; however, we must tell it that | ||||
|         # we're running on the server side | ||||
|         self.websocket.is_client = False | ||||
|         self.websocket.side = "server" | ||||
|         self.websocket.subprotocol = subprotocol | ||||
|         self.websocket.connection_made(request.transport) | ||||
|         self.websocket.connection_open() | ||||
|         return self.websocket | ||||
|  | ||||
|  | ||||
| class WebSocketConnection: | ||||
|  | ||||
|     # TODO | ||||
|     # - Implement ping/pong | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         send: Callable[[ASIMessage], Awaitable[None]], | ||||
|         receive: Callable[[], Awaitable[ASIMessage]], | ||||
|         subprotocols: Optional[List[str]] = None, | ||||
|     ) -> None: | ||||
|         self._send = send | ||||
|         self._receive = receive | ||||
|         self._subprotocols = subprotocols or [] | ||||
|  | ||||
|     async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: | ||||
|         message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} | ||||
|  | ||||
|         if isinstance(data, bytes): | ||||
|             message.update({"bytes": data}) | ||||
|         else: | ||||
|             message.update({"text": str(data)}) | ||||
|  | ||||
|         await self._send(message) | ||||
|  | ||||
|     async def recv(self, *args, **kwargs) -> Optional[str]: | ||||
|         message = await self._receive() | ||||
|  | ||||
|         if message["type"] == "websocket.receive": | ||||
|             return message["text"] | ||||
|         elif message["type"] == "websocket.disconnect": | ||||
|             pass | ||||
|  | ||||
|         return None | ||||
|  | ||||
|     receive = recv | ||||
|  | ||||
|     async def accept(self, subprotocols: Optional[List[str]] = None) -> None: | ||||
|         subprotocol = None | ||||
|         if subprotocols: | ||||
|             for subp in subprotocols: | ||||
|                 if subp in self.subprotocols: | ||||
|                     subprotocol = subp | ||||
|                     break | ||||
|  | ||||
|         await self._send( | ||||
|             { | ||||
|                 "type": "websocket.accept", | ||||
|                 "subprotocol": subprotocol, | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     def subprotocols(self): | ||||
|         return self._subprotocols | ||||
|  | ||||
|     @subprotocols.setter | ||||
|     def subprotocols(self, subprotocols: Optional[List[str]] = None): | ||||
|         self._subprotocols = subprotocols or [] | ||||
| @@ -8,8 +8,8 @@ import traceback | ||||
| from gunicorn.workers import base  # type: ignore | ||||
|  | ||||
| from sanic.log import logger | ||||
| from sanic.server import HttpProtocol, Signal, serve, trigger_events | ||||
| from sanic.websocket import WebSocketProtocol | ||||
| from sanic.server import HttpProtocol, Signal, serve | ||||
| from sanic.server.protocols.websocket_protocol import WebSocketProtocol | ||||
|  | ||||
|  | ||||
| try: | ||||
| @@ -68,10 +68,10 @@ class GunicornWorker(base.Worker): | ||||
|         ) | ||||
|         self._server_settings["signal"] = self.signal | ||||
|         self._server_settings.pop("sock") | ||||
|         trigger_events( | ||||
|             self._server_settings.get("before_start", []), self.loop | ||||
|         self._await(self.app.callable._startup()) | ||||
|         self._await( | ||||
|             self.app.callable._server_event("init", "before", loop=self.loop) | ||||
|         ) | ||||
|         self._server_settings["before_start"] = () | ||||
|  | ||||
|         main_start = self._server_settings.pop("main_start", None) | ||||
|         main_stop = self._server_settings.pop("main_stop", None) | ||||
| @@ -82,24 +82,29 @@ class GunicornWorker(base.Worker): | ||||
|                 "with GunicornWorker" | ||||
|             ) | ||||
|  | ||||
|         self._runner = asyncio.ensure_future(self._run(), loop=self.loop) | ||||
|         try: | ||||
|             self.loop.run_until_complete(self._runner) | ||||
|             self._await(self._run()) | ||||
|             self.app.callable.is_running = True | ||||
|             trigger_events( | ||||
|                 self._server_settings.get("after_start", []), self.loop | ||||
|             self._await( | ||||
|                 self.app.callable._server_event( | ||||
|                     "init", "after", loop=self.loop | ||||
|                 ) | ||||
|             ) | ||||
|             self.loop.run_until_complete(self._check_alive()) | ||||
|             trigger_events( | ||||
|                 self._server_settings.get("before_stop", []), self.loop | ||||
|             self._await( | ||||
|                 self.app.callable._server_event( | ||||
|                     "shutdown", "before", loop=self.loop | ||||
|                 ) | ||||
|             ) | ||||
|             self.loop.run_until_complete(self.close()) | ||||
|         except BaseException: | ||||
|             traceback.print_exc() | ||||
|         finally: | ||||
|             try: | ||||
|                 trigger_events( | ||||
|                     self._server_settings.get("after_stop", []), self.loop | ||||
|                 self._await( | ||||
|                     self.app.callable._server_event( | ||||
|                         "shutdown", "after", loop=self.loop | ||||
|                     ) | ||||
|                 ) | ||||
|             except BaseException: | ||||
|                 traceback.print_exc() | ||||
| @@ -137,14 +142,11 @@ class GunicornWorker(base.Worker): | ||||
|  | ||||
|             # Force close non-idle connection after waiting for | ||||
|             # graceful_shutdown_timeout | ||||
|             coros = [] | ||||
|             for conn in self.connections: | ||||
|                 if hasattr(conn, "websocket") and conn.websocket: | ||||
|                     coros.append(conn.websocket.close_connection()) | ||||
|                     conn.websocket.fail_connection(code=1001) | ||||
|                 else: | ||||
|                     conn.close() | ||||
|             _shutdown = asyncio.gather(*coros, loop=self.loop) | ||||
|             await _shutdown | ||||
|                     conn.abort() | ||||
|  | ||||
|     async def _run(self): | ||||
|         for sock in self.sockets: | ||||
| @@ -238,3 +240,7 @@ class GunicornWorker(base.Worker): | ||||
|         self.exit_code = 1 | ||||
|         self.cfg.worker_abort(self) | ||||
|         sys.exit(1) | ||||
|  | ||||
|     def _await(self, coro): | ||||
|         fut = asyncio.ensure_future(coro, loop=self.loop) | ||||
|         self.loop.run_until_complete(fut) | ||||
|   | ||||
							
								
								
									
										33
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								setup.py
									
									
									
									
									
								
							| @@ -81,60 +81,63 @@ env_dependency = ( | ||||
| ) | ||||
| ujson = "ujson>=1.35" + env_dependency | ||||
| uvloop = "uvloop>=0.5.3" + env_dependency | ||||
|  | ||||
| types_ujson = "types-ujson" + env_dependency | ||||
| requirements = [ | ||||
|     "sanic-routing~=0.7", | ||||
|     "httptools>=0.0.10", | ||||
|     uvloop, | ||||
|     ujson, | ||||
|     "aiofiles>=0.6.0", | ||||
|     "websockets>=9.0", | ||||
|     "websockets>=10.0", | ||||
|     "multidict>=5.0,<6.0", | ||||
| ] | ||||
|  | ||||
| tests_require = [ | ||||
|     "sanic-testing>=0.7.0b1", | ||||
|     "sanic-testing>=0.7.0", | ||||
|     "pytest==5.2.1", | ||||
|     "multidict>=5.0,<6.0", | ||||
|     "coverage==5.3", | ||||
|     "gunicorn==20.0.4", | ||||
|     "pytest-cov", | ||||
|     "beautifulsoup4", | ||||
|     uvloop, | ||||
|     ujson, | ||||
|     "pytest-sanic", | ||||
|     "pytest-sugar", | ||||
|     "pytest-benchmark", | ||||
|     "chardet==3.*", | ||||
|     "flake8", | ||||
|     "black", | ||||
|     "isort>=5.0.0", | ||||
|     "bandit", | ||||
|     "mypy>=0.901", | ||||
|     "docutils", | ||||
|     "pygments", | ||||
|     "uvicorn<0.15.0", | ||||
|     types_ujson, | ||||
| ] | ||||
|  | ||||
| docs_require = [ | ||||
|     "sphinx>=2.1.2", | ||||
|     "sphinx_rtd_theme", | ||||
|     "recommonmark>=0.5.0", | ||||
|     "sphinx_rtd_theme>=0.4.3", | ||||
|     "docutils", | ||||
|     "pygments", | ||||
|     "m2r2", | ||||
| ] | ||||
|  | ||||
| dev_require = tests_require + [ | ||||
|     "aiofiles", | ||||
|     "tox", | ||||
|     "black", | ||||
|     "flake8", | ||||
|     "bandit", | ||||
|     "towncrier", | ||||
| ] | ||||
|  | ||||
| all_require = dev_require + docs_require | ||||
| all_require = list(set(dev_require + docs_require)) | ||||
|  | ||||
| if strtobool(os.environ.get("SANIC_NO_UJSON", "no")): | ||||
|     print("Installing without uJSON") | ||||
|     requirements.remove(ujson) | ||||
|     tests_require.remove(ujson) | ||||
|     tests_require.remove(types_ujson) | ||||
|  | ||||
| # 'nt' means windows OS | ||||
| if strtobool(os.environ.get("SANIC_NO_UVLOOP", "no")): | ||||
|     print("Installing without uvLoop") | ||||
|     requirements.remove(uvloop) | ||||
|     tests_require.remove(uvloop) | ||||
|  | ||||
| extras_require = { | ||||
|     "test": tests_require, | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| import random | ||||
| import re | ||||
| import string | ||||
| @@ -9,10 +11,12 @@ from typing import Tuple | ||||
| import pytest | ||||
|  | ||||
| from sanic_routing.exceptions import RouteExists | ||||
| from sanic_testing.testing import PORT | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.constants import HTTP_METHODS | ||||
| from sanic.router import Router | ||||
| from sanic.touchup.service import TouchUp | ||||
|  | ||||
|  | ||||
| slugify = re.compile(r"[^a-zA-Z0-9_\-]") | ||||
| @@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]: | ||||
|     collect_ignore = ["test_worker.py"] | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def caplog(caplog): | ||||
|     yield caplog | ||||
|  | ||||
|  | ||||
| async def _handler(request): | ||||
|     """ | ||||
|     Dummy placeholder method used for route resolver when creating a new | ||||
| @@ -41,33 +40,32 @@ async def _handler(request): | ||||
|  | ||||
|  | ||||
| TYPE_TO_GENERATOR_MAP = { | ||||
|     "string": lambda: "".join( | ||||
|     "str": lambda: "".join( | ||||
|         [random.choice(string.ascii_lowercase) for _ in range(4)] | ||||
|     ), | ||||
|     "int": lambda: random.choice(range(1000000)), | ||||
|     "number": lambda: random.random(), | ||||
|     "float": lambda: random.random(), | ||||
|     "alpha": lambda: "".join( | ||||
|         [random.choice(string.ascii_lowercase) for _ in range(4)] | ||||
|     ), | ||||
|     "uuid": lambda: str(uuid.uuid1()), | ||||
| } | ||||
|  | ||||
| CACHE = {} | ||||
|  | ||||
|  | ||||
| class RouteStringGenerator: | ||||
|  | ||||
|     ROUTE_COUNT_PER_DEPTH = 100 | ||||
|     HTTP_METHODS = HTTP_METHODS | ||||
|     ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"] | ||||
|     ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"] | ||||
|  | ||||
|     def generate_random_direct_route(self, max_route_depth=4): | ||||
|         routes = [] | ||||
|         for depth in range(1, max_route_depth + 1): | ||||
|             for _ in range(self.ROUTE_COUNT_PER_DEPTH): | ||||
|                 route = "/".join( | ||||
|                     [ | ||||
|                         TYPE_TO_GENERATOR_MAP.get("string")() | ||||
|                         for _ in range(depth) | ||||
|                     ] | ||||
|                     [TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)] | ||||
|                 ) | ||||
|                 route = route.replace(".", "", -1) | ||||
|                 route_detail = (random.choice(self.HTTP_METHODS), route) | ||||
| @@ -83,7 +81,7 @@ class RouteStringGenerator: | ||||
|             new_route_part = "/".join( | ||||
|                 [ | ||||
|                     "<{}:{}>".format( | ||||
|                         TYPE_TO_GENERATOR_MAP.get("string")(), | ||||
|                         TYPE_TO_GENERATOR_MAP.get("str")(), | ||||
|                         random.choice(self.ROUTE_PARAM_TYPES), | ||||
|                     ) | ||||
|                     for _ in range(max_route_depth - current_length) | ||||
| @@ -98,7 +96,7 @@ class RouteStringGenerator: | ||||
|     def generate_url_for_template(template): | ||||
|         url = template | ||||
|         for pattern, param_type in re.findall( | ||||
|             re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"), | ||||
|             re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"), | ||||
|             template, | ||||
|         ): | ||||
|             value = TYPE_TO_GENERATOR_MAP.get(param_type)() | ||||
| @@ -111,6 +109,7 @@ def sanic_router(app): | ||||
|     # noinspection PyProtectedMember | ||||
|     def _setup(route_details: tuple) -> Tuple[Router, tuple]: | ||||
|         router = Router() | ||||
|         router.ctx.app = app | ||||
|         added_router = [] | ||||
|         for method, route in route_details: | ||||
|             try: | ||||
| @@ -141,5 +140,33 @@ def url_param_generator(): | ||||
|  | ||||
| @pytest.fixture(scope="function") | ||||
| def app(request): | ||||
|     if not CACHE: | ||||
|         for target, method_name in TouchUp._registry: | ||||
|             CACHE[method_name] = getattr(target, method_name) | ||||
|     app = Sanic(slugify.sub("-", request.node.name)) | ||||
|     return app | ||||
|     yield app | ||||
|     for target, method_name in TouchUp._registry: | ||||
|         setattr(target, method_name, CACHE[method_name]) | ||||
|  | ||||
|  | ||||
| @pytest.fixture(scope="function") | ||||
| def run_startup(caplog): | ||||
|     def run(app): | ||||
|         nonlocal caplog | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         with caplog.at_level(logging.DEBUG): | ||||
|             server = app.create_server( | ||||
|                 debug=True, return_asyncio_server=True, port=PORT | ||||
|             ) | ||||
|             loop._stopping = False | ||||
|  | ||||
|             _server = loop.run_until_complete(server) | ||||
|  | ||||
|             _server.close() | ||||
|             loop.run_until_complete(_server.wait_closed()) | ||||
|             app.stop() | ||||
|  | ||||
|         return caplog.record_tuples | ||||
|  | ||||
|     return run | ||||
|   | ||||
| @@ -178,9 +178,6 @@ def test_app_enable_websocket(app, websocket_enabled, enable): | ||||
| @patch("sanic.app.WebSocketProtocol") | ||||
| def test_app_websocket_parameters(websocket_protocol_mock, app): | ||||
|     app.config.WEBSOCKET_MAX_SIZE = 44 | ||||
|     app.config.WEBSOCKET_MAX_QUEUE = 45 | ||||
|     app.config.WEBSOCKET_READ_LIMIT = 46 | ||||
|     app.config.WEBSOCKET_WRITE_LIMIT = 47 | ||||
|     app.config.WEBSOCKET_PING_TIMEOUT = 48 | ||||
|     app.config.WEBSOCKET_PING_INTERVAL = 50 | ||||
|  | ||||
| @@ -197,11 +194,6 @@ def test_app_websocket_parameters(websocket_protocol_mock, app): | ||||
|     websocket_protocol_call_args = websocket_protocol_mock.call_args | ||||
|     ws_kwargs = websocket_protocol_call_args[1] | ||||
|     assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE | ||||
|     assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE | ||||
|     assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT | ||||
|     assert ( | ||||
|         ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT | ||||
|     ) | ||||
|     assert ( | ||||
|         ws_kwargs["websocket_ping_timeout"] | ||||
|         == app.config.WEBSOCKET_PING_TIMEOUT | ||||
| @@ -396,7 +388,7 @@ def test_app_set_attribute_warning(app): | ||||
|     assert len(record) == 1 | ||||
|     assert record[0].message.args[0] == ( | ||||
|         "Setting variables on Sanic instances is deprecated " | ||||
|         "and will be removed in version 21.9. You should change your " | ||||
|         "and will be removed in version 21.12. You should change your " | ||||
|         "Sanic instance to use instance.ctx.foo instead." | ||||
|     ) | ||||
|  | ||||
|   | ||||
| @@ -7,10 +7,10 @@ import uvicorn | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.asgi import MockTransport | ||||
| from sanic.exceptions import InvalidUsage | ||||
| from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable | ||||
| from sanic.request import Request | ||||
| from sanic.response import json, text | ||||
| from sanic.websocket import WebSocketConnection | ||||
| from sanic.server.websockets.connection import WebSocketConnection | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @@ -346,3 +346,33 @@ async def test_content_type(app): | ||||
|  | ||||
|     _, response = await app.asgi_client.get("/custom") | ||||
|     assert response.headers.get("content-type") == "somethingelse" | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_request_handle_exception(app): | ||||
|     @app.get("/error-prone") | ||||
|     def _request(request): | ||||
|         raise ServiceUnavailable(message="Service unavailable") | ||||
|  | ||||
|     _, response = await app.asgi_client.get("/wrong-path") | ||||
|     assert response.status_code == 404 | ||||
|  | ||||
|     _, response = await app.asgi_client.get("/error-prone") | ||||
|     assert response.status_code == 503 | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_request_exception_suppressed_by_middleware(app): | ||||
|     @app.get("/error-prone") | ||||
|     def _request(request): | ||||
|         raise ServiceUnavailable(message="Service unavailable") | ||||
|  | ||||
|     @app.on_request | ||||
|     def forbidden(request): | ||||
|         raise Forbidden(message="forbidden") | ||||
|  | ||||
|     _, response = await app.asgi_client.get("/wrong-path") | ||||
|     assert response.status_code == 403 | ||||
|  | ||||
|     _, response = await app.asgi_client.get("/error-prone") | ||||
|     assert response.status_code == 403 | ||||
|   | ||||
| @@ -20,4 +20,4 @@ def test_bad_request_response(app): | ||||
|  | ||||
|     app.run(host="127.0.0.1", port=42101, debug=False) | ||||
|     assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n" | ||||
|     assert b"Bad Request" in lines[-1] | ||||
|     assert b"Bad Request" in lines[-2] | ||||
|   | ||||
							
								
								
									
										70
									
								
								tests/test_blueprint_copy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								tests/test_blueprint_copy.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | ||||
| from copy import deepcopy | ||||
|  | ||||
| from sanic import Blueprint, Sanic, blueprints, response | ||||
| from sanic.response import text | ||||
|  | ||||
|  | ||||
| def test_bp_copy(app: Sanic): | ||||
|     bp1 = Blueprint("test_bp1", version=1) | ||||
|     bp1.ctx.test = 1 | ||||
|     assert hasattr(bp1.ctx, "test") | ||||
|  | ||||
|     @bp1.route("/page") | ||||
|     def handle_request(request): | ||||
|         return text("Hello world!") | ||||
|  | ||||
|     bp2 = bp1.copy(name="test_bp2", version=2) | ||||
|     assert id(bp1) != id(bp2) | ||||
|     assert bp1._apps == bp2._apps == set() | ||||
|     assert not hasattr(bp2.ctx, "test") | ||||
|     assert len(bp2._future_exceptions) == len(bp1._future_exceptions) | ||||
|     assert len(bp2._future_listeners) == len(bp1._future_listeners) | ||||
|     assert len(bp2._future_middleware) == len(bp1._future_middleware) | ||||
|     assert len(bp2._future_routes) == len(bp1._future_routes) | ||||
|     assert len(bp2._future_signals) == len(bp1._future_signals) | ||||
|  | ||||
|     app.blueprint(bp1) | ||||
|     app.blueprint(bp2) | ||||
|  | ||||
|     bp3 = bp1.copy(name="test_bp3", version=3, with_registration=True) | ||||
|     assert id(bp1) != id(bp3) | ||||
|     assert bp1._apps == bp3._apps and bp3._apps | ||||
|     assert not hasattr(bp3.ctx, "test") | ||||
|  | ||||
|     bp4 = bp1.copy(name="test_bp4", version=4, with_ctx=True) | ||||
|     assert id(bp1) != id(bp4) | ||||
|     assert bp4.ctx.test == 1 | ||||
|  | ||||
|     bp5 = bp1.copy(name="test_bp5", version=5, with_registration=False) | ||||
|     assert id(bp1) != id(bp5) | ||||
|     assert not bp5._apps | ||||
|     assert bp1._apps != set() | ||||
|  | ||||
|     app.blueprint(bp5) | ||||
|  | ||||
|     bp6 = bp1.copy( | ||||
|         name="test_bp6", | ||||
|         version=6, | ||||
|         with_registration=True, | ||||
|         version_prefix="/version", | ||||
|     ) | ||||
|     assert bp6._apps | ||||
|     assert bp6.version_prefix == "/version" | ||||
|  | ||||
|     _, response = app.test_client.get("/v1/page") | ||||
|     assert "Hello world!" in response.text | ||||
|  | ||||
|     _, response = app.test_client.get("/v2/page") | ||||
|     assert "Hello world!" in response.text | ||||
|  | ||||
|     _, response = app.test_client.get("/v3/page") | ||||
|     assert "Hello world!" in response.text | ||||
|  | ||||
|     _, response = app.test_client.get("/v4/page") | ||||
|     assert "Hello world!" in response.text | ||||
|  | ||||
|     _, response = app.test_client.get("/v5/page") | ||||
|     assert "Hello world!" in response.text | ||||
|  | ||||
|     _, response = app.test_client.get("/version6/page") | ||||
|     assert "Hello world!" in response.text | ||||
| @@ -3,6 +3,12 @@ from pytest import raises | ||||
| from sanic.app import Sanic | ||||
| from sanic.blueprint_group import BlueprintGroup | ||||
| from sanic.blueprints import Blueprint | ||||
| from sanic.exceptions import ( | ||||
|     Forbidden, | ||||
|     InvalidUsage, | ||||
|     SanicException, | ||||
|     ServerError, | ||||
| ) | ||||
| from sanic.request import Request | ||||
| from sanic.response import HTTPResponse, text | ||||
|  | ||||
| @@ -96,16 +102,28 @@ def test_bp_group(app: Sanic): | ||||
|     def blueprint_1_default_route(request): | ||||
|         return text("BP1_OK") | ||||
|  | ||||
|     @blueprint_1.route("/invalid") | ||||
|     def blueprint_1_error(request: Request): | ||||
|         raise InvalidUsage("Invalid") | ||||
|  | ||||
|     @blueprint_2.route("/") | ||||
|     def blueprint_2_default_route(request): | ||||
|         return text("BP2_OK") | ||||
|  | ||||
|     @blueprint_2.route("/error") | ||||
|     def blueprint_2_error(request: Request): | ||||
|         raise ServerError("Error") | ||||
|  | ||||
|     blueprint_group_1 = Blueprint.group( | ||||
|         blueprint_1, blueprint_2, url_prefix="/bp" | ||||
|     ) | ||||
|  | ||||
|     blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3") | ||||
|  | ||||
|     @blueprint_group_1.exception(InvalidUsage) | ||||
|     def handle_group_exception(request, exception): | ||||
|         return text("BP1_ERR_OK") | ||||
|  | ||||
|     @blueprint_group_1.middleware("request") | ||||
|     def blueprint_group_1_middleware(request): | ||||
|         global MIDDLEWARE_INVOKE_COUNTER | ||||
| @@ -116,19 +134,47 @@ def test_bp_group(app: Sanic): | ||||
|         global MIDDLEWARE_INVOKE_COUNTER | ||||
|         MIDDLEWARE_INVOKE_COUNTER["request"] += 1 | ||||
|  | ||||
|     @blueprint_group_1.on_request | ||||
|     def blueprint_group_1_convenience_1(request): | ||||
|         global MIDDLEWARE_INVOKE_COUNTER | ||||
|         MIDDLEWARE_INVOKE_COUNTER["request"] += 1 | ||||
|  | ||||
|     @blueprint_group_1.on_request() | ||||
|     def blueprint_group_1_convenience_2(request): | ||||
|         global MIDDLEWARE_INVOKE_COUNTER | ||||
|         MIDDLEWARE_INVOKE_COUNTER["request"] += 1 | ||||
|  | ||||
|     @blueprint_3.route("/") | ||||
|     def blueprint_3_default_route(request): | ||||
|         return text("BP3_OK") | ||||
|  | ||||
|     @blueprint_3.route("/forbidden") | ||||
|     def blueprint_3_forbidden(request: Request): | ||||
|         raise Forbidden("Forbidden") | ||||
|  | ||||
|     blueprint_group_2 = Blueprint.group( | ||||
|         blueprint_group_1, blueprint_3, url_prefix="/api" | ||||
|     ) | ||||
|  | ||||
|     @blueprint_group_2.exception(SanicException) | ||||
|     def handle_non_handled_exception(request, exception): | ||||
|         return text("BP2_ERR_OK") | ||||
|  | ||||
|     @blueprint_group_2.middleware("response") | ||||
|     def blueprint_group_2_middleware(request, response): | ||||
|         global MIDDLEWARE_INVOKE_COUNTER | ||||
|         MIDDLEWARE_INVOKE_COUNTER["response"] += 1 | ||||
|  | ||||
|     @blueprint_group_2.on_response | ||||
|     def blueprint_group_2_middleware_convenience_1(request, response): | ||||
|         global MIDDLEWARE_INVOKE_COUNTER | ||||
|         MIDDLEWARE_INVOKE_COUNTER["response"] += 1 | ||||
|  | ||||
|     @blueprint_group_2.on_response() | ||||
|     def blueprint_group_2_middleware_convenience_2(request, response): | ||||
|         global MIDDLEWARE_INVOKE_COUNTER | ||||
|         MIDDLEWARE_INVOKE_COUNTER["response"] += 1 | ||||
|  | ||||
|     app.blueprint(blueprint_group_2) | ||||
|  | ||||
|     @app.route("/") | ||||
| @@ -141,14 +187,23 @@ def test_bp_group(app: Sanic): | ||||
|     _, response = app.test_client.get("/api/bp/bp1") | ||||
|     assert response.text == "BP1_OK" | ||||
|  | ||||
|     _, response = app.test_client.get("/api/bp/bp1/invalid") | ||||
|     assert response.text == "BP1_ERR_OK" | ||||
|  | ||||
|     _, response = app.test_client.get("/api/bp/bp2") | ||||
|     assert response.text == "BP2_OK" | ||||
|  | ||||
|     _, response = app.test_client.get("/api/bp/bp2/error") | ||||
|     assert response.text == "BP2_ERR_OK" | ||||
|  | ||||
|     _, response = app.test_client.get("/api/bp3") | ||||
|     assert response.text == "BP3_OK" | ||||
|  | ||||
|     assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3 | ||||
|     assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4 | ||||
|     _, response = app.test_client.get("/api/bp3/forbidden") | ||||
|     assert response.text == "BP2_ERR_OK" | ||||
|  | ||||
|     assert MIDDLEWARE_INVOKE_COUNTER["response"] == 18 | ||||
|     assert MIDDLEWARE_INVOKE_COUNTER["request"] == 16 | ||||
|  | ||||
|  | ||||
| def test_bp_group_list_operations(app: Sanic): | ||||
|   | ||||
| @@ -83,7 +83,6 @@ def test_versioned_routes_get(app, method): | ||||
|             return text("OK") | ||||
|  | ||||
|     else: | ||||
|         print(func) | ||||
|         raise Exception(f"{func} is not callable") | ||||
|  | ||||
|     app.blueprint(bp) | ||||
| @@ -477,6 +476,58 @@ def test_bp_exception_handler(app): | ||||
|     assert response.status == 200 | ||||
|  | ||||
|  | ||||
| def test_bp_exception_handler_applied(app): | ||||
|     class Error(Exception): | ||||
|         pass | ||||
|  | ||||
|     handled = Blueprint("handled") | ||||
|     nothandled = Blueprint("nothandled") | ||||
|  | ||||
|     @handled.exception(Error) | ||||
|     def handle_error(req, e): | ||||
|         return text("handled {}".format(e)) | ||||
|  | ||||
|     @handled.route("/ok") | ||||
|     def ok(request): | ||||
|         raise Error("uh oh") | ||||
|  | ||||
|     @nothandled.route("/notok") | ||||
|     def notok(request): | ||||
|         raise Error("uh oh") | ||||
|  | ||||
|     app.blueprint(handled) | ||||
|     app.blueprint(nothandled) | ||||
|  | ||||
|     _, response = app.test_client.get("/ok") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "handled uh oh" | ||||
|  | ||||
|     _, response = app.test_client.get("/notok") | ||||
|     assert response.status == 500 | ||||
|  | ||||
|  | ||||
| def test_bp_exception_handler_not_applied(app): | ||||
|     class Error(Exception): | ||||
|         pass | ||||
|  | ||||
|     handled = Blueprint("handled") | ||||
|     nothandled = Blueprint("nothandled") | ||||
|  | ||||
|     @handled.exception(Error) | ||||
|     def handle_error(req, e): | ||||
|         return text("handled {}".format(e)) | ||||
|  | ||||
|     @nothandled.route("/notok") | ||||
|     def notok(request): | ||||
|         raise Error("uh oh") | ||||
|  | ||||
|     app.blueprint(handled) | ||||
|     app.blueprint(nothandled) | ||||
|  | ||||
|     _, response = app.test_client.get("/notok") | ||||
|     assert response.status == 500 | ||||
|  | ||||
|  | ||||
| def test_bp_listeners(app): | ||||
|     app.route("/")(lambda x: x) | ||||
|     blueprint = Blueprint("test_middleware") | ||||
| @@ -1034,6 +1085,6 @@ def test_bp_set_attribute_warning(): | ||||
|     assert len(record) == 1 | ||||
|     assert record[0].message.args[0] == ( | ||||
|         "Setting variables on Blueprint instances is deprecated " | ||||
|         "and will be removed in version 21.9. You should change your " | ||||
|         "and will be removed in version 21.12. You should change your " | ||||
|         "Blueprint instance to use instance.ctx.foo instead." | ||||
|     ) | ||||
|   | ||||
| @@ -89,7 +89,7 @@ def test_debug(cmd): | ||||
|     out, err, exitcode = capture(command) | ||||
|     lines = out.split(b"\n") | ||||
|  | ||||
|     app_info = lines[9] | ||||
|     app_info = lines[26] | ||||
|     info = json.loads(app_info) | ||||
|  | ||||
|     assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO | ||||
| @@ -103,7 +103,7 @@ def test_auto_reload(cmd): | ||||
|     out, err, exitcode = capture(command) | ||||
|     lines = out.split(b"\n") | ||||
|  | ||||
|     app_info = lines[9] | ||||
|     app_info = lines[26] | ||||
|     info = json.loads(app_info) | ||||
|  | ||||
|     assert info["debug"] is False | ||||
| @@ -118,7 +118,7 @@ def test_access_logs(cmd, expected): | ||||
|     out, err, exitcode = capture(command) | ||||
|     lines = out.split(b"\n") | ||||
|  | ||||
|     app_info = lines[9] | ||||
|     app_info = lines[26] | ||||
|     info = json.loads(app_info) | ||||
|  | ||||
|     assert info["access_log"] is expected | ||||
|   | ||||
| @@ -3,6 +3,7 @@ from os import environ | ||||
| from pathlib import Path | ||||
| from tempfile import TemporaryDirectory | ||||
| from textwrap import dedent | ||||
| from unittest.mock import Mock | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| @@ -13,7 +14,7 @@ from sanic.exceptions import PyFileError | ||||
|  | ||||
| @contextmanager | ||||
| def temp_path(): | ||||
|     """ a simple cross platform replacement for NamedTemporaryFile """ | ||||
|     """a simple cross platform replacement for NamedTemporaryFile""" | ||||
|     with TemporaryDirectory() as td: | ||||
|         yield Path(td, "file") | ||||
|  | ||||
| @@ -350,3 +351,40 @@ def test_update_from_lowercase_key(app): | ||||
|     d = {"test_setting_value": 1} | ||||
|     app.update_config(d) | ||||
|     assert "test_setting_value" not in app.config | ||||
|  | ||||
|  | ||||
| def test_deprecation_notice_when_setting_logo(app): | ||||
|     message = ( | ||||
|         "Setting the config.LOGO is deprecated and will no longer be " | ||||
|         "supported starting in v22.6." | ||||
|     ) | ||||
|     with pytest.warns(DeprecationWarning, match=message): | ||||
|         app.config.LOGO = "My Custom Logo" | ||||
|  | ||||
|  | ||||
| def test_config_set_methods(app, monkeypatch): | ||||
|     post_set = Mock() | ||||
|     monkeypatch.setattr(Config, "_post_set", post_set) | ||||
|  | ||||
|     app.config.FOO = 1 | ||||
|     post_set.assert_called_once_with("FOO", 1) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config["FOO"] = 2 | ||||
|     post_set.assert_called_once_with("FOO", 2) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update({"FOO": 3}) | ||||
|     post_set.assert_called_once_with("FOO", 3) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update([("FOO", 4)]) | ||||
|     post_set.assert_called_once_with("FOO", 4) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update(FOO=5) | ||||
|     post_set.assert_called_once_with("FOO", 5) | ||||
|     post_set.reset_mock() | ||||
|  | ||||
|     app.config.update_config({"FOO": 6}) | ||||
|     post_set.assert_called_once_with("FOO", 6) | ||||
|   | ||||
| @@ -1,6 +1,4 @@ | ||||
| from crypt import methods | ||||
|  | ||||
| from sanic import text | ||||
| from sanic import Sanic, text | ||||
| from sanic.constants import HTTP_METHODS, HTTPMethod | ||||
|  | ||||
|  | ||||
| @@ -14,7 +12,7 @@ def test_string_compat(): | ||||
|     assert HTTPMethod.GET.upper() == "GET" | ||||
|  | ||||
|  | ||||
| def test_use_in_routes(app): | ||||
| def test_use_in_routes(app: Sanic): | ||||
|     @app.route("/", methods=[HTTPMethod.GET, HTTPMethod.POST]) | ||||
|     def handler(_): | ||||
|         return text("It works") | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| import asyncio | ||||
|  | ||||
| from queue import Queue | ||||
| from threading import Event | ||||
|  | ||||
| from sanic.response import text | ||||
| @@ -13,8 +12,6 @@ def test_create_task(app): | ||||
|         await asyncio.sleep(0.05) | ||||
|         e.set() | ||||
|  | ||||
|     app.add_task(coro) | ||||
|  | ||||
|     @app.route("/early") | ||||
|     def not_set(request): | ||||
|         return text(str(e.is_set())) | ||||
| @@ -24,24 +21,30 @@ def test_create_task(app): | ||||
|         await asyncio.sleep(0.1) | ||||
|         return text(str(e.is_set())) | ||||
|  | ||||
|     app.add_task(coro) | ||||
|  | ||||
|     request, response = app.test_client.get("/early") | ||||
|     assert response.body == b"False" | ||||
|  | ||||
|     app.signal_router.reset() | ||||
|     app.add_task(coro) | ||||
|     request, response = app.test_client.get("/late") | ||||
|     assert response.body == b"True" | ||||
|  | ||||
|  | ||||
| def test_create_task_with_app_arg(app): | ||||
|     q = Queue() | ||||
|     @app.after_server_start | ||||
|     async def setup_q(app, _): | ||||
|         app.ctx.q = asyncio.Queue() | ||||
|  | ||||
|     @app.route("/") | ||||
|     def not_set(request): | ||||
|         return "hello" | ||||
|     async def not_set(request): | ||||
|         return text(await request.app.ctx.q.get()) | ||||
|  | ||||
|     async def coro(app): | ||||
|         q.put(app.name) | ||||
|         await app.ctx.q.put(app.name) | ||||
|  | ||||
|     app.add_task(coro) | ||||
|  | ||||
|     request, response = app.test_client.get("/") | ||||
|     assert q.get() == "test_create_task_with_app_arg" | ||||
|     _, response = app.test_client.get("/") | ||||
|     assert response.text == "test_create_task_with_app_arg" | ||||
|   | ||||
| @@ -1,10 +1,12 @@ | ||||
| import pytest | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.errorpages import exception_response | ||||
| from sanic.exceptions import NotFound | ||||
| from sanic.config import Config | ||||
| from sanic.errorpages import HTMLRenderer, exception_response | ||||
| from sanic.exceptions import NotFound, SanicException | ||||
| from sanic.handlers import ErrorHandler | ||||
| from sanic.request import Request | ||||
| from sanic.response import HTTPResponse | ||||
| from sanic.response import HTTPResponse, html, json, text | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @@ -20,7 +22,7 @@ def app(): | ||||
|  | ||||
| @pytest.fixture | ||||
| def fake_request(app): | ||||
|     return Request(b"/foobar", {}, "1.1", "GET", None, app) | ||||
|     return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
| @@ -47,7 +49,13 @@ def test_should_return_html_valid_setting( | ||||
|     try: | ||||
|         raise exception("bad stuff") | ||||
|     except Exception as e: | ||||
|         response = exception_response(fake_request, e, True) | ||||
|         response = exception_response( | ||||
|             fake_request, | ||||
|             e, | ||||
|             True, | ||||
|             base=HTMLRenderer, | ||||
|             fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT, | ||||
|         ) | ||||
|  | ||||
|     assert isinstance(response, HTTPResponse) | ||||
|     assert response.status == status | ||||
| @@ -74,13 +82,263 @@ def test_auto_fallback_with_content_type(app): | ||||
|     app.config.FALLBACK_ERROR_FORMAT = "auto" | ||||
|  | ||||
|     _, response = app.test_client.get( | ||||
|         "/error", headers={"content-type": "application/json"} | ||||
|         "/error", headers={"content-type": "application/json", "accept": "*/*"} | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "application/json" | ||||
|  | ||||
|     _, response = app.test_client.get( | ||||
|         "/error", headers={"content-type": "text/plain"} | ||||
|         "/error", headers={"content-type": "foo/bar", "accept": "*/*"} | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/html; charset=utf-8" | ||||
|  | ||||
|  | ||||
| def test_route_error_format_set_on_auto(app): | ||||
|     @app.get("/text") | ||||
|     def text_response(request): | ||||
|         return text(request.route.ctx.error_format) | ||||
|  | ||||
|     @app.get("/json") | ||||
|     def json_response(request): | ||||
|         return json({"format": request.route.ctx.error_format}) | ||||
|  | ||||
|     @app.get("/html") | ||||
|     def html_response(request): | ||||
|         return html(request.route.ctx.error_format) | ||||
|  | ||||
|     _, response = app.test_client.get("/text") | ||||
|     assert response.text == "text" | ||||
|  | ||||
|     _, response = app.test_client.get("/json") | ||||
|     assert response.json["format"] == "json" | ||||
|  | ||||
|     _, response = app.test_client.get("/html") | ||||
|     assert response.text == "html" | ||||
|  | ||||
|  | ||||
| def test_route_error_response_from_auto_route(app): | ||||
|     @app.get("/text") | ||||
|     def text_response(request): | ||||
|         raise Exception("oops") | ||||
|         return text("Never gonna see this") | ||||
|  | ||||
|     @app.get("/json") | ||||
|     def json_response(request): | ||||
|         raise Exception("oops") | ||||
|         return json({"message": "Never gonna see this"}) | ||||
|  | ||||
|     @app.get("/html") | ||||
|     def html_response(request): | ||||
|         raise Exception("oops") | ||||
|         return html("<h1>Never gonna see this</h1>") | ||||
|  | ||||
|     _, response = app.test_client.get("/text") | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|     _, response = app.test_client.get("/json") | ||||
|     assert response.content_type == "application/json" | ||||
|  | ||||
|     _, response = app.test_client.get("/html") | ||||
|     assert response.content_type == "text/html; charset=utf-8" | ||||
|  | ||||
|  | ||||
| def test_route_error_response_from_explicit_format(app): | ||||
|     @app.get("/text", error_format="json") | ||||
|     def text_response(request): | ||||
|         raise Exception("oops") | ||||
|         return text("Never gonna see this") | ||||
|  | ||||
|     @app.get("/json", error_format="text") | ||||
|     def json_response(request): | ||||
|         raise Exception("oops") | ||||
|         return json({"message": "Never gonna see this"}) | ||||
|  | ||||
|     _, response = app.test_client.get("/text") | ||||
|     assert response.content_type == "application/json" | ||||
|  | ||||
|     _, response = app.test_client.get("/json") | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|  | ||||
| def test_unknown_fallback_format(app): | ||||
|     with pytest.raises(SanicException, match="Unknown format: bad"): | ||||
|         app.config.FALLBACK_ERROR_FORMAT = "bad" | ||||
|  | ||||
|  | ||||
| def test_route_error_format_unknown(app): | ||||
|     with pytest.raises(SanicException, match="Unknown format: bad"): | ||||
|  | ||||
|         @app.get("/text", error_format="bad") | ||||
|         def handler(request): | ||||
|             ... | ||||
|  | ||||
|  | ||||
| def test_fallback_with_content_type_mismatch_accept(app): | ||||
|     app.config.FALLBACK_ERROR_FORMAT = "auto" | ||||
|  | ||||
|     _, response = app.test_client.get( | ||||
|         "/error", | ||||
|         headers={"content-type": "application/json", "accept": "text/plain"}, | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|     _, response = app.test_client.get( | ||||
|         "/error", | ||||
|         headers={"content-type": "text/plain", "accept": "foo/bar"}, | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/html; charset=utf-8" | ||||
|  | ||||
|     app.router.reset() | ||||
|  | ||||
|     @app.route("/alt1") | ||||
|     @app.route("/alt2", error_format="text") | ||||
|     @app.route("/alt3", error_format="html") | ||||
|     def handler(_): | ||||
|         raise Exception("problem here") | ||||
|         # Yes, we know this return value is unreachable. This is on purpose. | ||||
|         return json({}) | ||||
|  | ||||
|     app.router.finalize() | ||||
|  | ||||
|     _, response = app.test_client.get( | ||||
|         "/alt1", | ||||
|         headers={"accept": "foo/bar"}, | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/html; charset=utf-8" | ||||
|     _, response = app.test_client.get( | ||||
|         "/alt1", | ||||
|         headers={"accept": "foo/bar,*/*"}, | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "application/json" | ||||
|  | ||||
|     _, response = app.test_client.get( | ||||
|         "/alt2", | ||||
|         headers={"accept": "foo/bar"}, | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/html; charset=utf-8" | ||||
|     _, response = app.test_client.get( | ||||
|         "/alt2", | ||||
|         headers={"accept": "foo/bar,*/*"}, | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|     _, response = app.test_client.get( | ||||
|         "/alt3", | ||||
|         headers={"accept": "foo/bar"}, | ||||
|     ) | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/html; charset=utf-8" | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "accept,content_type,expected", | ||||
|     ( | ||||
|         (None, None, "text/plain; charset=utf-8"), | ||||
|         ("foo/bar", None, "text/html; charset=utf-8"), | ||||
|         ("application/json", None, "application/json"), | ||||
|         ("application/json,text/plain", None, "application/json"), | ||||
|         ("text/plain,application/json", None, "application/json"), | ||||
|         ("text/plain,foo/bar", None, "text/plain; charset=utf-8"), | ||||
|         # Following test is valid after v22.3 | ||||
|         # ("text/plain,text/html", None, "text/plain; charset=utf-8"), | ||||
|         ("*/*", "foo/bar", "text/html; charset=utf-8"), | ||||
|         ("*/*", "application/json", "application/json"), | ||||
|     ), | ||||
| ) | ||||
| def test_combinations_for_auto(fake_request, accept, content_type, expected): | ||||
|     if accept: | ||||
|         fake_request.headers["accept"] = accept | ||||
|     else: | ||||
|         del fake_request.headers["accept"] | ||||
|  | ||||
|     if content_type: | ||||
|         fake_request.headers["content-type"] = content_type | ||||
|  | ||||
|     try: | ||||
|         raise Exception("bad stuff") | ||||
|     except Exception as e: | ||||
|         response = exception_response( | ||||
|             fake_request, | ||||
|             e, | ||||
|             True, | ||||
|             base=HTMLRenderer, | ||||
|             fallback="auto", | ||||
|         ) | ||||
|  | ||||
|     assert response.content_type == expected | ||||
|  | ||||
|  | ||||
| def test_allow_fallback_error_format_set_main_process_start(app): | ||||
|     @app.main_process_start | ||||
|     async def start(app, _): | ||||
|         app.config.FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     request, response = app.test_client.get("/error") | ||||
|     assert request.app.error_handler.fallback == "text" | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|  | ||||
| def test_setting_fallback_to_non_default_raise_warning(app): | ||||
|     app.error_handler = ErrorHandler(fallback="text") | ||||
|  | ||||
|     assert app.error_handler.fallback == "text" | ||||
|  | ||||
|     with pytest.warns( | ||||
|         UserWarning, | ||||
|         match=( | ||||
|             "Overriding non-default ErrorHandler fallback value. " | ||||
|             "Changing from text to auto." | ||||
|         ), | ||||
|     ): | ||||
|         app.config.FALLBACK_ERROR_FORMAT = "auto" | ||||
|  | ||||
|     assert app.error_handler.fallback == "auto" | ||||
|  | ||||
|     app.config.FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     with pytest.warns( | ||||
|         UserWarning, | ||||
|         match=( | ||||
|             "Overriding non-default ErrorHandler fallback value. " | ||||
|             "Changing from text to json." | ||||
|         ), | ||||
|     ): | ||||
|         app.config.FALLBACK_ERROR_FORMAT = "json" | ||||
|  | ||||
|     assert app.error_handler.fallback == "json" | ||||
|  | ||||
|  | ||||
| def test_allow_fallback_error_format_in_config_injection(): | ||||
|     class MyConfig(Config): | ||||
|         FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     app = Sanic("test", config=MyConfig()) | ||||
|  | ||||
|     @app.route("/error", methods=["GET", "POST"]) | ||||
|     def err(request): | ||||
|         raise Exception("something went wrong") | ||||
|  | ||||
|     request, response = app.test_client.get("/error") | ||||
|     assert request.app.error_handler.fallback == "text" | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|  | ||||
|  | ||||
| def test_allow_fallback_error_format_in_config_replacement(app): | ||||
|     class MyConfig(Config): | ||||
|         FALLBACK_ERROR_FORMAT = "text" | ||||
|  | ||||
|     app.config = MyConfig() | ||||
|  | ||||
|     request, response = app.test_client.get("/error") | ||||
|     assert request.app.error_handler.fallback == "text" | ||||
|     assert response.status == 500 | ||||
|     assert response.content_type == "text/plain; charset=utf-8" | ||||
|   | ||||
| @@ -1,8 +1,10 @@ | ||||
| import logging | ||||
| import warnings | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from bs4 import BeautifulSoup | ||||
| from websockets.version import version as websockets_version | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.exceptions import ( | ||||
| @@ -232,3 +234,41 @@ def test_sanic_exception(exception_app): | ||||
|         request, response = exception_app.test_client.get("/old_abort") | ||||
|     assert response.status == 500 | ||||
|     assert len(w) == 1 and "deprecated" in w[0].message.args[0] | ||||
|  | ||||
|  | ||||
| def test_custom_exception_default_message(exception_app): | ||||
|     class TeaError(SanicException): | ||||
|         message = "Tempest in a teapot" | ||||
|         status_code = 418 | ||||
|  | ||||
|     exception_app.router.reset() | ||||
|  | ||||
|     @exception_app.get("/tempest") | ||||
|     def tempest(_): | ||||
|         raise TeaError | ||||
|  | ||||
|     _, response = exception_app.test_client.get("/tempest", debug=True) | ||||
|     assert response.status == 418 | ||||
|     assert b"Tempest in a teapot" in response.body | ||||
|  | ||||
|  | ||||
| def test_exception_in_ws_logged(caplog): | ||||
|     app = Sanic(__file__) | ||||
|  | ||||
|     @app.websocket("/feed") | ||||
|     async def feed(request, ws): | ||||
|         raise Exception("...") | ||||
|  | ||||
|     with caplog.at_level(logging.INFO): | ||||
|         app.test_client.websocket("/feed") | ||||
|     # Websockets v10.0 and above output an additional | ||||
|     # INFO message when a ws connection is accepted | ||||
|     ws_version_parts = websockets_version.split(".") | ||||
|     ws_major = int(ws_version_parts[0]) | ||||
|     record_index = 2 if ws_major >= 10 else 1 | ||||
|     assert caplog.record_tuples[record_index][0] == "sanic.error" | ||||
|     assert caplog.record_tuples[record_index][1] == logging.ERROR | ||||
|     assert ( | ||||
|         "Exception occurred while handling uri:" | ||||
|         in caplog.record_tuples[record_index][2] | ||||
|     ) | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| import asyncio | ||||
| import logging | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from bs4 import BeautifulSoup | ||||
|  | ||||
| @@ -8,9 +11,6 @@ from sanic.handlers import ErrorHandler | ||||
| from sanic.response import stream, text | ||||
|  | ||||
|  | ||||
| exception_handler_app = Sanic("test_exception_handler") | ||||
|  | ||||
|  | ||||
| async def sample_streaming_fn(response): | ||||
|     await response.write("foo,") | ||||
|     await asyncio.sleep(0.001) | ||||
| @@ -21,113 +21,107 @@ class ErrorWithRequestCtx(ServerError): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/1") | ||||
| def handler_1(request): | ||||
|     raise InvalidUsage("OK") | ||||
| @pytest.fixture | ||||
| def exception_handler_app(): | ||||
|     exception_handler_app = Sanic("test_exception_handler") | ||||
|  | ||||
|     @exception_handler_app.route("/1", error_format="html") | ||||
|     def handler_1(request): | ||||
|         raise InvalidUsage("OK") | ||||
|  | ||||
|     @exception_handler_app.route("/2", error_format="html") | ||||
|     def handler_2(request): | ||||
|         raise ServerError("OK") | ||||
|  | ||||
|     @exception_handler_app.route("/3", error_format="html") | ||||
|     def handler_3(request): | ||||
|         raise NotFound("OK") | ||||
|  | ||||
|     @exception_handler_app.route("/4", error_format="html") | ||||
|     def handler_4(request): | ||||
|         foo = bar  # noqa -- F821 | ||||
|         return text(foo) | ||||
|  | ||||
|     @exception_handler_app.route("/5", error_format="html") | ||||
|     def handler_5(request): | ||||
|         class CustomServerError(ServerError): | ||||
|             pass | ||||
|  | ||||
|         raise CustomServerError("Custom server error") | ||||
|  | ||||
|     @exception_handler_app.route("/6/<arg:int>", error_format="html") | ||||
|     def handler_6(request, arg): | ||||
|         try: | ||||
|             foo = 1 / arg | ||||
|         except Exception as e: | ||||
|             raise e from ValueError(f"{arg}") | ||||
|         return text(foo) | ||||
|  | ||||
|     @exception_handler_app.route("/7", error_format="html") | ||||
|     def handler_7(request): | ||||
|         raise Forbidden("go away!") | ||||
|  | ||||
|     @exception_handler_app.route("/8", error_format="html") | ||||
|     def handler_8(request): | ||||
|  | ||||
|         raise ErrorWithRequestCtx("OK") | ||||
|  | ||||
|     @exception_handler_app.exception(ErrorWithRequestCtx, NotFound) | ||||
|     def handler_exception_with_ctx(request, exception): | ||||
|         return text(request.ctx.middleware_ran) | ||||
|  | ||||
|     @exception_handler_app.exception(ServerError) | ||||
|     def handler_exception(request, exception): | ||||
|         return text("OK") | ||||
|  | ||||
|     @exception_handler_app.exception(Forbidden) | ||||
|     async def async_handler_exception(request, exception): | ||||
|         return stream( | ||||
|             sample_streaming_fn, | ||||
|             content_type="text/csv", | ||||
|         ) | ||||
|  | ||||
|     @exception_handler_app.middleware | ||||
|     async def some_request_middleware(request): | ||||
|         request.ctx.middleware_ran = "Done." | ||||
|  | ||||
|     return exception_handler_app | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/2") | ||||
| def handler_2(request): | ||||
|     raise ServerError("OK") | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/3") | ||||
| def handler_3(request): | ||||
|     raise NotFound("OK") | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/4") | ||||
| def handler_4(request): | ||||
|     foo = bar  # noqa -- F821 undefined name 'bar' is done to throw exception | ||||
|     return text(foo) | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/5") | ||||
| def handler_5(request): | ||||
|     class CustomServerError(ServerError): | ||||
|         pass | ||||
|  | ||||
|     raise CustomServerError("Custom server error") | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/6/<arg:int>") | ||||
| def handler_6(request, arg): | ||||
|     try: | ||||
|         foo = 1 / arg | ||||
|     except Exception as e: | ||||
|         raise e from ValueError(f"{arg}") | ||||
|     return text(foo) | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/7") | ||||
| def handler_7(request): | ||||
|     raise Forbidden("go away!") | ||||
|  | ||||
|  | ||||
| @exception_handler_app.route("/8") | ||||
| def handler_8(request): | ||||
|  | ||||
|     raise ErrorWithRequestCtx("OK") | ||||
|  | ||||
|  | ||||
| @exception_handler_app.exception(ErrorWithRequestCtx, NotFound) | ||||
| def handler_exception_with_ctx(request, exception): | ||||
|     return text(request.ctx.middleware_ran) | ||||
|  | ||||
|  | ||||
| @exception_handler_app.exception(ServerError) | ||||
| def handler_exception(request, exception): | ||||
|     return text("OK") | ||||
|  | ||||
|  | ||||
| @exception_handler_app.exception(Forbidden) | ||||
| async def async_handler_exception(request, exception): | ||||
|     return stream( | ||||
|         sample_streaming_fn, | ||||
|         content_type="text/csv", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @exception_handler_app.middleware | ||||
| async def some_request_middleware(request): | ||||
|     request.ctx.middleware_ran = "Done." | ||||
|  | ||||
|  | ||||
| def test_invalid_usage_exception_handler(): | ||||
| def test_invalid_usage_exception_handler(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/1") | ||||
|     assert response.status == 400 | ||||
|  | ||||
|  | ||||
| def test_server_error_exception_handler(): | ||||
| def test_server_error_exception_handler(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/2") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "OK" | ||||
|  | ||||
|  | ||||
| def test_not_found_exception_handler(): | ||||
| def test_not_found_exception_handler(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/3") | ||||
|     assert response.status == 200 | ||||
|  | ||||
|  | ||||
| def test_text_exception__handler(): | ||||
| def test_text_exception__handler(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/random") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "Done." | ||||
|  | ||||
|  | ||||
| def test_async_exception_handler(): | ||||
| def test_async_exception_handler(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/7") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "foo,bar" | ||||
|  | ||||
|  | ||||
| def test_html_traceback_output_in_debug_mode(): | ||||
| def test_html_traceback_output_in_debug_mode(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/4", debug=True) | ||||
|     assert response.status == 500 | ||||
|     soup = BeautifulSoup(response.body, "html.parser") | ||||
|     html = str(soup) | ||||
|  | ||||
|     assert "response = handler(request, **kwargs)" in html | ||||
|     assert "handler_4" in html | ||||
|     assert "foo = bar" in html | ||||
|  | ||||
| @@ -137,12 +131,12 @@ def test_html_traceback_output_in_debug_mode(): | ||||
|     ) == summary_text | ||||
|  | ||||
|  | ||||
| def test_inherited_exception_handler(): | ||||
| def test_inherited_exception_handler(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/5") | ||||
|     assert response.status == 200 | ||||
|  | ||||
|  | ||||
| def test_chained_exception_handler(): | ||||
| def test_chained_exception_handler(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get( | ||||
|         "/6/0", debug=True | ||||
|     ) | ||||
| @@ -151,11 +145,9 @@ def test_chained_exception_handler(): | ||||
|     soup = BeautifulSoup(response.body, "html.parser") | ||||
|     html = str(soup) | ||||
|  | ||||
|     assert "response = handler(request, **kwargs)" in html | ||||
|     assert "handler_6" in html | ||||
|     assert "foo = 1 / arg" in html | ||||
|     assert "ValueError" in html | ||||
|     assert "The above exception was the direct cause" in html | ||||
|  | ||||
|     summary_text = " ".join(soup.select(".summary")[0].text.split()) | ||||
|     assert ( | ||||
| @@ -163,7 +155,7 @@ def test_chained_exception_handler(): | ||||
|     ) == summary_text | ||||
|  | ||||
|  | ||||
| def test_exception_handler_lookup(): | ||||
| def test_exception_handler_lookup(exception_handler_app): | ||||
|     class CustomError(Exception): | ||||
|         pass | ||||
|  | ||||
| @@ -186,26 +178,52 @@ def test_exception_handler_lookup(): | ||||
|         class ModuleNotFoundError(ImportError): | ||||
|             pass | ||||
|  | ||||
|     handler = ErrorHandler() | ||||
|     handler = ErrorHandler("auto") | ||||
|     handler.add(ImportError, import_error_handler) | ||||
|     handler.add(CustomError, custom_error_handler) | ||||
|     handler.add(ServerError, server_error_handler) | ||||
|  | ||||
|     assert handler.lookup(ImportError()) == import_error_handler | ||||
|     assert handler.lookup(ModuleNotFoundError()) == import_error_handler | ||||
|     assert handler.lookup(CustomError()) == custom_error_handler | ||||
|     assert handler.lookup(ServerError("Error")) == server_error_handler | ||||
|     assert handler.lookup(CustomServerError("Error")) == server_error_handler | ||||
|     assert handler.lookup(ImportError(), None) == import_error_handler | ||||
|     assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler | ||||
|     assert handler.lookup(CustomError(), None) == custom_error_handler | ||||
|     assert handler.lookup(ServerError("Error"), None) == server_error_handler | ||||
|     assert ( | ||||
|         handler.lookup(CustomServerError("Error"), None) | ||||
|         == server_error_handler | ||||
|     ) | ||||
|  | ||||
|     # once again to ensure there is no caching bug | ||||
|     assert handler.lookup(ImportError()) == import_error_handler | ||||
|     assert handler.lookup(ModuleNotFoundError()) == import_error_handler | ||||
|     assert handler.lookup(CustomError()) == custom_error_handler | ||||
|     assert handler.lookup(ServerError("Error")) == server_error_handler | ||||
|     assert handler.lookup(CustomServerError("Error")) == server_error_handler | ||||
|     assert handler.lookup(ImportError(), None) == import_error_handler | ||||
|     assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler | ||||
|     assert handler.lookup(CustomError(), None) == custom_error_handler | ||||
|     assert handler.lookup(ServerError("Error"), None) == server_error_handler | ||||
|     assert ( | ||||
|         handler.lookup(CustomServerError("Error"), None) | ||||
|         == server_error_handler | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def test_exception_handler_processed_request_middleware(): | ||||
| def test_exception_handler_processed_request_middleware(exception_handler_app): | ||||
|     request, response = exception_handler_app.test_client.get("/8") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "Done." | ||||
|  | ||||
|  | ||||
| def test_single_arg_exception_handler_notice(exception_handler_app, caplog): | ||||
|     class CustomErrorHandler(ErrorHandler): | ||||
|         def lookup(self, exception): | ||||
|             return super().lookup(exception, None) | ||||
|  | ||||
|     exception_handler_app.error_handler = CustomErrorHandler() | ||||
|  | ||||
|     with caplog.at_level(logging.WARNING): | ||||
|         _, response = exception_handler_app.test_client.get("/1") | ||||
|  | ||||
|     assert caplog.records[0].message == ( | ||||
|         "You are using a deprecated error handler. The lookup method should " | ||||
|         "accept two positional parameters: (exception, route_name: " | ||||
|         "Optional[str]). Until you upgrade your ErrorHandler.lookup, " | ||||
|         "Blueprint specific exceptions will not work properly. Beginning in " | ||||
|         "v22.3, the legacy style lookup method will not work at all." | ||||
|     ) | ||||
|     assert response.status == 400 | ||||
|   | ||||
							
								
								
									
										46
									
								
								tests/test_graceful_shutdown.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/test_graceful_shutdown.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| import time | ||||
|  | ||||
| from collections import Counter | ||||
| from multiprocessing import Process | ||||
|  | ||||
| import httpx | ||||
|  | ||||
|  | ||||
| PORT = 42101 | ||||
|  | ||||
|  | ||||
| def test_no_exceptions_when_cancel_pending_request(app, caplog): | ||||
|     app.config.GRACEFUL_SHUTDOWN_TIMEOUT = 1 | ||||
|  | ||||
|     @app.get("/") | ||||
|     async def handler(request): | ||||
|         await asyncio.sleep(5) | ||||
|  | ||||
|     @app.after_server_start | ||||
|     def shutdown(app, _): | ||||
|         time.sleep(0.2) | ||||
|         app.stop() | ||||
|  | ||||
|     def ping(): | ||||
|         time.sleep(0.1) | ||||
|         response = httpx.get("http://127.0.0.1:8000") | ||||
|         print(response.status_code) | ||||
|  | ||||
|     p = Process(target=ping) | ||||
|     p.start() | ||||
|  | ||||
|     with caplog.at_level(logging.INFO): | ||||
|         app.run() | ||||
|  | ||||
|     p.kill() | ||||
|  | ||||
|     counter = Counter([r[1] for r in caplog.record_tuples]) | ||||
|  | ||||
|     assert counter[logging.INFO] == 5 | ||||
|     assert logging.ERROR not in counter | ||||
|     assert ( | ||||
|         caplog.record_tuples[3][2] | ||||
|         == "Request: GET http://127.0.0.1:8000/ stopped. Transport is closed." | ||||
|     ) | ||||
							
								
								
									
										39
									
								
								tests/test_handler_annotations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								tests/test_handler_annotations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| from uuid import UUID | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from sanic import json | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "idx,path,expectation", | ||||
|     ( | ||||
|         (0, "/abc", "str"), | ||||
|         (1, "/123", "int"), | ||||
|         (2, "/123.5", "float"), | ||||
|         (3, "/8af729fe-2b94-4a95-a168-c07068568429", "UUID"), | ||||
|     ), | ||||
| ) | ||||
| def test_annotated_handlers(app, idx, path, expectation): | ||||
|     def build_response(num, foo): | ||||
|         return json({"num": num, "type": type(foo).__name__}) | ||||
|  | ||||
|     @app.get("/<foo>") | ||||
|     def handler0(_, foo: str): | ||||
|         return build_response(0, foo) | ||||
|  | ||||
|     @app.get("/<foo>") | ||||
|     def handler1(_, foo: int): | ||||
|         return build_response(1, foo) | ||||
|  | ||||
|     @app.get("/<foo>") | ||||
|     def handler2(_, foo: float): | ||||
|         return build_response(2, foo) | ||||
|  | ||||
|     @app.get("/<foo>") | ||||
|     def handler3(_, foo: UUID): | ||||
|         return build_response(3, foo) | ||||
|  | ||||
|     _, response = app.test_client.get(path) | ||||
|     assert response.json["num"] == idx | ||||
|     assert response.json["type"] == expectation | ||||
| @@ -3,8 +3,9 @@ from unittest.mock import Mock | ||||
| import pytest | ||||
|  | ||||
| from sanic import headers, text | ||||
| from sanic.exceptions import PayloadTooLarge | ||||
| from sanic.exceptions import InvalidHeader, PayloadTooLarge | ||||
| from sanic.http import Http | ||||
| from sanic.request import Request | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @@ -182,3 +183,187 @@ def test_request_line(app): | ||||
|     ) | ||||
|  | ||||
|     assert request.request_line == b"GET / HTTP/1.1" | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "raw", | ||||
|     ( | ||||
|         "show/first, show/second", | ||||
|         "show/*, show/first", | ||||
|         "*/*, show/first", | ||||
|         "*/*, show/*", | ||||
|         "other/*; q=0.1, show/*; q=0.2", | ||||
|         "show/first; q=0.5, show/second; q=0.5", | ||||
|         "show/first; foo=bar, show/second; foo=bar", | ||||
|         "show/second, show/first; foo=bar", | ||||
|         "show/second; q=0.5, show/first; foo=bar; q=0.5", | ||||
|         "show/second; q=0.5, show/first; q=1.0", | ||||
|         "show/first, show/second; q=1.0", | ||||
|     ), | ||||
| ) | ||||
| def test_parse_accept_ordered_okay(raw): | ||||
|     ordered = headers.parse_accept(raw) | ||||
|     expected_subtype = ( | ||||
|         "*" if all(q.subtype.is_wildcard for q in ordered) else "first" | ||||
|     ) | ||||
|     assert ordered[0].type_ == "show" | ||||
|     assert ordered[0].subtype == expected_subtype | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "raw", | ||||
|     ( | ||||
|         "missing", | ||||
|         "missing/", | ||||
|         "/missing", | ||||
|     ), | ||||
| ) | ||||
| def test_bad_accept(raw): | ||||
|     with pytest.raises(InvalidHeader): | ||||
|         headers.parse_accept(raw) | ||||
|  | ||||
|  | ||||
| def test_empty_accept(): | ||||
|     assert headers.parse_accept("") == [] | ||||
|  | ||||
|  | ||||
| def test_wildcard_accept_set_ok(): | ||||
|     accept = headers.parse_accept("*/*")[0] | ||||
|     assert accept.type_.is_wildcard | ||||
|     assert accept.subtype.is_wildcard | ||||
|  | ||||
|     accept = headers.parse_accept("foo/bar")[0] | ||||
|     assert not accept.type_.is_wildcard | ||||
|     assert not accept.subtype.is_wildcard | ||||
|  | ||||
|  | ||||
| def test_accept_parsed_against_str(): | ||||
|     accept = headers.Accept.parse("foo/bar") | ||||
|     assert accept > "foo/bar; q=0.1" | ||||
|  | ||||
|  | ||||
| def test_media_type_equality(): | ||||
|     assert headers.MediaType("foo") == headers.MediaType("foo") == "foo" | ||||
|     assert headers.MediaType("foo") == headers.MediaType("*") == "*" | ||||
|     assert headers.MediaType("foo") != headers.MediaType("bar") | ||||
|     assert headers.MediaType("foo") != "bar" | ||||
|  | ||||
|  | ||||
| def test_media_type_matching(): | ||||
|     assert headers.MediaType("foo").match(headers.MediaType("foo")) | ||||
|     assert headers.MediaType("foo").match("foo") | ||||
|  | ||||
|     assert not headers.MediaType("foo").match(headers.MediaType("*")) | ||||
|     assert not headers.MediaType("foo").match("*") | ||||
|  | ||||
|     assert not headers.MediaType("foo").match(headers.MediaType("bar")) | ||||
|     assert not headers.MediaType("foo").match("bar") | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "value,other,outcome,allow_type,allow_subtype", | ||||
|     ( | ||||
|         # ALLOW BOTH | ||||
|         ("foo/bar", "foo/bar", True, True, True), | ||||
|         ("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), | ||||
|         ("foo/bar", "foo/*", True, True, True), | ||||
|         ("foo/bar", headers.Accept.parse("foo/*"), True, True, True), | ||||
|         ("foo/bar", "*/*", True, True, True), | ||||
|         ("foo/bar", headers.Accept.parse("*/*"), True, True, True), | ||||
|         ("foo/*", "foo/bar", True, True, True), | ||||
|         ("foo/*", headers.Accept.parse("foo/bar"), True, True, True), | ||||
|         ("foo/*", "foo/*", True, True, True), | ||||
|         ("foo/*", headers.Accept.parse("foo/*"), True, True, True), | ||||
|         ("foo/*", "*/*", True, True, True), | ||||
|         ("foo/*", headers.Accept.parse("*/*"), True, True, True), | ||||
|         ("*/*", "foo/bar", True, True, True), | ||||
|         ("*/*", headers.Accept.parse("foo/bar"), True, True, True), | ||||
|         ("*/*", "foo/*", True, True, True), | ||||
|         ("*/*", headers.Accept.parse("foo/*"), True, True, True), | ||||
|         ("*/*", "*/*", True, True, True), | ||||
|         ("*/*", headers.Accept.parse("*/*"), True, True, True), | ||||
|         # ALLOW TYPE | ||||
|         ("foo/bar", "foo/bar", True, True, False), | ||||
|         ("foo/bar", headers.Accept.parse("foo/bar"), True, True, False), | ||||
|         ("foo/bar", "foo/*", False, True, False), | ||||
|         ("foo/bar", headers.Accept.parse("foo/*"), False, True, False), | ||||
|         ("foo/bar", "*/*", False, True, False), | ||||
|         ("foo/bar", headers.Accept.parse("*/*"), False, True, False), | ||||
|         ("foo/*", "foo/bar", False, True, False), | ||||
|         ("foo/*", headers.Accept.parse("foo/bar"), False, True, False), | ||||
|         ("foo/*", "foo/*", False, True, False), | ||||
|         ("foo/*", headers.Accept.parse("foo/*"), False, True, False), | ||||
|         ("foo/*", "*/*", False, True, False), | ||||
|         ("foo/*", headers.Accept.parse("*/*"), False, True, False), | ||||
|         ("*/*", "foo/bar", False, True, False), | ||||
|         ("*/*", headers.Accept.parse("foo/bar"), False, True, False), | ||||
|         ("*/*", "foo/*", False, True, False), | ||||
|         ("*/*", headers.Accept.parse("foo/*"), False, True, False), | ||||
|         ("*/*", "*/*", False, True, False), | ||||
|         ("*/*", headers.Accept.parse("*/*"), False, True, False), | ||||
|         # ALLOW SUBTYPE | ||||
|         ("foo/bar", "foo/bar", True, False, True), | ||||
|         ("foo/bar", headers.Accept.parse("foo/bar"), True, False, True), | ||||
|         ("foo/bar", "foo/*", True, False, True), | ||||
|         ("foo/bar", headers.Accept.parse("foo/*"), True, False, True), | ||||
|         ("foo/bar", "*/*", False, False, True), | ||||
|         ("foo/bar", headers.Accept.parse("*/*"), False, False, True), | ||||
|         ("foo/*", "foo/bar", True, False, True), | ||||
|         ("foo/*", headers.Accept.parse("foo/bar"), True, False, True), | ||||
|         ("foo/*", "foo/*", True, False, True), | ||||
|         ("foo/*", headers.Accept.parse("foo/*"), True, False, True), | ||||
|         ("foo/*", "*/*", False, False, True), | ||||
|         ("foo/*", headers.Accept.parse("*/*"), False, False, True), | ||||
|         ("*/*", "foo/bar", False, False, True), | ||||
|         ("*/*", headers.Accept.parse("foo/bar"), False, False, True), | ||||
|         ("*/*", "foo/*", False, False, True), | ||||
|         ("*/*", headers.Accept.parse("foo/*"), False, False, True), | ||||
|         ("*/*", "*/*", False, False, True), | ||||
|         ("*/*", headers.Accept.parse("*/*"), False, False, True), | ||||
|     ), | ||||
| ) | ||||
| def test_accept_matching(value, other, outcome, allow_type, allow_subtype): | ||||
|     assert ( | ||||
|         headers.Accept.parse(value).match( | ||||
|             other, | ||||
|             allow_type_wildcard=allow_type, | ||||
|             allow_subtype_wildcard=allow_subtype, | ||||
|         ) | ||||
|         is outcome | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*")) | ||||
| def test_value_in_accept(value): | ||||
|     acceptable = headers.parse_accept(value) | ||||
|     assert "foo/bar" in acceptable | ||||
|     assert "foo/*" in acceptable | ||||
|     assert "*/*" in acceptable | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("value", ("foo/bar", "foo/*")) | ||||
| def test_value_not_in_accept(value): | ||||
|     acceptable = headers.parse_accept(value) | ||||
|     assert "no/match" not in acceptable | ||||
|     assert "no/*" not in acceptable | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "header,expected", | ||||
|     ( | ||||
|         ( | ||||
|             "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",  # noqa: E501 | ||||
|             [ | ||||
|                 "text/html", | ||||
|                 "application/xhtml+xml", | ||||
|                 "image/avif", | ||||
|                 "image/webp", | ||||
|                 "application/xml;q=0.9", | ||||
|                 "*/*;q=0.8", | ||||
|             ], | ||||
|         ), | ||||
|     ), | ||||
| ) | ||||
| def test_browser_headers(header, expected): | ||||
|     request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) | ||||
|     assert request.accept == expected | ||||
|   | ||||
							
								
								
									
										137
									
								
								tests/test_http.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								tests/test_http.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,137 @@ | ||||
| import asyncio | ||||
| import json as stdjson | ||||
|  | ||||
| from collections import namedtuple | ||||
| from textwrap import dedent | ||||
| from typing import AnyStr | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from sanic_testing.reusable import ReusableClient | ||||
|  | ||||
| from sanic import json, text | ||||
| from sanic.app import Sanic | ||||
|  | ||||
|  | ||||
| PORT = 1234 | ||||
|  | ||||
|  | ||||
| class RawClient: | ||||
|     CRLF = b"\r\n" | ||||
|  | ||||
|     def __init__(self, host: str, port: int): | ||||
|         self.reader = None | ||||
|         self.writer = None | ||||
|         self.host = host | ||||
|         self.port = port | ||||
|  | ||||
|     async def connect(self): | ||||
|         self.reader, self.writer = await asyncio.open_connection( | ||||
|             self.host, self.port | ||||
|         ) | ||||
|  | ||||
|     async def close(self): | ||||
|         self.writer.close() | ||||
|         await self.writer.wait_closed() | ||||
|  | ||||
|     async def send(self, message: AnyStr): | ||||
|         if isinstance(message, str): | ||||
|             msg = self._clean(message).encode("utf-8") | ||||
|         else: | ||||
|             msg = message | ||||
|         await self._send(msg) | ||||
|  | ||||
|     async def _send(self, message: bytes): | ||||
|         if not self.writer: | ||||
|             raise Exception("No open write stream") | ||||
|         self.writer.write(message) | ||||
|  | ||||
|     async def recv(self, nbytes: int = -1) -> bytes: | ||||
|         if not self.reader: | ||||
|             raise Exception("No open read stream") | ||||
|         return await self.reader.read(nbytes) | ||||
|  | ||||
|     def _clean(self, message: str) -> str: | ||||
|         return ( | ||||
|             dedent(message) | ||||
|             .lstrip("\n") | ||||
|             .replace("\n", self.CRLF.decode("utf-8")) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def test_app(app: Sanic): | ||||
|     app.config.KEEP_ALIVE_TIMEOUT = 1 | ||||
|  | ||||
|     @app.get("/") | ||||
|     async def base_handler(request): | ||||
|         return text("111122223333444455556666777788889999") | ||||
|  | ||||
|     @app.post("/upload", stream=True) | ||||
|     async def upload_handler(request): | ||||
|         data = [part.decode("utf-8") async for part in request.stream] | ||||
|         return json(data) | ||||
|  | ||||
|     return app | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def runner(test_app): | ||||
|     client = ReusableClient(test_app, port=PORT) | ||||
|     client.run() | ||||
|     yield client | ||||
|     client.stop() | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def client(runner): | ||||
|     client = namedtuple("Client", ("raw", "send", "recv")) | ||||
|  | ||||
|     raw = RawClient(runner.host, runner.port) | ||||
|     runner._run(raw.connect()) | ||||
|  | ||||
|     def send(msg): | ||||
|         nonlocal runner | ||||
|         nonlocal raw | ||||
|         runner._run(raw.send(msg)) | ||||
|  | ||||
|     def recv(**kwargs): | ||||
|         nonlocal runner | ||||
|         nonlocal raw | ||||
|         method = raw.recv_until if "until" in kwargs else raw.recv | ||||
|         return runner._run(method(**kwargs)) | ||||
|  | ||||
|     yield client(raw, send, recv) | ||||
|  | ||||
|     runner._run(raw.close()) | ||||
|  | ||||
|  | ||||
| def test_full_message(client): | ||||
|     client.send( | ||||
|         """ | ||||
|         GET / HTTP/1.1 | ||||
|         host: localhost:7777 | ||||
|  | ||||
|         """ | ||||
|     ) | ||||
|     response = client.recv() | ||||
|     assert len(response) == 140 | ||||
|     assert b"200 OK" in response | ||||
|  | ||||
|  | ||||
| def test_transfer_chunked(client): | ||||
|     client.send( | ||||
|         """ | ||||
|         POST /upload HTTP/1.1 | ||||
|         transfer-encoding: chunked | ||||
|  | ||||
|         """ | ||||
|     ) | ||||
|     client.send(b"3\r\nfoo\r\n") | ||||
|     client.send(b"3\r\nbar\r\n") | ||||
|     client.send(b"0\r\n\r\n") | ||||
|     response = client.recv() | ||||
|     _, body = response.rsplit(b"\r\n\r\n", 1) | ||||
|     data = stdjson.loads(body) | ||||
|  | ||||
|     assert data == ["foo", "bar"] | ||||
| @@ -2,16 +2,13 @@ import asyncio | ||||
| import platform | ||||
|  | ||||
| from asyncio import sleep as aio_sleep | ||||
| from json import JSONDecodeError | ||||
| from os import environ | ||||
|  | ||||
| import httpcore | ||||
| import httpx | ||||
| import pytest | ||||
|  | ||||
| from sanic_testing.testing import HOST, SanicTestClient | ||||
| from sanic_testing.reusable import ReusableClient | ||||
|  | ||||
| from sanic import Sanic, server | ||||
| from sanic import Sanic | ||||
| from sanic.compat import OS_IS_WINDOWS | ||||
| from sanic.response import text | ||||
|  | ||||
| @@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} | ||||
| PORT = 42101  # test_keep_alive_timeout_reuse doesn't work with random port | ||||
|  | ||||
|  | ||||
| class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool): | ||||
|     last_reused_connection = None | ||||
|  | ||||
|     async def _get_connection_from_pool(self, *args, **kwargs): | ||||
|         conn = await super()._get_connection_from_pool(*args, **kwargs) | ||||
|         self.__class__.last_reused_connection = conn | ||||
|         return conn | ||||
|  | ||||
|  | ||||
| class ResusableSanicSession(httpx.AsyncClient): | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         transport = ReusableSanicConnectionPool() | ||||
|         super().__init__(transport=transport, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class ReuseableSanicTestClient(SanicTestClient): | ||||
|     def __init__(self, app, loop=None): | ||||
|         super().__init__(app) | ||||
|         if loop is None: | ||||
|             loop = asyncio.get_event_loop() | ||||
|         self._loop = loop | ||||
|         self._server = None | ||||
|         self._tcp_connector = None | ||||
|         self._session = None | ||||
|  | ||||
|     def get_new_session(self): | ||||
|         return ResusableSanicSession() | ||||
|  | ||||
|     # Copied from SanicTestClient, but with some changes to reuse the | ||||
|     # same loop for the same app. | ||||
|     def _sanic_endpoint_test( | ||||
|         self, | ||||
|         method="get", | ||||
|         uri="/", | ||||
|         gather_request=True, | ||||
|         debug=False, | ||||
|         server_kwargs=None, | ||||
|         *request_args, | ||||
|         **request_kwargs, | ||||
|     ): | ||||
|         loop = self._loop | ||||
|         results = [None, None] | ||||
|         exceptions = [] | ||||
|         server_kwargs = server_kwargs or {"return_asyncio_server": True} | ||||
|         if gather_request: | ||||
|  | ||||
|             def _collect_request(request): | ||||
|                 if results[0] is None: | ||||
|                     results[0] = request | ||||
|  | ||||
|             self.app.request_middleware.appendleft(_collect_request) | ||||
|  | ||||
|         if uri.startswith( | ||||
|             ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") | ||||
|         ): | ||||
|             url = uri | ||||
|         else: | ||||
|             uri = uri if uri.startswith("/") else f"/{uri}" | ||||
|             scheme = "http" | ||||
|             url = f"{scheme}://{HOST}:{PORT}{uri}" | ||||
|  | ||||
|         @self.app.listener("after_server_start") | ||||
|         async def _collect_response(loop): | ||||
|             try: | ||||
|                 response = await self._local_request( | ||||
|                     method, url, *request_args, **request_kwargs | ||||
|                 ) | ||||
|                 results[-1] = response | ||||
|             except Exception as e2: | ||||
|                 exceptions.append(e2) | ||||
|  | ||||
|         if self._server is not None: | ||||
|             _server = self._server | ||||
|         else: | ||||
|             _server_co = self.app.create_server( | ||||
|                 host=HOST, debug=debug, port=PORT, **server_kwargs | ||||
|             ) | ||||
|  | ||||
|             server.trigger_events( | ||||
|                 self.app.listeners["before_server_start"], loop | ||||
|             ) | ||||
|  | ||||
|             try: | ||||
|                 loop._stopping = False | ||||
|                 _server = loop.run_until_complete(_server_co) | ||||
|             except Exception as e1: | ||||
|                 raise e1 | ||||
|             self._server = _server | ||||
|         server.trigger_events(self.app.listeners["after_server_start"], loop) | ||||
|         self.app.listeners["after_server_start"].pop() | ||||
|  | ||||
|         if exceptions: | ||||
|             raise ValueError(f"Exception during request: {exceptions}") | ||||
|  | ||||
|         if gather_request: | ||||
|             self.app.request_middleware.pop() | ||||
|             try: | ||||
|                 request, response = results | ||||
|                 return request, response | ||||
|             except Exception: | ||||
|                 raise ValueError( | ||||
|                     f"Request and response object expected, got ({results})" | ||||
|                 ) | ||||
|         else: | ||||
|             try: | ||||
|                 return results[-1] | ||||
|             except Exception: | ||||
|                 raise ValueError(f"Request object expected, got ({results})") | ||||
|  | ||||
|     def kill_server(self): | ||||
|         try: | ||||
|             if self._server: | ||||
|                 self._server.close() | ||||
|                 self._loop.run_until_complete(self._server.wait_closed()) | ||||
|                 self._server = None | ||||
|  | ||||
|             if self._session: | ||||
|                 self._loop.run_until_complete(self._session.aclose()) | ||||
|                 self._session = None | ||||
|  | ||||
|         except Exception as e3: | ||||
|             raise e3 | ||||
|  | ||||
|     # Copied from SanicTestClient, but with some changes to reuse the | ||||
|     # same TCPConnection and the sane ClientSession more than once. | ||||
|     # Note, you cannot use the same session if you are in a _different_ | ||||
|     # loop, so the changes above are required too. | ||||
|     async def _local_request(self, method, url, *args, **kwargs): | ||||
|         raw_cookies = kwargs.pop("raw_cookies", None) | ||||
|         request_keepalive = kwargs.pop( | ||||
|             "request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"] | ||||
|         ) | ||||
|         if not self._session: | ||||
|             self._session = self.get_new_session() | ||||
|         try: | ||||
|             response = await getattr(self._session, method.lower())( | ||||
|                 url, timeout=request_keepalive, *args, **kwargs | ||||
|             ) | ||||
|         except NameError: | ||||
|             raise Exception(response.status_code) | ||||
|  | ||||
|         try: | ||||
|             response.json = response.json() | ||||
|         except (JSONDecodeError, UnicodeDecodeError): | ||||
|             response.json = None | ||||
|  | ||||
|         response.body = await response.aread() | ||||
|         response.status = response.status_code | ||||
|         response.content_type = response.headers.get("content-type") | ||||
|  | ||||
|         if raw_cookies: | ||||
|             response.raw_cookies = {} | ||||
|             for cookie in response.cookies: | ||||
|                 response.raw_cookies[cookie.name] = cookie | ||||
|  | ||||
|         return response | ||||
|  | ||||
|  | ||||
| keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") | ||||
| keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") | ||||
| keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") | ||||
| @@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse(): | ||||
|     """If the server keep-alive timeout and client keep-alive timeout are | ||||
|     both longer than the delay, the client _and_ server will successfully | ||||
|     reuse the existing connection.""" | ||||
|     try: | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT) | ||||
|     with client: | ||||
|         headers = {"Connection": "keep-alive"} | ||||
|         request, response = client.get("/1", headers=headers) | ||||
|         assert response.status == 200 | ||||
|         assert response.text == "OK" | ||||
|         assert request.protocol.state["requests_count"] == 1 | ||||
|  | ||||
|         loop.run_until_complete(aio_sleep(1)) | ||||
|  | ||||
|         request, response = client.get("/1") | ||||
|         assert response.status == 200 | ||||
|         assert response.text == "OK" | ||||
|         assert ReusableSanicConnectionPool.last_reused_connection | ||||
|     finally: | ||||
|         client.kill_server() | ||||
|         assert request.protocol.state["requests_count"] == 2 | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif( | ||||
| @@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse(): | ||||
| def test_keep_alive_client_timeout(): | ||||
|     """If the server keep-alive timeout is longer than the client | ||||
|     keep-alive timeout, client will try to create a new connection here.""" | ||||
|     try: | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     client = ReusableClient( | ||||
|         keep_alive_app_client_timeout, loop=loop, port=PORT | ||||
|     ) | ||||
|     with client: | ||||
|         headers = {"Connection": "keep-alive"} | ||||
|         _, response = client.get("/1", headers=headers, request_keepalive=1) | ||||
|         request, response = client.get("/1", headers=headers, timeout=1) | ||||
|  | ||||
|         assert response.status == 200 | ||||
|         assert response.text == "OK" | ||||
|         assert request.protocol.state["requests_count"] == 1 | ||||
|  | ||||
|         loop.run_until_complete(aio_sleep(2)) | ||||
|         _, response = client.get("/1", request_keepalive=1) | ||||
|  | ||||
|         assert ReusableSanicConnectionPool.last_reused_connection is None | ||||
|     finally: | ||||
|         client.kill_server() | ||||
|         request, response = client.get("/1", timeout=1) | ||||
|         assert request.protocol.state["requests_count"] == 1 | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif( | ||||
| @@ -277,22 +117,23 @@ def test_keep_alive_server_timeout(): | ||||
|     keep-alive timeout, the client will either a 'Connection reset' error | ||||
|     _or_ a new connection. Depending on how the event-loop handles the | ||||
|     broken server connection.""" | ||||
|     try: | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     client = ReusableClient( | ||||
|         keep_alive_app_server_timeout, loop=loop, port=PORT | ||||
|     ) | ||||
|     with client: | ||||
|         headers = {"Connection": "keep-alive"} | ||||
|         _, response = client.get("/1", headers=headers, request_keepalive=60) | ||||
|         request, response = client.get("/1", headers=headers, timeout=60) | ||||
|  | ||||
|         assert response.status == 200 | ||||
|         assert response.text == "OK" | ||||
|         assert request.protocol.state["requests_count"] == 1 | ||||
|  | ||||
|         loop.run_until_complete(aio_sleep(3)) | ||||
|         _, response = client.get("/1", request_keepalive=60) | ||||
|         request, response = client.get("/1", timeout=60) | ||||
|  | ||||
|         assert ReusableSanicConnectionPool.last_reused_connection is None | ||||
|     finally: | ||||
|         client.kill_server() | ||||
|         assert request.protocol.state["requests_count"] == 1 | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif( | ||||
| @@ -300,10 +141,10 @@ def test_keep_alive_server_timeout(): | ||||
|     reason="Not testable with current client", | ||||
| ) | ||||
| def test_keep_alive_connection_context(): | ||||
|     try: | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         client = ReuseableSanicTestClient(keep_alive_app_context, loop) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT) | ||||
|     with client: | ||||
|         headers = {"Connection": "keep-alive"} | ||||
|         request1, _ = client.post("/ctx", headers=headers) | ||||
|  | ||||
| @@ -315,5 +156,4 @@ def test_keep_alive_connection_context(): | ||||
|         assert ( | ||||
|             request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello" | ||||
|         ) | ||||
|     finally: | ||||
|         client.kill_server() | ||||
|         assert request2.protocol.state["requests_count"] == 2 | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import uuid | ||||
|  | ||||
| from importlib import reload | ||||
| from io import StringIO | ||||
| from unittest.mock import Mock | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| @@ -51,7 +52,7 @@ def test_log(app): | ||||
|  | ||||
| def test_logging_defaults(): | ||||
|     # reset_logging() | ||||
|     app = Sanic("test_logging") | ||||
|     Sanic("test_logging") | ||||
|  | ||||
|     for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: | ||||
|         assert ( | ||||
| @@ -87,7 +88,7 @@ def test_logging_pass_customer_logconfig(): | ||||
|         "format" | ||||
|     ] = "%(asctime)s - (%(name)s)[%(levelname)s]: %(message)s" | ||||
|  | ||||
|     app = Sanic("test_logging", log_config=modified_config) | ||||
|     Sanic("test_logging", log_config=modified_config) | ||||
|  | ||||
|     for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: | ||||
|         assert fmt._fmt == modified_config["formatters"]["generic"]["format"] | ||||
| @@ -111,11 +112,13 @@ def test_logging_pass_customer_logconfig(): | ||||
|     ), | ||||
| ) | ||||
| def test_log_connection_lost(app, debug, monkeypatch): | ||||
|     """ Should not log Connection lost exception on non debug """ | ||||
|     """Should not log Connection lost exception on non debug""" | ||||
|     stream = StringIO() | ||||
|     error = logging.getLogger("sanic.error") | ||||
|     error.addHandler(logging.StreamHandler(stream)) | ||||
|     monkeypatch.setattr(sanic.server, "error_logger", error) | ||||
|     monkeypatch.setattr( | ||||
|         sanic.server.protocols.http_protocol, "error_logger", error | ||||
|     ) | ||||
|  | ||||
|     @app.route("/conn_lost") | ||||
|     async def conn_lost(request): | ||||
| @@ -208,6 +211,56 @@ def test_logging_modified_root_logger_config(): | ||||
|     modified_config = LOGGING_CONFIG_DEFAULTS | ||||
|     modified_config["loggers"]["sanic.root"]["level"] = "DEBUG" | ||||
|  | ||||
|     app = Sanic("test_logging", log_config=modified_config) | ||||
|     Sanic("test_logging", log_config=modified_config) | ||||
|  | ||||
|     assert logging.getLogger("sanic.root").getEffectiveLevel() == logging.DEBUG | ||||
|  | ||||
|  | ||||
| def test_access_log_client_ip_remote_addr(monkeypatch): | ||||
|     access = Mock() | ||||
|     monkeypatch.setattr(sanic.http, "access_logger", access) | ||||
|  | ||||
|     app = Sanic("test_logging") | ||||
|     app.config.PROXIES_COUNT = 2 | ||||
|  | ||||
|     @app.route("/") | ||||
|     async def handler(request): | ||||
|         return text(request.remote_addr) | ||||
|  | ||||
|     headers = {"X-Forwarded-For": "1.1.1.1, 2.2.2.2"} | ||||
|  | ||||
|     request, response = app.test_client.get("/", headers=headers) | ||||
|  | ||||
|     assert request.remote_addr == "1.1.1.1" | ||||
|     access.info.assert_called_with( | ||||
|         "", | ||||
|         extra={ | ||||
|             "status": 200, | ||||
|             "byte": len(response.content), | ||||
|             "host": f"{request.remote_addr}:{request.port}", | ||||
|             "request": f"GET {request.scheme}://{request.host}/", | ||||
|         }, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def test_access_log_client_ip_reqip(monkeypatch): | ||||
|     access = Mock() | ||||
|     monkeypatch.setattr(sanic.http, "access_logger", access) | ||||
|  | ||||
|     app = Sanic("test_logging") | ||||
|  | ||||
|     @app.route("/") | ||||
|     async def handler(request): | ||||
|         return text(request.ip) | ||||
|  | ||||
|     request, response = app.test_client.get("/") | ||||
|  | ||||
|     access.info.assert_called_with( | ||||
|         "", | ||||
|         extra={ | ||||
|             "status": 200, | ||||
|             "byte": len(response.content), | ||||
|             "host": f"{request.ip}:{request.port}", | ||||
|             "request": f"GET {request.scheme}://{request.host}/", | ||||
|         }, | ||||
|     ) | ||||
|   | ||||
| @@ -6,85 +6,37 @@ from sanic_testing.testing import PORT | ||||
| from sanic.config import BASE_LOGO | ||||
|  | ||||
|  | ||||
| def test_logo_base(app, caplog): | ||||
|     server = app.create_server( | ||||
|         debug=True, return_asyncio_server=True, port=PORT | ||||
|     ) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     loop._stopping = False | ||||
| def test_logo_base(app, run_startup): | ||||
|     logs = run_startup(app) | ||||
|  | ||||
|     with caplog.at_level(logging.DEBUG): | ||||
|         _server = loop.run_until_complete(server) | ||||
|  | ||||
|     _server.close() | ||||
|     loop.run_until_complete(_server.wait_closed()) | ||||
|     app.stop() | ||||
|  | ||||
|     assert caplog.record_tuples[0][1] == logging.DEBUG | ||||
|     assert caplog.record_tuples[0][2] == BASE_LOGO | ||||
|     assert logs[0][1] == logging.DEBUG | ||||
|     assert logs[0][2] == BASE_LOGO | ||||
|  | ||||
|  | ||||
| def test_logo_false(app, caplog): | ||||
| def test_logo_false(app, caplog, run_startup): | ||||
|     app.config.LOGO = False | ||||
|  | ||||
|     server = app.create_server( | ||||
|         debug=True, return_asyncio_server=True, port=PORT | ||||
|     ) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     loop._stopping = False | ||||
|     logs = run_startup(app) | ||||
|  | ||||
|     with caplog.at_level(logging.DEBUG): | ||||
|         _server = loop.run_until_complete(server) | ||||
|  | ||||
|     _server.close() | ||||
|     loop.run_until_complete(_server.wait_closed()) | ||||
|     app.stop() | ||||
|  | ||||
|     banner, port = caplog.record_tuples[0][2].rsplit(":", 1) | ||||
|     assert caplog.record_tuples[0][1] == logging.INFO | ||||
|     banner, port = logs[0][2].rsplit(":", 1) | ||||
|     assert logs[0][1] == logging.INFO | ||||
|     assert banner == "Goin' Fast @ http://127.0.0.1" | ||||
|     assert int(port) > 0 | ||||
|  | ||||
|  | ||||
| def test_logo_true(app, caplog): | ||||
| def test_logo_true(app, run_startup): | ||||
|     app.config.LOGO = True | ||||
|  | ||||
|     server = app.create_server( | ||||
|         debug=True, return_asyncio_server=True, port=PORT | ||||
|     ) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     loop._stopping = False | ||||
|     logs = run_startup(app) | ||||
|  | ||||
|     with caplog.at_level(logging.DEBUG): | ||||
|         _server = loop.run_until_complete(server) | ||||
|  | ||||
|     _server.close() | ||||
|     loop.run_until_complete(_server.wait_closed()) | ||||
|     app.stop() | ||||
|  | ||||
|     assert caplog.record_tuples[0][1] == logging.DEBUG | ||||
|     assert caplog.record_tuples[0][2] == BASE_LOGO | ||||
|     assert logs[0][1] == logging.DEBUG | ||||
|     assert logs[0][2] == BASE_LOGO | ||||
|  | ||||
|  | ||||
| def test_logo_custom(app, caplog): | ||||
| def test_logo_custom(app, run_startup): | ||||
|     app.config.LOGO = "My Custom Logo" | ||||
|  | ||||
|     server = app.create_server( | ||||
|         debug=True, return_asyncio_server=True, port=PORT | ||||
|     ) | ||||
|     loop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     loop._stopping = False | ||||
|     logs = run_startup(app) | ||||
|  | ||||
|     with caplog.at_level(logging.DEBUG): | ||||
|         _server = loop.run_until_complete(server) | ||||
|  | ||||
|     _server.close() | ||||
|     loop.run_until_complete(_server.wait_closed()) | ||||
|     app.stop() | ||||
|  | ||||
|     assert caplog.record_tuples[0][1] == logging.DEBUG | ||||
|     assert caplog.record_tuples[0][2] == "My Custom Logo" | ||||
|     assert logs[0][1] == logging.DEBUG | ||||
|     assert logs[0][2] == "My Custom Logo" | ||||
|   | ||||
| @@ -5,7 +5,7 @@ from itertools import count | ||||
|  | ||||
| from sanic.exceptions import NotFound | ||||
| from sanic.request import Request | ||||
| from sanic.response import HTTPResponse, text | ||||
| from sanic.response import HTTPResponse, json, text | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------ # | ||||
| @@ -37,14 +37,19 @@ def test_middleware_request_as_convenience(app): | ||||
|     async def handler1(request): | ||||
|         results.append(request) | ||||
|  | ||||
|     @app.route("/") | ||||
|     @app.on_request() | ||||
|     async def handler2(request): | ||||
|         results.append(request) | ||||
|  | ||||
|     @app.route("/") | ||||
|     async def handler3(request): | ||||
|         return text("OK") | ||||
|  | ||||
|     request, response = app.test_client.get("/") | ||||
|  | ||||
|     assert response.text == "OK" | ||||
|     assert type(results[0]) is Request | ||||
|     assert type(results[1]) is Request | ||||
|  | ||||
|  | ||||
| def test_middleware_response(app): | ||||
| @@ -79,7 +84,12 @@ def test_middleware_response_as_convenience(app): | ||||
|         results.append(request) | ||||
|  | ||||
|     @app.on_response | ||||
|     async def process_response(request, response): | ||||
|     async def process_response_1(request, response): | ||||
|         results.append(request) | ||||
|         results.append(response) | ||||
|  | ||||
|     @app.on_response() | ||||
|     async def process_response_2(request, response): | ||||
|         results.append(request) | ||||
|         results.append(response) | ||||
|  | ||||
| @@ -93,6 +103,8 @@ def test_middleware_response_as_convenience(app): | ||||
|     assert type(results[0]) is Request | ||||
|     assert type(results[1]) is Request | ||||
|     assert isinstance(results[2], HTTPResponse) | ||||
|     assert type(results[3]) is Request | ||||
|     assert isinstance(results[4], HTTPResponse) | ||||
|  | ||||
|  | ||||
| def test_middleware_response_as_convenience_called(app): | ||||
| @@ -271,3 +283,17 @@ def test_request_middleware_executes_once(app): | ||||
|  | ||||
|     request, response = app.test_client.get("/") | ||||
|     assert next(i) == 3 | ||||
|  | ||||
|  | ||||
| def test_middleware_added_response(app): | ||||
|     @app.on_response | ||||
|     def display(_, response): | ||||
|         response["foo"] = "bar" | ||||
|         return json(response) | ||||
|  | ||||
|     @app.get("/") | ||||
|     async def handler(request): | ||||
|         return {} | ||||
|  | ||||
|     _, response = app.test_client.get("/") | ||||
|     assert response.json["foo"] == "bar" | ||||
|   | ||||
| @@ -140,3 +140,39 @@ def test_ipv6_address_is_not_wrapped(app): | ||||
|     assert resp.json["client"] == "[::1]" | ||||
|     assert resp.json["client_ip"] == "::1" | ||||
|     assert request.ip == "::1" | ||||
|  | ||||
|  | ||||
| def test_request_accept(): | ||||
|     app = Sanic("req-generator") | ||||
|  | ||||
|     @app.get("/") | ||||
|     async def get(request): | ||||
|         return response.empty() | ||||
|  | ||||
|     request, _ = app.test_client.get( | ||||
|         "/", | ||||
|         headers={ | ||||
|             "Accept": "text/*, text/plain, text/plain;format=flowed, */*" | ||||
|         }, | ||||
|     ) | ||||
|     assert request.accept == [ | ||||
|         "text/plain;format=flowed", | ||||
|         "text/plain", | ||||
|         "text/*", | ||||
|         "*/*", | ||||
|     ] | ||||
|  | ||||
|     request, _ = app.test_client.get( | ||||
|         "/", | ||||
|         headers={ | ||||
|             "Accept": ( | ||||
|                 "text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c" | ||||
|             ) | ||||
|         }, | ||||
|     ) | ||||
|     assert request.accept == [ | ||||
|         "text/html", | ||||
|         "text/x-c", | ||||
|         "text/x-dvi; q=0.8", | ||||
|         "text/plain; q=0.5", | ||||
|     ] | ||||
|   | ||||
| @@ -1,99 +0,0 @@ | ||||
| import asyncio | ||||
|  | ||||
| import httpcore | ||||
| import httpx | ||||
|  | ||||
| from sanic_testing.testing import SanicTestClient | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.response import text | ||||
|  | ||||
|  | ||||
| class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection): | ||||
|     async def arequest(self, *args, **kwargs): | ||||
|         await asyncio.sleep(2) | ||||
|         return await super().arequest(*args, **kwargs) | ||||
|  | ||||
|     async def _open_socket(self, *args, **kwargs): | ||||
|         retval = await super()._open_socket(*args, **kwargs) | ||||
|         if self._request_delay: | ||||
|             await asyncio.sleep(self._request_delay) | ||||
|         return retval | ||||
|  | ||||
|  | ||||
| class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool): | ||||
|     def __init__(self, request_delay=None, *args, **kwargs): | ||||
|         self._request_delay = request_delay | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     async def _add_to_pool(self, connection, timeout): | ||||
|         connection.__class__ = DelayableHTTPConnection | ||||
|         connection._request_delay = self._request_delay | ||||
|         await super()._add_to_pool(connection, timeout) | ||||
|  | ||||
|  | ||||
| class DelayableSanicSession(httpx.AsyncClient): | ||||
|     def __init__(self, request_delay=None, *args, **kwargs) -> None: | ||||
|         transport = DelayableSanicConnectionPool(request_delay=request_delay) | ||||
|         super().__init__(transport=transport, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class DelayableSanicTestClient(SanicTestClient): | ||||
|     def __init__(self, app, request_delay=None): | ||||
|         super().__init__(app) | ||||
|         self._request_delay = request_delay | ||||
|         self._loop = None | ||||
|  | ||||
|     def get_new_session(self): | ||||
|         return DelayableSanicSession(request_delay=self._request_delay) | ||||
|  | ||||
|  | ||||
| request_timeout_default_app = Sanic("test_request_timeout_default") | ||||
| request_no_timeout_app = Sanic("test_request_no_timeout") | ||||
| request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6 | ||||
| request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6 | ||||
|  | ||||
|  | ||||
| @request_timeout_default_app.route("/1") | ||||
| async def handler1(request): | ||||
|     return text("OK") | ||||
|  | ||||
|  | ||||
| @request_no_timeout_app.route("/1") | ||||
| async def handler2(request): | ||||
|     return text("OK") | ||||
|  | ||||
|  | ||||
| @request_timeout_default_app.websocket("/ws1") | ||||
| async def ws_handler1(request, ws): | ||||
|     await ws.send("OK") | ||||
|  | ||||
|  | ||||
| def test_default_server_error_request_timeout(): | ||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) | ||||
|     request, response = client.get("/1") | ||||
|     assert response.status == 408 | ||||
|     assert "Request Timeout" in response.text | ||||
|  | ||||
|  | ||||
| def test_default_server_error_request_dont_timeout(): | ||||
|     client = DelayableSanicTestClient(request_no_timeout_app, 0.2) | ||||
|     request, response = client.get("/1") | ||||
|     assert response.status == 200 | ||||
|     assert response.text == "OK" | ||||
|  | ||||
|  | ||||
| def test_default_server_error_websocket_request_timeout(): | ||||
|  | ||||
|     headers = { | ||||
|         "Upgrade": "websocket", | ||||
|         "Connection": "upgrade", | ||||
|         "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", | ||||
|         "Sec-WebSocket-Version": "13", | ||||
|     } | ||||
|  | ||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) | ||||
|     request, response = client.get("/ws1", headers=headers) | ||||
|  | ||||
|     assert response.status == 408 | ||||
|     assert "Request Timeout" in response.text | ||||
| @@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app): | ||||
| @pytest.mark.asyncio | ||||
| @pytest.mark.parametrize("url", ["/ws", "ws"]) | ||||
| async def test_websocket_route_asgi(app, url): | ||||
|     ev = asyncio.Event() | ||||
|     @app.after_server_start | ||||
|     async def setup_ev(app, _): | ||||
|         app.ctx.ev = asyncio.Event() | ||||
|  | ||||
|     @app.websocket(url) | ||||
|     async def handler(request, ws): | ||||
|         ev.set() | ||||
|         request.app.ctx.ev.set() | ||||
|  | ||||
|     request, response = await app.asgi_client.websocket(url) | ||||
|     assert ev.is_set() | ||||
|     @app.get("/ev") | ||||
|     async def check(request): | ||||
|         return json({"set": request.app.ctx.ev.is_set()}) | ||||
|  | ||||
|     _, response = await app.asgi_client.websocket(url) | ||||
|     _, response = await app.asgi_client.get("/") | ||||
|     assert response.json["set"] | ||||
|  | ||||
|  | ||||
| def test_websocket_route_with_subprotocols(app): | ||||
| @pytest.mark.parametrize( | ||||
|     "subprotocols,expected", | ||||
|     ( | ||||
|         (["one"], "one"), | ||||
|         (["three", "one"], "one"), | ||||
|         (["tree"], None), | ||||
|         (None, None), | ||||
|     ), | ||||
| ) | ||||
| def test_websocket_route_with_subprotocols(app, subprotocols, expected): | ||||
|     results = [] | ||||
|  | ||||
|     @app.websocket("/ws", subprotocols=["foo", "bar"]) | ||||
|     @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"]) | ||||
|     async def handler(request, ws): | ||||
|         results.append(ws.subprotocol) | ||||
|         nonlocal results | ||||
|         results = ws.subprotocol | ||||
|         assert ws.subprotocol is not None | ||||
|  | ||||
|     _, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"]) | ||||
|     assert response.opened is True | ||||
|     assert results == ["bar"] | ||||
|  | ||||
|     _, response = SanicTestClient(app).websocket( | ||||
|         "/ws", subprotocols=["bar", "foo"] | ||||
|         "/ws", subprotocols=subprotocols | ||||
|     ) | ||||
|     assert response.opened is True | ||||
|     assert results == ["bar", "bar"] | ||||
|  | ||||
|     _, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"]) | ||||
|     assert response.opened is True | ||||
|     assert results == ["bar", "bar", None] | ||||
|  | ||||
|     _, response = SanicTestClient(app).websocket("/ws") | ||||
|     assert response.opened is True | ||||
|     assert results == ["bar", "bar", None, None] | ||||
|     assert results == expected | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("strict_slashes", [True, False, None]) | ||||
|   | ||||
| @@ -8,7 +8,7 @@ import pytest | ||||
|  | ||||
| from sanic_testing.testing import HOST, PORT | ||||
|  | ||||
| from sanic.exceptions import InvalidUsage | ||||
| from sanic.exceptions import InvalidUsage, SanicException | ||||
|  | ||||
|  | ||||
| AVAILABLE_LISTENERS = [ | ||||
| @@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app): | ||||
|     async def init_db(app, loop): | ||||
|         app.db = MySanicDb() | ||||
|  | ||||
|     await app.create_server(debug=True, return_asyncio_server=True, port=PORT) | ||||
|     srv = await app.create_server( | ||||
|         debug=True, return_asyncio_server=True, port=PORT | ||||
|     ) | ||||
|     await srv.startup() | ||||
|     await srv.before_start() | ||||
|  | ||||
|     assert hasattr(app, "db") | ||||
|     assert isinstance(app.db, MySanicDb) | ||||
| @@ -157,14 +161,15 @@ def test_create_server_trigger_events(app): | ||||
|         serv_coro = app.create_server(return_asyncio_server=True, sock=sock) | ||||
|         serv_task = asyncio.ensure_future(serv_coro, loop=loop) | ||||
|         server = loop.run_until_complete(serv_task) | ||||
|         server.after_start() | ||||
|         loop.run_until_complete(server.startup()) | ||||
|         loop.run_until_complete(server.after_start()) | ||||
|         try: | ||||
|             loop.run_forever() | ||||
|         except KeyboardInterrupt as e: | ||||
|         except KeyboardInterrupt: | ||||
|             loop.stop() | ||||
|         finally: | ||||
|             # Run the on_stop function if provided | ||||
|             server.before_stop() | ||||
|             loop.run_until_complete(server.before_stop()) | ||||
|  | ||||
|             # Wait for server to close | ||||
|             close_task = server.close() | ||||
| @@ -174,5 +179,19 @@ def test_create_server_trigger_events(app): | ||||
|             signal.stopped = True | ||||
|             for connection in server.connections: | ||||
|                 connection.close_if_idle() | ||||
|             server.after_stop() | ||||
|             loop.run_until_complete(server.after_stop()) | ||||
|         assert flag1 and flag2 and flag3 | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_missing_startup_raises_exception(app): | ||||
|     @app.listener("before_server_start") | ||||
|     async def init_db(app, loop): | ||||
|         ... | ||||
|  | ||||
|     srv = await app.create_server( | ||||
|         debug=True, return_asyncio_server=True, port=PORT | ||||
|     ) | ||||
|  | ||||
|     with pytest.raises(SanicException): | ||||
|         await srv.before_start() | ||||
|   | ||||
| @@ -95,7 +95,7 @@ def test_windows_workaround(): | ||||
|         os.kill(os.getpid(), signal.SIGINT) | ||||
|         await asyncio.sleep(0.2) | ||||
|         assert app.is_stopping | ||||
|         assert app.stay_active_task.result() == None | ||||
|         assert app.stay_active_task.result() is None | ||||
|         # Second Ctrl+C should raise | ||||
|         with pytest.raises(KeyboardInterrupt): | ||||
|             os.kill(os.getpid(), signal.SIGINT) | ||||
|   | ||||
| @@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app): | ||||
|  | ||||
|     app.signal_router.finalize() | ||||
|  | ||||
|     assert len(app.signal_router.routes) == 3 | ||||
|     await app.dispatch("foo.bar.baz") | ||||
|     assert counter == 2 | ||||
|  | ||||
| @@ -331,7 +332,8 @@ def test_event_on_bp_not_registered(): | ||||
|     "event,expected", | ||||
|     ( | ||||
|         ("foo.bar.baz", True), | ||||
|         ("server.init.before", False), | ||||
|         ("server.init.before", True), | ||||
|         ("server.init.somethingelse", False), | ||||
|         ("http.request.start", False), | ||||
|         ("sanic.notice.anything", True), | ||||
|     ), | ||||
|   | ||||
| @@ -461,6 +461,22 @@ def test_nested_dir(app, static_file_directory): | ||||
|     assert response.text == "foo\n" | ||||
|  | ||||
|  | ||||
| def test_handle_is_a_directory_error(app, static_file_directory): | ||||
|     error_text = "Is a directory. Access denied" | ||||
|     app.static("/static", static_file_directory) | ||||
|  | ||||
|     @app.exception(Exception) | ||||
|     async def handleStaticDirError(request, exception): | ||||
|         if isinstance(exception, IsADirectoryError): | ||||
|             return text(error_text, status=403) | ||||
|         raise exception | ||||
|  | ||||
|     request, response = app.test_client.get("/static/") | ||||
|  | ||||
|     assert response.status == 403 | ||||
|     assert response.text == error_text | ||||
|  | ||||
|  | ||||
| def test_stack_trace_on_not_found(app, static_file_directory, caplog): | ||||
|     app.static("/static", static_file_directory) | ||||
|  | ||||
| @@ -507,3 +523,56 @@ def test_multiple_statics(app, static_file_directory): | ||||
|     assert response.body == get_file_content( | ||||
|         static_file_directory, "python.png" | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def test_resource_type_default(app, static_file_directory): | ||||
|     app.static("/static", static_file_directory) | ||||
|     app.static("/file", get_file_path(static_file_directory, "test.file")) | ||||
|  | ||||
|     _, response = app.test_client.get("/static") | ||||
|     assert response.status == 404 | ||||
|  | ||||
|     _, response = app.test_client.get("/file") | ||||
|     assert response.status == 200 | ||||
|     assert response.body == get_file_content( | ||||
|         static_file_directory, "test.file" | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def test_resource_type_file(app, static_file_directory): | ||||
|     app.static( | ||||
|         "/file", | ||||
|         get_file_path(static_file_directory, "test.file"), | ||||
|         resource_type="file", | ||||
|     ) | ||||
|  | ||||
|     _, response = app.test_client.get("/file") | ||||
|     assert response.status == 200 | ||||
|     assert response.body == get_file_content( | ||||
|         static_file_directory, "test.file" | ||||
|     ) | ||||
|  | ||||
|     with pytest.raises(TypeError): | ||||
|         app.static("/static", static_file_directory, resource_type="file") | ||||
|  | ||||
|  | ||||
| def test_resource_type_dir(app, static_file_directory): | ||||
|     app.static("/static", static_file_directory, resource_type="dir") | ||||
|  | ||||
|     _, response = app.test_client.get("/static/test.file") | ||||
|     assert response.status == 200 | ||||
|     assert response.body == get_file_content( | ||||
|         static_file_directory, "test.file" | ||||
|     ) | ||||
|  | ||||
|     with pytest.raises(TypeError): | ||||
|         app.static( | ||||
|             "/file", | ||||
|             get_file_path(static_file_directory, "test.file"), | ||||
|             resource_type="dir", | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def test_resource_type_unknown(app, static_file_directory, caplog): | ||||
|     with pytest.raises(ValueError): | ||||
|         app.static("/static", static_file_directory, resource_type="unknown") | ||||
|   | ||||
| @@ -26,6 +26,7 @@ def protocol(app, mock_transport): | ||||
|     protocol = HttpProtocol(loop=loop, app=app) | ||||
|     protocol.connection_made(mock_transport) | ||||
|     protocol._setup_connection() | ||||
|     protocol._http.init_for_request() | ||||
|     protocol._task = Mock(spec=asyncio.Task) | ||||
|     protocol._task.cancel = Mock() | ||||
|     return protocol | ||||
|   | ||||
							
								
								
									
										21
									
								
								tests/test_touchup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								tests/test_touchup.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| import logging | ||||
|  | ||||
| from sanic.signals import RESERVED_NAMESPACES | ||||
| from sanic.touchup import TouchUp | ||||
|  | ||||
|  | ||||
| def test_touchup_methods(app): | ||||
|     assert len(TouchUp._registry) == 9 | ||||
|  | ||||
|  | ||||
| async def test_ode_removes_dispatch_events(app, caplog): | ||||
|     with caplog.at_level(logging.DEBUG, logger="sanic.root"): | ||||
|         await app._startup() | ||||
|     logs = caplog.record_tuples | ||||
|  | ||||
|     for signal in RESERVED_NAMESPACES["http"]: | ||||
|         assert ( | ||||
|             "sanic.root", | ||||
|             logging.DEBUG, | ||||
|             f"Disabling event: {signal}", | ||||
|         ) in logs | ||||
| @@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def test_websocket_bp_route_name(app): | ||||
| @pytest.mark.parametrize( | ||||
|     "name,expected", | ||||
|     ( | ||||
|         ("test_route", "/bp/route"), | ||||
|         ("test_route2", "/bp/route2"), | ||||
|         ("foobar_3", "/bp/route3"), | ||||
|     ), | ||||
| ) | ||||
| def test_websocket_bp_route_name(app, name, expected): | ||||
|     """Tests that blueprint websocket route is named.""" | ||||
|     event = asyncio.Event() | ||||
|     bp = Blueprint("test_bp", url_prefix="/bp") | ||||
| @@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app): | ||||
|     uri = app.url_for("test_bp.main") | ||||
|     assert uri == "/bp/main" | ||||
|  | ||||
|     uri = app.url_for("test_bp.test_route") | ||||
|     assert uri == "/bp/route" | ||||
|     uri = app.url_for(f"test_bp.{name}") | ||||
|     assert uri == expected | ||||
|     request, response = SanicTestClient(app).websocket(uri) | ||||
|     assert response.opened is True | ||||
|     assert event.is_set() | ||||
|  | ||||
|     event.clear() | ||||
|     uri = app.url_for("test_bp.test_route2") | ||||
|     assert uri == "/bp/route2" | ||||
|     request, response = SanicTestClient(app).websocket(uri) | ||||
|     assert response.opened is True | ||||
|     assert event.is_set() | ||||
|  | ||||
|     uri = app.url_for("test_bp.foobar_3") | ||||
|     assert uri == "/bp/route3" | ||||
|  | ||||
|  | ||||
| # TODO: add test with a route with multiple hosts | ||||
| # TODO: add test with a route with _host in url_for | ||||
|   | ||||
| @@ -175,7 +175,7 @@ def test_worker_close(worker): | ||||
|     worker.wsgi = mock.Mock() | ||||
|     conn = mock.Mock() | ||||
|     conn.websocket = mock.Mock() | ||||
|     conn.websocket.close_connection = mock.Mock(wraps=_a_noop) | ||||
|     conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) | ||||
|     worker.connections = set([conn]) | ||||
|     worker.log = mock.Mock() | ||||
|     worker.loop = loop | ||||
| @@ -190,5 +190,5 @@ def test_worker_close(worker): | ||||
|     loop.run_until_complete(_close) | ||||
|  | ||||
|     assert worker.signal.stopped | ||||
|     assert conn.websocket.close_connection.called | ||||
|     assert conn.websocket.fail_connection.called | ||||
|     assert len(worker.servers) == 0 | ||||
|   | ||||
							
								
								
									
										55
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										55
									
								
								tox.ini
									
									
									
									
									
								
							| @@ -2,53 +2,28 @@ | ||||
| envlist = py37, py38, py39, pyNightly, pypy37, {py37,py38,py39,pyNightly,pypy37}-no-ext, lint, check, security, docs, type-checking | ||||
|  | ||||
| [testenv] | ||||
| usedevelop = True | ||||
| usedevelop = true | ||||
| setenv = | ||||
|     {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UJSON=1 | ||||
|     {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 | ||||
| deps = | ||||
|     sanic-testing>=0.6.0 | ||||
|     coverage==5.3 | ||||
|     pytest==5.2.1 | ||||
|     pytest-cov | ||||
|     pytest-sanic | ||||
|     pytest-sugar | ||||
|     pytest-benchmark | ||||
|     chardet==3.* | ||||
|     beautifulsoup4 | ||||
|     gunicorn==20.0.4 | ||||
|     uvicorn | ||||
|     websockets>=9.0 | ||||
| extras = test | ||||
| commands = | ||||
|     pytest {posargs:tests --cov sanic} | ||||
|     - coverage combine --append | ||||
|     coverage report -m | ||||
|     coverage report -m -i | ||||
|     coverage html -i | ||||
|  | ||||
| [testenv:lint] | ||||
| deps = | ||||
|     flake8 | ||||
|     black | ||||
|     isort>=5.0.0 | ||||
|     bandit | ||||
|  | ||||
| commands = | ||||
|     flake8 sanic | ||||
|     black --config ./.black.toml --check --verbose sanic/ | ||||
|     isort --check-only sanic --profile=black | ||||
|  | ||||
| [testenv:type-checking] | ||||
| deps = | ||||
|     mypy>=0.901 | ||||
|     types-ujson | ||||
|  | ||||
| commands = | ||||
|     mypy sanic | ||||
|  | ||||
| [testenv:check] | ||||
| deps = | ||||
|     docutils | ||||
|     pygments | ||||
| commands = | ||||
|     python setup.py check -r -s | ||||
|  | ||||
| @@ -60,8 +35,6 @@ markers = | ||||
|     asyncio | ||||
|  | ||||
| [testenv:security] | ||||
| deps = | ||||
|     bandit | ||||
|  | ||||
| commands = | ||||
|     bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py | ||||
| @@ -69,30 +42,10 @@ commands = | ||||
| [testenv:docs] | ||||
| platform = linux|linux2|darwin | ||||
| whitelist_externals = make | ||||
| deps = | ||||
|     sphinx>=2.1.2 | ||||
|     sphinx_rtd_theme>=0.4.3 | ||||
|     recommonmark>=0.5.0 | ||||
|     docutils | ||||
|     pygments | ||||
|     gunicorn==20.0.4 | ||||
| extras = docs | ||||
| commands = | ||||
|     make docs-test | ||||
|  | ||||
| [testenv:coverage] | ||||
| usedevelop = True | ||||
| deps = | ||||
|     sanic-testing>=0.6.0 | ||||
|     coverage==5.3 | ||||
|     pytest==5.2.1 | ||||
|     pytest-cov | ||||
|     pytest-sanic | ||||
|     pytest-sugar | ||||
|     pytest-benchmark | ||||
|     chardet==3.* | ||||
|     beautifulsoup4 | ||||
|     gunicorn==20.0.4 | ||||
|     uvicorn | ||||
|     websockets>=9.0 | ||||
| commands = | ||||
|     pytest tests --cov=./sanic --cov-report=xml | ||||
|   | ||||
		Reference in New Issue
	
	Block a user