Compare commits
	
		
			52 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 5e12edbc38 | ||
|   | 50a606adee | ||
|   | f995612073 | ||
|   | bc08383acd | ||
|   | b83a1a184c | ||
|   | 59dd6814f8 | ||
|   | f7abf3db1b | ||
|   | cf1d2148ac | ||
|   | b5f2bd9b0e | ||
|   | ba2670e99c | ||
|   | 6ffc4d9756 | ||
|   | 595d2c76ac | ||
|   | d9796e9b1e | ||
|   | 404c5f9f9e | ||
|   | a937e08ef0 | ||
|   | ef4f058a6c | ||
|   | 69c5dde9bf | ||
|   | 945885d501 | ||
|   | 9d0b54c90d | ||
|   | 2e5c288fea | ||
|   | f32ef20b74 | ||
|   | e2eefaac55 | ||
|   | e1cfbf0fd9 | ||
|   | 08c5689441 | ||
|   | 8dbda247d6 | ||
|   | 71a631237d | ||
|   | e22ff3828b | ||
|   | b1b12e004e | ||
|   | 5308fec354 | ||
|   | 0ba57d4701 | ||
|   | 54ca6a6178 | ||
|   | 7dd4a78cf2 | ||
|   | 52ff49512a | ||
|   | 5a48b94089 | ||
|   | ba1c73d947 | ||
|   | 4732b6bdfa | ||
|   | a6e78b70ab | ||
|   | bb1174afc5 | ||
|   | df8abe9cfd | ||
|   | c3bca97ee1 | ||
|   | c3b6fa1bba | ||
|   | 94d496afe1 | ||
|   | 7b7a572f9b | ||
|   | 1b8cb742f9 | ||
|   | 3492d180a8 | ||
|   | 021da38373 | ||
|   | ac784759d5 | ||
|   | 36eda2cd62 | ||
|   | 08a4b3013f | ||
|   | 1dd0332e8b | ||
|   | a90877ac31 | ||
|   | 8b7ea27a48 | 
| @@ -10,3 +10,15 @@ exclude_patterns: | |||||||
|   - "examples/" |   - "examples/" | ||||||
|   - "hack/" |   - "hack/" | ||||||
|   - "scripts/" |   - "scripts/" | ||||||
|  |   - "tests/" | ||||||
|  | checks: | ||||||
|  |   argument-count: | ||||||
|  |     enabled: false | ||||||
|  |   file-lines: | ||||||
|  |     config: | ||||||
|  |       threshold: 1000 | ||||||
|  |   method-count: | ||||||
|  |     config: | ||||||
|  |       threshold: 40 | ||||||
|  |   complex-logic: | ||||||
|  |     enabled: false | ||||||
|   | |||||||
							
								
								
									
										62
									
								
								.github/workflows/pr-windows.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/pr-windows.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,34 +1,34 @@ | |||||||
| # name: Run Unit Tests on Windows | name: Run Unit Tests on Windows | ||||||
| # on: | on: | ||||||
| #   pull_request: |   pull_request: | ||||||
| #     branches: |     branches: | ||||||
| #       - main |       - main | ||||||
|  |  | ||||||
| # jobs: | jobs: | ||||||
| #   testsOnWindows: |   testsOnWindows: | ||||||
| #     name: ut-${{ matrix.config.tox-env }} |     name: ut-${{ matrix.config.tox-env }} | ||||||
| #     runs-on: windows-latest |     runs-on: windows-latest | ||||||
| #     strategy: |     strategy: | ||||||
| #       fail-fast: false |       fail-fast: false | ||||||
| #       matrix: |       matrix: | ||||||
| #         config: |         config: | ||||||
| #           - { python-version: 3.7, tox-env: py37-no-ext } |           - { python-version: 3.7, tox-env: py37-no-ext } | ||||||
| #           - { python-version: 3.8, tox-env: py38-no-ext } |           - { python-version: 3.8, tox-env: py38-no-ext } | ||||||
| #           - { python-version: 3.9, tox-env: py39-no-ext } |           - { python-version: 3.9, tox-env: py39-no-ext } | ||||||
| #           - { python-version: pypy-3.7, tox-env: pypy37-no-ext } |           - { python-version: pypy-3.7, tox-env: pypy37-no-ext } | ||||||
|  |  | ||||||
| #     steps: |     steps: | ||||||
| #       - name: Checkout Repository |       - name: Checkout Repository | ||||||
| #         uses: actions/checkout@v2 |         uses: actions/checkout@v2 | ||||||
|  |  | ||||||
| #       - name: Run Unit Tests |       - name: Run Unit Tests | ||||||
| #         uses: ahopkins/custom-actions@pip-extra-args |         uses: ahopkins/custom-actions@pip-extra-args | ||||||
| #         with: |         with: | ||||||
| #           python-version: ${{ matrix.config.python-version }} |           python-version: ${{ matrix.config.python-version }} | ||||||
| #           test-infra-tool: tox |           test-infra-tool: tox | ||||||
| #           test-infra-version: latest |           test-infra-version: latest | ||||||
| #           action: tests |           action: tests | ||||||
| #           test-additional-args: "-e=${{ matrix.config.tox-env }}" |           test-additional-args: "-e=${{ matrix.config.tox-env }}" | ||||||
| #           experimental-ignore-error: "true" |           experimental-ignore-error: "true" | ||||||
| #           command-timeout: "600000" |           command-timeout: "600000" | ||||||
| #           pip-extra-args: "--user" |           pip-extra-args: "--user" | ||||||
|   | |||||||
| @@ -1,3 +1,22 @@ | |||||||
|  | .. note:: | ||||||
|  |  | ||||||
|  |   From v21.9, CHANGELOG files are maintained in ``./docs/sanic/releases`` | ||||||
|  |  | ||||||
|  | Version 21.6.1 | ||||||
|  | -------------- | ||||||
|  |  | ||||||
|  | Bugfixes | ||||||
|  | ******** | ||||||
|  |  | ||||||
|  |   * `#2178 <https://github.com/sanic-org/sanic/pull/2178>`_ | ||||||
|  |     Update sanic-routing to allow for better splitting of complex URI templates | ||||||
|  |   * `#2183 <https://github.com/sanic-org/sanic/pull/2183>`_ | ||||||
|  |     Proper handling of chunked request bodies to resolve phantom 503 in logs | ||||||
|  |   * `#2181 <https://github.com/sanic-org/sanic/pull/2181>`_ | ||||||
|  |     Resolve regression in exception logging | ||||||
|  |   * `#2201 <https://github.com/sanic-org/sanic/pull/2201>`_ | ||||||
|  |     Cleanup request info in pipelined requests | ||||||
|  |  | ||||||
| Version 21.6.0 | Version 21.6.0 | ||||||
| -------------- | -------------- | ||||||
|  |  | ||||||
|   | |||||||
| @@ -19,7 +19,7 @@ a virtual environment already set up, then run: | |||||||
|  |  | ||||||
| .. code-block:: bash | .. code-block:: bash | ||||||
|  |  | ||||||
|    pip3 install -e . ".[dev]" |    pip install -e ".[dev]" | ||||||
|  |  | ||||||
| Dependency Changes | Dependency Changes | ||||||
| ------------------ | ------------------ | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								README.rst
									
									
									
									
									
								
							| @@ -77,17 +77,7 @@ The goal of the project is to provide a simple way to get up and running a highl | |||||||
| Sponsor | Sponsor | ||||||
| ------- | ------- | ||||||
|  |  | ||||||
| |Try CodeStream| | Check out `open collective <https://opencollective.com/sanic-org>`_ to learn more about helping to fund Sanic. | ||||||
|  |  | ||||||
| .. |Try CodeStream| image:: https://alt-images.codestream.com/codestream_logo_sanicorg.png |  | ||||||
|    :target: https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner |  | ||||||
|    :alt: Try CodeStream |  | ||||||
|  |  | ||||||
| Manage pull requests and conduct code reviews in your IDE with full source-tree context. Comment on any line, not just the diffs. Use jump-to-definition, your favorite keybindings, and code intelligence with more of your workflow. |  | ||||||
|  |  | ||||||
| `Learn More <https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner>`_ |  | ||||||
|  |  | ||||||
| Thank you to our sponsor. Check out `open collective <https://opencollective.com/sanic-org>`_ to learn more about helping to fund Sanic. |  | ||||||
|  |  | ||||||
| Installation | Installation | ||||||
| ------------ | ------------ | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								docs/conf.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								docs/conf.py
									
									
									
									
									
								
							| @@ -10,10 +10,8 @@ | |||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
|  |  | ||||||
| # Add support for auto-doc |  | ||||||
| import recommonmark |  | ||||||
|  |  | ||||||
| from recommonmark.transform import AutoStructify | # Add support for auto-doc | ||||||
|  |  | ||||||
|  |  | ||||||
| # Ensure that sanic is present in the path, to allow sphinx-apidoc to | # Ensure that sanic is present in the path, to allow sphinx-apidoc to | ||||||
| @@ -26,7 +24,7 @@ import sanic | |||||||
|  |  | ||||||
| # -- General configuration ------------------------------------------------ | # -- General configuration ------------------------------------------------ | ||||||
|  |  | ||||||
| extensions = ["sphinx.ext.autodoc", "recommonmark"] | extensions = ["sphinx.ext.autodoc", "m2r2"] | ||||||
|  |  | ||||||
| templates_path = ["_templates"] | templates_path = ["_templates"] | ||||||
|  |  | ||||||
| @@ -162,20 +160,6 @@ autodoc_default_options = { | |||||||
|     "member-order": "groupwise", |     "member-order": "groupwise", | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| # app setup hook |  | ||||||
| def setup(app): |  | ||||||
|     app.add_config_value( |  | ||||||
|         "recommonmark_config", |  | ||||||
|         { |  | ||||||
|             "enable_eval_rst": True, |  | ||||||
|             "enable_auto_doc_ref": False, |  | ||||||
|         }, |  | ||||||
|         True, |  | ||||||
|     ) |  | ||||||
|     app.add_transform(AutoStructify) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| html_theme_options = { | html_theme_options = { | ||||||
|     "style_external_links": False, |     "style_external_links": False, | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| 📜 Changelog | 📜 Changelog | ||||||
| ============ | ============ | ||||||
|  |  | ||||||
|  | .. mdinclude:: ./releases/21.9.md | ||||||
|  |  | ||||||
| .. include:: ../../CHANGELOG.rst | .. include:: ../../CHANGELOG.rst | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| ♥️ Contributing | ♥️ Contributing | ||||||
| ============== | =============== | ||||||
|  |  | ||||||
| .. include:: ../../CONTRIBUTING.rst | .. include:: ../../CONTRIBUTING.rst | ||||||
|   | |||||||
							
								
								
									
										40
									
								
								docs/sanic/releases/21.9.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								docs/sanic/releases/21.9.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | |||||||
|  | ## Version 21.9 | ||||||
|  |  | ||||||
|  | ### Features | ||||||
|  | - [#2158](https://github.com/sanic-org/sanic/pull/2158), [#2248](https://github.com/sanic-org/sanic/pull/2248) Complete overhaul of I/O to websockets | ||||||
|  | - [#2160](https://github.com/sanic-org/sanic/pull/2160) Add new 17 signals into server and request lifecycles | ||||||
|  | - [#2162](https://github.com/sanic-org/sanic/pull/2162) Smarter `auto` fallback formatting upon exception | ||||||
|  | - [#2184](https://github.com/sanic-org/sanic/pull/2184) Introduce implementation for copying a Blueprint | ||||||
|  | - [#2200](https://github.com/sanic-org/sanic/pull/2200) Accept header parsing | ||||||
|  | - [#2207](https://github.com/sanic-org/sanic/pull/2207) Log remote address if available | ||||||
|  | - [#2209](https://github.com/sanic-org/sanic/pull/2209) Add convenience methods to BP groups | ||||||
|  | - [#2216](https://github.com/sanic-org/sanic/pull/2216) Add default messages to SanicExceptions | ||||||
|  | - [#2225](https://github.com/sanic-org/sanic/pull/2225) Type annotation convenience for annotated handlers with path parameters | ||||||
|  | - [#2236](https://github.com/sanic-org/sanic/pull/2236) Allow Falsey (but not-None) responses from route handlers | ||||||
|  | - [#2238](https://github.com/sanic-org/sanic/pull/2238) Add `exception` decorator to Blueprint Groups | ||||||
|  | - [#2244](https://github.com/sanic-org/sanic/pull/2244) Explicit static directive for serving file or dir (ex: `static(..., resource_type="file")`) | ||||||
|  | - [#2245](https://github.com/sanic-org/sanic/pull/2245) Close HTTP loop when connection task cancelled | ||||||
|  |  | ||||||
|  | ### Bugfixes | ||||||
|  | - [#2188](https://github.com/sanic-org/sanic/pull/2188) Fix the handling of the end of a chunked request | ||||||
|  | - [#2195](https://github.com/sanic-org/sanic/pull/2195) Resolve unexpected error handling on static requests | ||||||
|  | - [#2208](https://github.com/sanic-org/sanic/pull/2208) Make blueprint-based exceptions attach and trigger in a more intuitive manner | ||||||
|  | - [#2211](https://github.com/sanic-org/sanic/pull/2211) Fixed for handling exceptions of asgi app call | ||||||
|  | - [#2213](https://github.com/sanic-org/sanic/pull/2213) Fix bug where ws exceptions not being logged | ||||||
|  | - [#2231](https://github.com/sanic-org/sanic/pull/2231) Cleaner closing of tasks by using `abort()` in strategic places to avoid dangling sockets | ||||||
|  | - [#2247](https://github.com/sanic-org/sanic/pull/2247) Fix logging of auto-reload status in debug mode | ||||||
|  | - [#2246](https://github.com/sanic-org/sanic/pull/2246) Account for BP with exception handler but no routes | ||||||
|  |  | ||||||
|  | ### Developer infrastructure   | ||||||
|  | - [#2194](https://github.com/sanic-org/sanic/pull/2194) HTTP unit tests with raw client | ||||||
|  | - [#2199](https://github.com/sanic-org/sanic/pull/2199) Switch to codeclimate | ||||||
|  | - [#2214](https://github.com/sanic-org/sanic/pull/2214) Try Reopening Windows Tests | ||||||
|  | - [#2229](https://github.com/sanic-org/sanic/pull/2229) Refactor `HttpProtocol` into a base class | ||||||
|  | - [#2230](https://github.com/sanic-org/sanic/pull/2230) Refactor `server.py` into multi-file module | ||||||
|  |  | ||||||
|  | ### Miscellaneous | ||||||
|  | - [#2173](https://github.com/sanic-org/sanic/pull/2173) Remove Duplicated Dependencies and PEP 517 Support  | ||||||
|  | - [#2193](https://github.com/sanic-org/sanic/pull/2193), [#2196](https://github.com/sanic-org/sanic/pull/2196), [#2217](https://github.com/sanic-org/sanic/pull/2217) Type annotation changes | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -1,29 +1,44 @@ | |||||||
| from sanic import Sanic |  | ||||||
| from sanic import response |  | ||||||
| from signal import signal, SIGINT |  | ||||||
| import asyncio | import asyncio | ||||||
|  |  | ||||||
|  | from signal import SIGINT, signal | ||||||
|  |  | ||||||
| import uvloop | import uvloop | ||||||
|  |  | ||||||
|  | from sanic import Sanic, response | ||||||
|  | from sanic.server import AsyncioServer | ||||||
|  |  | ||||||
|  |  | ||||||
| app = Sanic(__name__) | app = Sanic(__name__) | ||||||
|  |  | ||||||
| @app.listener('after_server_start') |  | ||||||
|  | @app.listener("after_server_start") | ||||||
| async def after_start_test(app, loop): | async def after_start_test(app, loop): | ||||||
|     print("Async Server Started!") |     print("Async Server Started!") | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.route("/") | @app.route("/") | ||||||
| async def test(request): | async def test(request): | ||||||
|     return response.json({"answer": "42"}) |     return response.json({"answer": "42"}) | ||||||
|  |  | ||||||
|  |  | ||||||
| asyncio.set_event_loop(uvloop.new_event_loop()) | asyncio.set_event_loop(uvloop.new_event_loop()) | ||||||
| serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True) | serv_coro = app.create_server( | ||||||
|  |     host="0.0.0.0", port=8000, return_asyncio_server=True | ||||||
|  | ) | ||||||
| loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | ||||||
| serv_task = asyncio.ensure_future(serv_coro, loop=loop) | serv_task = asyncio.ensure_future(serv_coro, loop=loop) | ||||||
| signal(SIGINT, lambda s, f: loop.stop()) | signal(SIGINT, lambda s, f: loop.stop()) | ||||||
| server = loop.run_until_complete(serv_task) | server: AsyncioServer = loop.run_until_complete(serv_task)  # type: ignore | ||||||
|  | server.startup() | ||||||
|  |  | ||||||
|  | # When using app.run(), this actually triggers before the serv_coro. | ||||||
|  | # But, in this example, we are using the convenience method, even if it is | ||||||
|  | # out of order. | ||||||
|  | server.before_start() | ||||||
| server.after_start() | server.after_start() | ||||||
| try: | try: | ||||||
|     loop.run_forever() |     loop.run_forever() | ||||||
| except KeyboardInterrupt as e: | except KeyboardInterrupt: | ||||||
|     loop.stop() |     loop.stop() | ||||||
| finally: | finally: | ||||||
|     server.before_stop() |     server.before_stop() | ||||||
|   | |||||||
| @@ -1,13 +1,14 @@ | |||||||
| from sanic import Sanic | from sanic import Sanic | ||||||
| from sanic.response import file | from sanic.response import redirect | ||||||
|  |  | ||||||
| app = Sanic(__name__) | app = Sanic(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.route('/') | app.static('index.html', "websocket.html") | ||||||
| async def index(request): |  | ||||||
|     return await file('websocket.html') |  | ||||||
|  |  | ||||||
|  | @app.route('/') | ||||||
|  | def index(request): | ||||||
|  |     return redirect("index.html") | ||||||
|  |  | ||||||
| @app.websocket('/feed') | @app.websocket('/feed') | ||||||
| async def feed(request, ws): | async def feed(request, ws): | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | [build-system] | ||||||
|  | requires = ["setuptools", "wheel"] | ||||||
|  | build-backend = "setuptools.build_meta" | ||||||
| @@ -1 +1 @@ | |||||||
| __version__ = "21.6.0" | __version__ = "21.9.1" | ||||||
|   | |||||||
							
								
								
									
										314
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										314
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -1,9 +1,12 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
| import logging | import logging | ||||||
| import logging.config | import logging.config | ||||||
| import os | import os | ||||||
| import re | import re | ||||||
|  |  | ||||||
| from asyncio import ( | from asyncio import ( | ||||||
|  |     AbstractEventLoop, | ||||||
|     CancelledError, |     CancelledError, | ||||||
|     Protocol, |     Protocol, | ||||||
|     ensure_future, |     ensure_future, | ||||||
| @@ -21,6 +24,7 @@ from traceback import format_exc | |||||||
| from types import SimpleNamespace | from types import SimpleNamespace | ||||||
| from typing import ( | from typing import ( | ||||||
|     Any, |     Any, | ||||||
|  |     AnyStr, | ||||||
|     Awaitable, |     Awaitable, | ||||||
|     Callable, |     Callable, | ||||||
|     Coroutine, |     Coroutine, | ||||||
| @@ -30,6 +34,7 @@ from typing import ( | |||||||
|     List, |     List, | ||||||
|     Optional, |     Optional, | ||||||
|     Set, |     Set, | ||||||
|  |     Tuple, | ||||||
|     Type, |     Type, | ||||||
|     Union, |     Union, | ||||||
| ) | ) | ||||||
| @@ -69,20 +74,29 @@ from sanic.router import Router | |||||||
| from sanic.server import AsyncioServer, HttpProtocol | from sanic.server import AsyncioServer, HttpProtocol | ||||||
| from sanic.server import Signal as ServerSignal | from sanic.server import Signal as ServerSignal | ||||||
| from sanic.server import serve, serve_multiple, serve_single | from sanic.server import serve, serve_multiple, serve_single | ||||||
|  | from sanic.server.protocols.websocket_protocol import WebSocketProtocol | ||||||
|  | from sanic.server.websockets.impl import ConnectionClosed | ||||||
| from sanic.signals import Signal, SignalRouter | from sanic.signals import Signal, SignalRouter | ||||||
| from sanic.websocket import ConnectionClosed, WebSocketProtocol | from sanic.touchup import TouchUp, TouchUpMeta | ||||||
|  |  | ||||||
|  |  | ||||||
| class Sanic(BaseSanic): | class Sanic(BaseSanic, metaclass=TouchUpMeta): | ||||||
|     """ |     """ | ||||||
|     The main application instance |     The main application instance | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     __touchup__ = ( | ||||||
|  |         "handle_request", | ||||||
|  |         "handle_exception", | ||||||
|  |         "_run_response_middleware", | ||||||
|  |         "_run_request_middleware", | ||||||
|  |     ) | ||||||
|     __fake_slots__ = ( |     __fake_slots__ = ( | ||||||
|         "_asgi_app", |         "_asgi_app", | ||||||
|         "_app_registry", |         "_app_registry", | ||||||
|         "_asgi_client", |         "_asgi_client", | ||||||
|         "_blueprint_order", |         "_blueprint_order", | ||||||
|  |         "_delayed_tasks", | ||||||
|         "_future_routes", |         "_future_routes", | ||||||
|         "_future_statics", |         "_future_statics", | ||||||
|         "_future_middleware", |         "_future_middleware", | ||||||
| @@ -137,7 +151,7 @@ class Sanic(BaseSanic): | |||||||
|         log_config: Optional[Dict[str, Any]] = None, |         log_config: Optional[Dict[str, Any]] = None, | ||||||
|         configure_logging: bool = True, |         configure_logging: bool = True, | ||||||
|         register: Optional[bool] = None, |         register: Optional[bool] = None, | ||||||
|         dumps: Optional[Callable[..., str]] = None, |         dumps: Optional[Callable[..., AnyStr]] = None, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         super().__init__(name=name) |         super().__init__(name=name) | ||||||
|  |  | ||||||
| @@ -153,6 +167,7 @@ class Sanic(BaseSanic): | |||||||
|  |  | ||||||
|         self._asgi_client = None |         self._asgi_client = None | ||||||
|         self._blueprint_order: List[Blueprint] = [] |         self._blueprint_order: List[Blueprint] = [] | ||||||
|  |         self._delayed_tasks: List[str] = [] | ||||||
|         self._test_client = None |         self._test_client = None | ||||||
|         self._test_manager = None |         self._test_manager = None | ||||||
|         self.asgi = False |         self.asgi = False | ||||||
| @@ -164,7 +179,9 @@ class Sanic(BaseSanic): | |||||||
|         self.configure_logging = configure_logging |         self.configure_logging = configure_logging | ||||||
|         self.ctx = ctx or SimpleNamespace() |         self.ctx = ctx or SimpleNamespace() | ||||||
|         self.debug = None |         self.debug = None | ||||||
|         self.error_handler = error_handler or ErrorHandler() |         self.error_handler = error_handler or ErrorHandler( | ||||||
|  |             fallback=self.config.FALLBACK_ERROR_FORMAT, | ||||||
|  |         ) | ||||||
|         self.is_running = False |         self.is_running = False | ||||||
|         self.is_stopping = False |         self.is_stopping = False | ||||||
|         self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) |         self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) | ||||||
| @@ -190,9 +207,10 @@ class Sanic(BaseSanic): | |||||||
|             self.__class__.register_app(self) |             self.__class__.register_app(self) | ||||||
|  |  | ||||||
|         self.router.ctx.app = self |         self.router.ctx.app = self | ||||||
|  |         self.signal_router.ctx.app = self | ||||||
|  |  | ||||||
|         if dumps: |         if dumps: | ||||||
|             BaseHTTPResponse._dumps = dumps |             BaseHTTPResponse._dumps = dumps  # type: ignore | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def loop(self): |     def loop(self): | ||||||
| @@ -230,9 +248,12 @@ class Sanic(BaseSanic): | |||||||
|             loop = self.loop  # Will raise SanicError if loop is not started |             loop = self.loop  # Will raise SanicError if loop is not started | ||||||
|             self._loop_add_task(task, self, loop) |             self._loop_add_task(task, self, loop) | ||||||
|         except SanicException: |         except SanicException: | ||||||
|             self.listener("before_server_start")( |             task_name = f"sanic.delayed_task.{hash(task)}" | ||||||
|                 partial(self._loop_add_task, task) |             if not self._delayed_tasks: | ||||||
|             ) |                 self.after_server_start(partial(self.dispatch_delayed_tasks)) | ||||||
|  |  | ||||||
|  |             self.signal(task_name)(partial(self.run_delayed_task, task=task)) | ||||||
|  |             self._delayed_tasks.append(task_name) | ||||||
|  |  | ||||||
|     def register_listener(self, listener: Callable, event: str) -> Any: |     def register_listener(self, listener: Callable, event: str) -> Any: | ||||||
|         """ |         """ | ||||||
| @@ -244,12 +265,20 @@ class Sanic(BaseSanic): | |||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             _event = ListenerEvent(event) |             _event = ListenerEvent[event.upper()] | ||||||
|         except ValueError: |         except (ValueError, AttributeError): | ||||||
|             valid = ", ".join(ListenerEvent.__members__.values()) |             valid = ", ".join( | ||||||
|  |                 map(lambda x: x.lower(), ListenerEvent.__members__.keys()) | ||||||
|  |             ) | ||||||
|             raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") |             raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") | ||||||
|  |  | ||||||
|         self.listeners[_event].append(listener) |         if "." in _event: | ||||||
|  |             self.signal(_event.value)( | ||||||
|  |                 partial(self._listener, listener=listener) | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             self.listeners[_event.value].append(listener) | ||||||
|  |  | ||||||
|         return listener |         return listener | ||||||
|  |  | ||||||
|     def register_middleware(self, middleware, attach_to: str = "request"): |     def register_middleware(self, middleware, attach_to: str = "request"): | ||||||
| @@ -308,7 +337,11 @@ class Sanic(BaseSanic): | |||||||
|                     self.named_response_middleware[_rn].appendleft(middleware) |                     self.named_response_middleware[_rn].appendleft(middleware) | ||||||
|         return middleware |         return middleware | ||||||
|  |  | ||||||
|     def _apply_exception_handler(self, handler: FutureException): |     def _apply_exception_handler( | ||||||
|  |         self, | ||||||
|  |         handler: FutureException, | ||||||
|  |         route_names: Optional[List[str]] = None, | ||||||
|  |     ): | ||||||
|         """Decorate a function to be registered as a handler for exceptions |         """Decorate a function to be registered as a handler for exceptions | ||||||
|  |  | ||||||
|         :param exceptions: exceptions |         :param exceptions: exceptions | ||||||
| @@ -318,9 +351,9 @@ class Sanic(BaseSanic): | |||||||
|         for exception in handler.exceptions: |         for exception in handler.exceptions: | ||||||
|             if isinstance(exception, (tuple, list)): |             if isinstance(exception, (tuple, list)): | ||||||
|                 for e in exception: |                 for e in exception: | ||||||
|                     self.error_handler.add(e, handler.handler) |                     self.error_handler.add(e, handler.handler, route_names) | ||||||
|             else: |             else: | ||||||
|                 self.error_handler.add(exception, handler.handler) |                 self.error_handler.add(exception, handler.handler, route_names) | ||||||
|         return handler.handler |         return handler.handler | ||||||
|  |  | ||||||
|     def _apply_listener(self, listener: FutureListener): |     def _apply_listener(self, listener: FutureListener): | ||||||
| @@ -377,11 +410,17 @@ class Sanic(BaseSanic): | |||||||
|         *, |         *, | ||||||
|         condition: Optional[Dict[str, str]] = None, |         condition: Optional[Dict[str, str]] = None, | ||||||
|         context: Optional[Dict[str, Any]] = None, |         context: Optional[Dict[str, Any]] = None, | ||||||
|  |         fail_not_found: bool = True, | ||||||
|  |         inline: bool = False, | ||||||
|  |         reverse: bool = False, | ||||||
|     ) -> Coroutine[Any, Any, Awaitable[Any]]: |     ) -> Coroutine[Any, Any, Awaitable[Any]]: | ||||||
|         return self.signal_router.dispatch( |         return self.signal_router.dispatch( | ||||||
|             event, |             event, | ||||||
|             context=context, |             context=context, | ||||||
|             condition=condition, |             condition=condition, | ||||||
|  |             inline=inline, | ||||||
|  |             reverse=reverse, | ||||||
|  |             fail_not_found=fail_not_found, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     async def event( |     async def event( | ||||||
| @@ -411,7 +450,13 @@ class Sanic(BaseSanic): | |||||||
|  |  | ||||||
|         self.websocket_enabled = enable |         self.websocket_enabled = enable | ||||||
|  |  | ||||||
|     def blueprint(self, blueprint, **options): |     def blueprint( | ||||||
|  |         self, | ||||||
|  |         blueprint: Union[ | ||||||
|  |             Blueprint, List[Blueprint], Tuple[Blueprint], BlueprintGroup | ||||||
|  |         ], | ||||||
|  |         **options: Any, | ||||||
|  |     ): | ||||||
|         """Register a blueprint on the application. |         """Register a blueprint on the application. | ||||||
|  |  | ||||||
|         :param blueprint: Blueprint object or (list, tuple) thereof |         :param blueprint: Blueprint object or (list, tuple) thereof | ||||||
| @@ -651,7 +696,7 @@ class Sanic(BaseSanic): | |||||||
|  |  | ||||||
|     async def handle_exception( |     async def handle_exception( | ||||||
|         self, request: Request, exception: BaseException |         self, request: Request, exception: BaseException | ||||||
|     ): |     ):  # no cov | ||||||
|         """ |         """ | ||||||
|         A handler that catches specific exceptions and outputs a response. |         A handler that catches specific exceptions and outputs a response. | ||||||
|  |  | ||||||
| @@ -661,6 +706,12 @@ class Sanic(BaseSanic): | |||||||
|         :type exception: BaseException |         :type exception: BaseException | ||||||
|         :raises ServerError: response 500 |         :raises ServerError: response 500 | ||||||
|         """ |         """ | ||||||
|  |         await self.dispatch( | ||||||
|  |             "http.lifecycle.exception", | ||||||
|  |             inline=True, | ||||||
|  |             context={"request": request, "exception": exception}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # -------------------------------------------- # |         # -------------------------------------------- # | ||||||
|         # Request Middleware |         # Request Middleware | ||||||
|         # -------------------------------------------- # |         # -------------------------------------------- # | ||||||
| @@ -707,7 +758,7 @@ class Sanic(BaseSanic): | |||||||
|                 f"Invalid response type {response!r} (need HTTPResponse)" |                 f"Invalid response type {response!r} (need HTTPResponse)" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     async def handle_request(self, request: Request): |     async def handle_request(self, request: Request):  # no cov | ||||||
|         """Take a request from the HTTP Server and return a response object |         """Take a request from the HTTP Server and return a response object | ||||||
|         to be sent back The HTTP Server only expects a response object, so |         to be sent back The HTTP Server only expects a response object, so | ||||||
|         exception handling must be done here |         exception handling must be done here | ||||||
| @@ -715,10 +766,22 @@ class Sanic(BaseSanic): | |||||||
|         :param request: HTTP Request object |         :param request: HTTP Request object | ||||||
|         :return: Nothing |         :return: Nothing | ||||||
|         """ |         """ | ||||||
|  |         await self.dispatch( | ||||||
|  |             "http.lifecycle.handle", | ||||||
|  |             inline=True, | ||||||
|  |             context={"request": request}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # Define `response` var here to remove warnings about |         # Define `response` var here to remove warnings about | ||||||
|         # allocation before assignment below. |         # allocation before assignment below. | ||||||
|         response = None |         response = None | ||||||
|         try: |         try: | ||||||
|  |  | ||||||
|  |             await self.dispatch( | ||||||
|  |                 "http.routing.before", | ||||||
|  |                 inline=True, | ||||||
|  |                 context={"request": request}, | ||||||
|  |             ) | ||||||
|             # Fetch handler from router |             # Fetch handler from router | ||||||
|             route, handler, kwargs = self.router.get( |             route, handler, kwargs = self.router.get( | ||||||
|                 request.path, |                 request.path, | ||||||
| @@ -726,19 +789,29 @@ class Sanic(BaseSanic): | |||||||
|                 request.headers.getone("host", None), |                 request.headers.getone("host", None), | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             request._match_info = kwargs |             request._match_info = {**kwargs} | ||||||
|             request.route = route |             request.route = route | ||||||
|  |  | ||||||
|  |             await self.dispatch( | ||||||
|  |                 "http.routing.after", | ||||||
|  |                 inline=True, | ||||||
|  |                 context={ | ||||||
|  |                     "request": request, | ||||||
|  |                     "route": route, | ||||||
|  |                     "kwargs": kwargs, | ||||||
|  |                     "handler": handler, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|             if ( |             if ( | ||||||
|                 request.stream.request_body  # type: ignore |                 request.stream | ||||||
|  |                 and request.stream.request_body | ||||||
|                 and not route.ctx.ignore_body |                 and not route.ctx.ignore_body | ||||||
|             ): |             ): | ||||||
|  |  | ||||||
|                 if hasattr(handler, "is_stream"): |                 if hasattr(handler, "is_stream"): | ||||||
|                     # Streaming handler: lift the size limit |                     # Streaming handler: lift the size limit | ||||||
|                     request.stream.request_max_size = float(  # type: ignore |                     request.stream.request_max_size = float("inf") | ||||||
|                         "inf" |  | ||||||
|                     ) |  | ||||||
|                 else: |                 else: | ||||||
|                     # Non-streaming handler: preload body |                     # Non-streaming handler: preload body | ||||||
|                     await request.receive_body() |                     await request.receive_body() | ||||||
| @@ -765,17 +838,25 @@ class Sanic(BaseSanic): | |||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|                 # Run response handler |                 # Run response handler | ||||||
|                 response = handler(request, **kwargs) |                 response = handler(request, **request.match_info) | ||||||
|                 if isawaitable(response): |                 if isawaitable(response): | ||||||
|                     response = await response |                     response = await response | ||||||
|  |  | ||||||
|             if response: |             if response is not None: | ||||||
|                 response = await request.respond(response) |                 response = await request.respond(response) | ||||||
|             elif not hasattr(handler, "is_websocket"): |             elif not hasattr(handler, "is_websocket"): | ||||||
|                 response = request.stream.response  # type: ignore |                 response = request.stream.response  # type: ignore | ||||||
|  |  | ||||||
|             # Make sure that response is finished / run StreamingHTTP callback |             # Make sure that response is finished / run StreamingHTTP callback | ||||||
|             if isinstance(response, BaseHTTPResponse): |             if isinstance(response, BaseHTTPResponse): | ||||||
|  |                 await self.dispatch( | ||||||
|  |                     "http.lifecycle.response", | ||||||
|  |                     inline=True, | ||||||
|  |                     context={ | ||||||
|  |                         "request": request, | ||||||
|  |                         "response": response, | ||||||
|  |                     }, | ||||||
|  |                 ) | ||||||
|                 await response.send(end_stream=True) |                 await response.send(end_stream=True) | ||||||
|             else: |             else: | ||||||
|                 if not hasattr(handler, "is_websocket"): |                 if not hasattr(handler, "is_websocket"): | ||||||
| @@ -793,23 +874,11 @@ class Sanic(BaseSanic): | |||||||
|     async def _websocket_handler( |     async def _websocket_handler( | ||||||
|         self, handler, request, *args, subprotocols=None, **kwargs |         self, handler, request, *args, subprotocols=None, **kwargs | ||||||
|     ): |     ): | ||||||
|         request.app = self |  | ||||||
|         if not getattr(handler, "__blueprintname__", False): |  | ||||||
|             request._name = handler.__name__ |  | ||||||
|         else: |  | ||||||
|             request._name = ( |  | ||||||
|                 getattr(handler, "__blueprintname__", "") + handler.__name__ |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         if self.asgi: |         if self.asgi: | ||||||
|             ws = request.transport.get_websocket_connection() |             ws = request.transport.get_websocket_connection() | ||||||
|             await ws.accept(subprotocols) |             await ws.accept(subprotocols) | ||||||
|         else: |         else: | ||||||
|             protocol = request.transport.get_protocol() |             protocol = request.transport.get_protocol() | ||||||
|             protocol.app = self |  | ||||||
|  |  | ||||||
|             ws = await protocol.websocket_handshake(request, subprotocols) |             ws = await protocol.websocket_handshake(request, subprotocols) | ||||||
|  |  | ||||||
|         # schedule the application handler |         # schedule the application handler | ||||||
| @@ -817,12 +886,18 @@ class Sanic(BaseSanic): | |||||||
|         # needs to be cancelled due to the server being stopped |         # needs to be cancelled due to the server being stopped | ||||||
|         fut = ensure_future(handler(request, ws, *args, **kwargs)) |         fut = ensure_future(handler(request, ws, *args, **kwargs)) | ||||||
|         self.websocket_tasks.add(fut) |         self.websocket_tasks.add(fut) | ||||||
|  |         cancelled = False | ||||||
|         try: |         try: | ||||||
|             await fut |             await fut | ||||||
|  |         except Exception as e: | ||||||
|  |             self.error_handler.log(request, e) | ||||||
|         except (CancelledError, ConnectionClosed): |         except (CancelledError, ConnectionClosed): | ||||||
|             pass |             cancelled = True | ||||||
|         finally: |         finally: | ||||||
|             self.websocket_tasks.remove(fut) |             self.websocket_tasks.remove(fut) | ||||||
|  |             if cancelled: | ||||||
|  |                 ws.end_connection(1000) | ||||||
|  |             else: | ||||||
|                 await ws.close() |                 await ws.close() | ||||||
|  |  | ||||||
|     # -------------------------------------------------------------------- # |     # -------------------------------------------------------------------- # | ||||||
| @@ -869,7 +944,7 @@ class Sanic(BaseSanic): | |||||||
|         *, |         *, | ||||||
|         debug: bool = False, |         debug: bool = False, | ||||||
|         auto_reload: Optional[bool] = None, |         auto_reload: Optional[bool] = None, | ||||||
|         ssl: Union[dict, SSLContext, None] = None, |         ssl: Union[Dict[str, str], SSLContext, None] = None, | ||||||
|         sock: Optional[socket] = None, |         sock: Optional[socket] = None, | ||||||
|         workers: int = 1, |         workers: int = 1, | ||||||
|         protocol: Optional[Type[Protocol]] = None, |         protocol: Optional[Type[Protocol]] = None, | ||||||
| @@ -999,7 +1074,7 @@ class Sanic(BaseSanic): | |||||||
|         port: Optional[int] = None, |         port: Optional[int] = None, | ||||||
|         *, |         *, | ||||||
|         debug: bool = False, |         debug: bool = False, | ||||||
|         ssl: Union[dict, SSLContext, None] = None, |         ssl: Union[Dict[str, str], SSLContext, None] = None, | ||||||
|         sock: Optional[socket] = None, |         sock: Optional[socket] = None, | ||||||
|         protocol: Type[Protocol] = None, |         protocol: Type[Protocol] = None, | ||||||
|         backlog: int = 100, |         backlog: int = 100, | ||||||
| @@ -1071,11 +1146,6 @@ class Sanic(BaseSanic): | |||||||
|             run_async=return_asyncio_server, |             run_async=return_asyncio_server, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # Trigger before_start events |  | ||||||
|         await self.trigger_events( |  | ||||||
|             server_settings.get("before_start", []), |  | ||||||
|             server_settings.get("loop"), |  | ||||||
|         ) |  | ||||||
|         main_start = server_settings.pop("main_start", None) |         main_start = server_settings.pop("main_start", None) | ||||||
|         main_stop = server_settings.pop("main_stop", None) |         main_stop = server_settings.pop("main_stop", None) | ||||||
|         if main_start or main_stop: |         if main_start or main_stop: | ||||||
| @@ -1088,17 +1158,9 @@ class Sanic(BaseSanic): | |||||||
|             asyncio_server_kwargs=asyncio_server_kwargs, **server_settings |             asyncio_server_kwargs=asyncio_server_kwargs, **server_settings | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     async def trigger_events(self, events, loop): |     async def _run_request_middleware( | ||||||
|         """Trigger events (functions or async) |         self, request, request_name=None | ||||||
|         :param events: one or more sync or async functions to execute |     ):  # no cov | ||||||
|         :param loop: event loop |  | ||||||
|         """ |  | ||||||
|         for event in events: |  | ||||||
|             result = event(loop) |  | ||||||
|             if isawaitable(result): |  | ||||||
|                 await result |  | ||||||
|  |  | ||||||
|     async def _run_request_middleware(self, request, request_name=None): |  | ||||||
|         # The if improves speed.  I don't know why |         # The if improves speed.  I don't know why | ||||||
|         named_middleware = self.named_request_middleware.get( |         named_middleware = self.named_request_middleware.get( | ||||||
|             request_name, deque() |             request_name, deque() | ||||||
| @@ -1111,25 +1173,67 @@ class Sanic(BaseSanic): | |||||||
|             request.request_middleware_started = True |             request.request_middleware_started = True | ||||||
|  |  | ||||||
|             for middleware in applicable_middleware: |             for middleware in applicable_middleware: | ||||||
|  |                 await self.dispatch( | ||||||
|  |                     "http.middleware.before", | ||||||
|  |                     inline=True, | ||||||
|  |                     context={ | ||||||
|  |                         "request": request, | ||||||
|  |                         "response": None, | ||||||
|  |                     }, | ||||||
|  |                     condition={"attach_to": "request"}, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|                 response = middleware(request) |                 response = middleware(request) | ||||||
|                 if isawaitable(response): |                 if isawaitable(response): | ||||||
|                     response = await response |                     response = await response | ||||||
|  |  | ||||||
|  |                 await self.dispatch( | ||||||
|  |                     "http.middleware.after", | ||||||
|  |                     inline=True, | ||||||
|  |                     context={ | ||||||
|  |                         "request": request, | ||||||
|  |                         "response": None, | ||||||
|  |                     }, | ||||||
|  |                     condition={"attach_to": "request"}, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|                 if response: |                 if response: | ||||||
|                     return response |                     return response | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     async def _run_response_middleware( |     async def _run_response_middleware( | ||||||
|         self, request, response, request_name=None |         self, request, response, request_name=None | ||||||
|     ): |     ):  # no cov | ||||||
|         named_middleware = self.named_response_middleware.get( |         named_middleware = self.named_response_middleware.get( | ||||||
|             request_name, deque() |             request_name, deque() | ||||||
|         ) |         ) | ||||||
|         applicable_middleware = self.response_middleware + named_middleware |         applicable_middleware = self.response_middleware + named_middleware | ||||||
|         if applicable_middleware: |         if applicable_middleware: | ||||||
|             for middleware in applicable_middleware: |             for middleware in applicable_middleware: | ||||||
|  |                 await self.dispatch( | ||||||
|  |                     "http.middleware.before", | ||||||
|  |                     inline=True, | ||||||
|  |                     context={ | ||||||
|  |                         "request": request, | ||||||
|  |                         "response": response, | ||||||
|  |                     }, | ||||||
|  |                     condition={"attach_to": "response"}, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|                 _response = middleware(request, response) |                 _response = middleware(request, response) | ||||||
|                 if isawaitable(_response): |                 if isawaitable(_response): | ||||||
|                     _response = await _response |                     _response = await _response | ||||||
|  |  | ||||||
|  |                 await self.dispatch( | ||||||
|  |                     "http.middleware.after", | ||||||
|  |                     inline=True, | ||||||
|  |                     context={ | ||||||
|  |                         "request": request, | ||||||
|  |                         "response": _response if _response else response, | ||||||
|  |                     }, | ||||||
|  |                     condition={"attach_to": "response"}, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|                 if _response: |                 if _response: | ||||||
|                     response = _response |                     response = _response | ||||||
|                     if isinstance(response, BaseHTTPResponse): |                     if isinstance(response, BaseHTTPResponse): | ||||||
| @@ -1155,10 +1259,6 @@ class Sanic(BaseSanic): | |||||||
|     ): |     ): | ||||||
|         """Helper function used by `run` and `create_server`.""" |         """Helper function used by `run` and `create_server`.""" | ||||||
|  |  | ||||||
|         self.listeners["before_server_start"] = [ |  | ||||||
|             self.finalize |  | ||||||
|         ] + self.listeners["before_server_start"] |  | ||||||
|  |  | ||||||
|         if isinstance(ssl, dict): |         if isinstance(ssl, dict): | ||||||
|             # try common aliaseses |             # try common aliaseses | ||||||
|             cert = ssl.get("cert") or ssl.get("certificate") |             cert = ssl.get("cert") or ssl.get("certificate") | ||||||
| @@ -1195,10 +1295,6 @@ class Sanic(BaseSanic): | |||||||
|         # Register start/stop events |         # Register start/stop events | ||||||
|  |  | ||||||
|         for event_name, settings_name, reverse in ( |         for event_name, settings_name, reverse in ( | ||||||
|             ("before_server_start", "before_start", False), |  | ||||||
|             ("after_server_start", "after_start", False), |  | ||||||
|             ("before_server_stop", "before_stop", True), |  | ||||||
|             ("after_server_stop", "after_stop", True), |  | ||||||
|             ("main_process_start", "main_start", False), |             ("main_process_start", "main_start", False), | ||||||
|             ("main_process_stop", "main_stop", True), |             ("main_process_stop", "main_stop", True), | ||||||
|         ): |         ): | ||||||
| @@ -1236,7 +1332,8 @@ class Sanic(BaseSanic): | |||||||
|                 logger.info(f"Goin' Fast @ {proto}://{host}:{port}") |                 logger.info(f"Goin' Fast @ {proto}://{host}:{port}") | ||||||
|  |  | ||||||
|         debug_mode = "enabled" if self.debug else "disabled" |         debug_mode = "enabled" if self.debug else "disabled" | ||||||
|         logger.debug("Sanic auto-reload: enabled") |         reload_mode = "enabled" if auto_reload else "disabled" | ||||||
|  |         logger.debug(f"Sanic auto-reload: {reload_mode}") | ||||||
|         logger.debug(f"Sanic debug mode: {debug_mode}") |         logger.debug(f"Sanic debug mode: {debug_mode}") | ||||||
|  |  | ||||||
|         return server_settings |         return server_settings | ||||||
| @@ -1246,20 +1343,44 @@ class Sanic(BaseSanic): | |||||||
|         return ".".join(parts) |         return ".".join(parts) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _loop_add_task(cls, task, app, loop): |     def _prep_task(cls, task, app, loop): | ||||||
|         if callable(task): |         if callable(task): | ||||||
|             try: |             try: | ||||||
|                 loop.create_task(task(app)) |                 task = task(app) | ||||||
|             except TypeError: |             except TypeError: | ||||||
|                 loop.create_task(task()) |                 task = task() | ||||||
|         else: |  | ||||||
|             loop.create_task(task) |         return task | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def _loop_add_task(cls, task, app, loop): | ||||||
|  |         prepped = cls._prep_task(task, app, loop) | ||||||
|  |         loop.create_task(prepped) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _cancel_websocket_tasks(cls, app, loop): |     def _cancel_websocket_tasks(cls, app, loop): | ||||||
|         for task in app.websocket_tasks: |         for task in app.websocket_tasks: | ||||||
|             task.cancel() |             task.cancel() | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     async def dispatch_delayed_tasks(app, loop): | ||||||
|  |         for name in app._delayed_tasks: | ||||||
|  |             await app.dispatch(name, context={"app": app, "loop": loop}) | ||||||
|  |         app._delayed_tasks.clear() | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     async def run_delayed_task(app, loop, task): | ||||||
|  |         prepped = app._prep_task(task, app, loop) | ||||||
|  |         await prepped | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     async def _listener( | ||||||
|  |         app: Sanic, loop: AbstractEventLoop, listener: ListenerType | ||||||
|  |     ): | ||||||
|  |         maybe_coro = listener(app, loop) | ||||||
|  |         if maybe_coro and isawaitable(maybe_coro): | ||||||
|  |             await maybe_coro | ||||||
|  |  | ||||||
|     # -------------------------------------------------------------------- # |     # -------------------------------------------------------------------- # | ||||||
|     # ASGI |     # ASGI | ||||||
|     # -------------------------------------------------------------------- # |     # -------------------------------------------------------------------- # | ||||||
| @@ -1333,15 +1454,52 @@ class Sanic(BaseSanic): | |||||||
|             raise SanicException(f'Sanic app name "{name}" not found.') |             raise SanicException(f'Sanic app name "{name}" not found.') | ||||||
|  |  | ||||||
|     # -------------------------------------------------------------------- # |     # -------------------------------------------------------------------- # | ||||||
|     # Static methods |     # Lifecycle | ||||||
|     # -------------------------------------------------------------------- # |     # -------------------------------------------------------------------- # | ||||||
|  |  | ||||||
|     @staticmethod |     def finalize(self): | ||||||
|     async def finalize(app, _): |  | ||||||
|         try: |         try: | ||||||
|             app.router.finalize() |             self.router.finalize() | ||||||
|             if app.signal_router.routes: |  | ||||||
|                 app.signal_router.finalize()  # noqa |  | ||||||
|         except FinalizationError as e: |         except FinalizationError as e: | ||||||
|             if not Sanic.test_mode: |             if not Sanic.test_mode: | ||||||
|                 raise e  # noqa |                 raise e | ||||||
|  |  | ||||||
|  |     def signalize(self): | ||||||
|  |         try: | ||||||
|  |             self.signal_router.finalize() | ||||||
|  |         except FinalizationError as e: | ||||||
|  |             if not Sanic.test_mode: | ||||||
|  |                 raise e | ||||||
|  |  | ||||||
|  |     async def _startup(self): | ||||||
|  |         self.signalize() | ||||||
|  |         self.finalize() | ||||||
|  |         ErrorHandler.finalize(self.error_handler) | ||||||
|  |         TouchUp.run(self) | ||||||
|  |  | ||||||
|  |     async def _server_event( | ||||||
|  |         self, | ||||||
|  |         concern: str, | ||||||
|  |         action: str, | ||||||
|  |         loop: Optional[AbstractEventLoop] = None, | ||||||
|  |     ) -> None: | ||||||
|  |         event = f"server.{concern}.{action}" | ||||||
|  |         if action not in ("before", "after") or concern not in ( | ||||||
|  |             "init", | ||||||
|  |             "shutdown", | ||||||
|  |         ): | ||||||
|  |             raise SanicException(f"Invalid server event: {event}") | ||||||
|  |         logger.debug(f"Triggering server events: {event}") | ||||||
|  |         reverse = concern == "shutdown" | ||||||
|  |         if loop is None: | ||||||
|  |             loop = self.loop | ||||||
|  |         await self.dispatch( | ||||||
|  |             event, | ||||||
|  |             fail_not_found=False, | ||||||
|  |             reverse=reverse, | ||||||
|  |             inline=True, | ||||||
|  |             context={ | ||||||
|  |                 "app": self, | ||||||
|  |                 "loop": loop, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -1,6 +1,5 @@ | |||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| from inspect import isawaitable |  | ||||||
| from typing import Optional | from typing import Optional | ||||||
| from urllib.parse import quote | from urllib.parse import quote | ||||||
|  |  | ||||||
| @@ -11,21 +10,27 @@ from sanic.exceptions import ServerError | |||||||
| from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport | from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport | ||||||
| from sanic.request import Request | from sanic.request import Request | ||||||
| from sanic.server import ConnInfo | from sanic.server import ConnInfo | ||||||
| from sanic.websocket import WebSocketConnection | from sanic.server.websockets.connection import WebSocketConnection | ||||||
|  |  | ||||||
|  |  | ||||||
| class Lifespan: | class Lifespan: | ||||||
|     def __init__(self, asgi_app: "ASGIApp") -> None: |     def __init__(self, asgi_app: "ASGIApp") -> None: | ||||||
|         self.asgi_app = asgi_app |         self.asgi_app = asgi_app | ||||||
|  |  | ||||||
|         if "before_server_start" in self.asgi_app.sanic_app.listeners: |         if ( | ||||||
|  |             "server.init.before" | ||||||
|  |             in self.asgi_app.sanic_app.signal_router.name_index | ||||||
|  |         ): | ||||||
|             warnings.warn( |             warnings.warn( | ||||||
|                 'You have set a listener for "before_server_start" ' |                 'You have set a listener for "before_server_start" ' | ||||||
|                 "in ASGI mode. " |                 "in ASGI mode. " | ||||||
|                 "It will be executed as early as possible, but not before " |                 "It will be executed as early as possible, but not before " | ||||||
|                 "the ASGI server is started." |                 "the ASGI server is started." | ||||||
|             ) |             ) | ||||||
|         if "after_server_stop" in self.asgi_app.sanic_app.listeners: |         if ( | ||||||
|  |             "server.shutdown.after" | ||||||
|  |             in self.asgi_app.sanic_app.signal_router.name_index | ||||||
|  |         ): | ||||||
|             warnings.warn( |             warnings.warn( | ||||||
|                 'You have set a listener for "after_server_stop" ' |                 'You have set a listener for "after_server_stop" ' | ||||||
|                 "in ASGI mode. " |                 "in ASGI mode. " | ||||||
| @@ -42,19 +47,9 @@ class Lifespan: | |||||||
|         in sequence since the ASGI lifespan protocol only supports a single |         in sequence since the ASGI lifespan protocol only supports a single | ||||||
|         startup event. |         startup event. | ||||||
|         """ |         """ | ||||||
|         self.asgi_app.sanic_app.router.finalize() |         await self.asgi_app.sanic_app._startup() | ||||||
|         if self.asgi_app.sanic_app.signal_router.routes: |         await self.asgi_app.sanic_app._server_event("init", "before") | ||||||
|             self.asgi_app.sanic_app.signal_router.finalize() |         await self.asgi_app.sanic_app._server_event("init", "after") | ||||||
|         listeners = self.asgi_app.sanic_app.listeners.get( |  | ||||||
|             "before_server_start", [] |  | ||||||
|         ) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) |  | ||||||
|  |  | ||||||
|         for handler in listeners: |  | ||||||
|             response = handler( |  | ||||||
|                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop |  | ||||||
|             ) |  | ||||||
|             if response and isawaitable(response): |  | ||||||
|                 await response |  | ||||||
|  |  | ||||||
|     async def shutdown(self) -> None: |     async def shutdown(self) -> None: | ||||||
|         """ |         """ | ||||||
| @@ -65,16 +60,8 @@ class Lifespan: | |||||||
|         in sequence since the ASGI lifespan protocol only supports a single |         in sequence since the ASGI lifespan protocol only supports a single | ||||||
|         shutdown event. |         shutdown event. | ||||||
|         """ |         """ | ||||||
|         listeners = self.asgi_app.sanic_app.listeners.get( |         await self.asgi_app.sanic_app._server_event("shutdown", "before") | ||||||
|             "before_server_stop", [] |         await self.asgi_app.sanic_app._server_event("shutdown", "after") | ||||||
|         ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) |  | ||||||
|  |  | ||||||
|         for handler in listeners: |  | ||||||
|             response = handler( |  | ||||||
|                 self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop |  | ||||||
|             ) |  | ||||||
|             if response and isawaitable(response): |  | ||||||
|                 await response |  | ||||||
|  |  | ||||||
|     async def __call__( |     async def __call__( | ||||||
|         self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend |         self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||||
| @@ -207,4 +194,7 @@ class ASGIApp: | |||||||
|         """ |         """ | ||||||
|         Handle the incoming request. |         Handle the incoming request. | ||||||
|         """ |         """ | ||||||
|  |         try: | ||||||
|             await self.sanic_app.handle_request(self.request) |             await self.sanic_app.handle_request(self.request) | ||||||
|  |         except Exception as e: | ||||||
|  |             await self.sanic_app.handle_exception(self.request, e) | ||||||
|   | |||||||
| @@ -58,7 +58,7 @@ class BaseSanic( | |||||||
|         if name not in self.__fake_slots__: |         if name not in self.__fake_slots__: | ||||||
|             warn( |             warn( | ||||||
|                 f"Setting variables on {self.__class__.__name__} instances is " |                 f"Setting variables on {self.__class__.__name__} instances is " | ||||||
|                 "deprecated and will be removed in version 21.9. You should " |                 "deprecated and will be removed in version 21.12. You should " | ||||||
|                 f"change your {self.__class__.__name__} instance to use " |                 f"change your {self.__class__.__name__} instance to use " | ||||||
|                 f"instance.ctx.{name} instead.", |                 f"instance.ctx.{name} instead.", | ||||||
|                 DeprecationWarning, |                 DeprecationWarning, | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| from collections.abc import MutableSequence | from collections.abc import MutableSequence | ||||||
|  | from functools import partial | ||||||
| from typing import TYPE_CHECKING, List, Optional, Union | from typing import TYPE_CHECKING, List, Optional, Union | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -196,6 +197,27 @@ class BlueprintGroup(MutableSequence): | |||||||
|         """ |         """ | ||||||
|         self._blueprints.append(value) |         self._blueprints.append(value) | ||||||
|  |  | ||||||
|  |     def exception(self, *exceptions, **kwargs): | ||||||
|  |         """ | ||||||
|  |         A decorator that can be used to implement a global exception handler | ||||||
|  |         for all the Blueprints that belong to this Blueprint Group. | ||||||
|  |  | ||||||
|  |         In case of nested Blueprint Groups, the same handler is applied | ||||||
|  |         across each of the Blueprints recursively. | ||||||
|  |  | ||||||
|  |         :param args: List of Python exceptions to be caught by the handler | ||||||
|  |         :param kwargs: Additional optional arguments to be passed to the | ||||||
|  |             exception handler | ||||||
|  |         :return a decorated method to handle global exceptions for any | ||||||
|  |             blueprint registered under this group. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         def register_exception_handler_for_blueprints(fn): | ||||||
|  |             for blueprint in self.blueprints: | ||||||
|  |                 blueprint.exception(*exceptions, **kwargs)(fn) | ||||||
|  |  | ||||||
|  |         return register_exception_handler_for_blueprints | ||||||
|  |  | ||||||
|     def insert(self, index: int, item: Blueprint) -> None: |     def insert(self, index: int, item: Blueprint) -> None: | ||||||
|         """ |         """ | ||||||
|         The Abstract class `MutableSequence` leverages this insert method to |         The Abstract class `MutableSequence` leverages this insert method to | ||||||
| @@ -229,3 +251,15 @@ class BlueprintGroup(MutableSequence): | |||||||
|             args = list(args)[1:] |             args = list(args)[1:] | ||||||
|             return register_middleware_for_blueprints(fn) |             return register_middleware_for_blueprints(fn) | ||||||
|         return register_middleware_for_blueprints |         return register_middleware_for_blueprints | ||||||
|  |  | ||||||
|  |     def on_request(self, middleware=None): | ||||||
|  |         if callable(middleware): | ||||||
|  |             return self.middleware(middleware, "request") | ||||||
|  |         else: | ||||||
|  |             return partial(self.middleware, attach_to="request") | ||||||
|  |  | ||||||
|  |     def on_response(self, middleware=None): | ||||||
|  |         if callable(middleware): | ||||||
|  |             return self.middleware(middleware, "response") | ||||||
|  |         else: | ||||||
|  |             return partial(self.middleware, attach_to="response") | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ from __future__ import annotations | |||||||
| import asyncio | import asyncio | ||||||
|  |  | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  | from copy import deepcopy | ||||||
| from types import SimpleNamespace | from types import SimpleNamespace | ||||||
| from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union | from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union | ||||||
|  |  | ||||||
| @@ -12,6 +13,7 @@ from sanic_routing.route import Route  # type: ignore | |||||||
| from sanic.base import BaseSanic | from sanic.base import BaseSanic | ||||||
| from sanic.blueprint_group import BlueprintGroup | from sanic.blueprint_group import BlueprintGroup | ||||||
| from sanic.exceptions import SanicException | from sanic.exceptions import SanicException | ||||||
|  | from sanic.helpers import Default, _default | ||||||
| from sanic.models.futures import FutureRoute, FutureStatic | from sanic.models.futures import FutureRoute, FutureStatic | ||||||
| from sanic.models.handler_types import ( | from sanic.models.handler_types import ( | ||||||
|     ListenerType, |     ListenerType, | ||||||
| @@ -40,7 +42,7 @@ class Blueprint(BaseSanic): | |||||||
|     :param host: IP Address of FQDN for the sanic server to use. |     :param host: IP Address of FQDN for the sanic server to use. | ||||||
|     :param version: Blueprint Version |     :param version: Blueprint Version | ||||||
|     :param strict_slashes: Enforce the API urls are requested with a |     :param strict_slashes: Enforce the API urls are requested with a | ||||||
|         training */* |         trailing */* | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     __fake_slots__ = ( |     __fake_slots__ = ( | ||||||
| @@ -76,15 +78,9 @@ class Blueprint(BaseSanic): | |||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |     ): | ||||||
|         super().__init__(name=name) |         super().__init__(name=name) | ||||||
|  |         self.reset() | ||||||
|         self._apps: Set[Sanic] = set() |  | ||||||
|         self.ctx = SimpleNamespace() |         self.ctx = SimpleNamespace() | ||||||
|         self.exceptions: List[RouteHandler] = [] |  | ||||||
|         self.host = host |         self.host = host | ||||||
|         self.listeners: Dict[str, List[ListenerType]] = {} |  | ||||||
|         self.middlewares: List[MiddlewareType] = [] |  | ||||||
|         self.routes: List[Route] = [] |  | ||||||
|         self.statics: List[RouteHandler] = [] |  | ||||||
|         self.strict_slashes = strict_slashes |         self.strict_slashes = strict_slashes | ||||||
|         self.url_prefix = ( |         self.url_prefix = ( | ||||||
|             url_prefix[:-1] |             url_prefix[:-1] | ||||||
| @@ -93,7 +89,6 @@ class Blueprint(BaseSanic): | |||||||
|         ) |         ) | ||||||
|         self.version = version |         self.version = version | ||||||
|         self.version_prefix = version_prefix |         self.version_prefix = version_prefix | ||||||
|         self.websocket_routes: List[Route] = [] |  | ||||||
|  |  | ||||||
|     def __repr__(self) -> str: |     def __repr__(self) -> str: | ||||||
|         args = ", ".join( |         args = ", ".join( | ||||||
| @@ -144,12 +139,87 @@ class Blueprint(BaseSanic): | |||||||
|         kwargs["apply"] = False |         kwargs["apply"] = False | ||||||
|         return super().signal(event, *args, **kwargs) |         return super().signal(event, *args, **kwargs) | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._apps: Set[Sanic] = set() | ||||||
|  |         self.exceptions: List[RouteHandler] = [] | ||||||
|  |         self.listeners: Dict[str, List[ListenerType]] = {} | ||||||
|  |         self.middlewares: List[MiddlewareType] = [] | ||||||
|  |         self.routes: List[Route] = [] | ||||||
|  |         self.statics: List[RouteHandler] = [] | ||||||
|  |         self.websocket_routes: List[Route] = [] | ||||||
|  |  | ||||||
|  |     def copy( | ||||||
|  |         self, | ||||||
|  |         name: str, | ||||||
|  |         url_prefix: Optional[Union[str, Default]] = _default, | ||||||
|  |         version: Optional[Union[int, str, float, Default]] = _default, | ||||||
|  |         version_prefix: Union[str, Default] = _default, | ||||||
|  |         strict_slashes: Optional[Union[bool, Default]] = _default, | ||||||
|  |         with_registration: bool = True, | ||||||
|  |         with_ctx: bool = False, | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Copy a blueprint instance with some optional parameters to | ||||||
|  |         override the values of attributes in the old instance. | ||||||
|  |  | ||||||
|  |         :param name: unique name of the blueprint | ||||||
|  |         :param url_prefix: URL to be prefixed before all route URLs | ||||||
|  |         :param version: Blueprint Version | ||||||
|  |         :param version_prefix: the prefix of the version number shown in the | ||||||
|  |             URL. | ||||||
|  |         :param strict_slashes: Enforce the API urls are requested with a | ||||||
|  |             trailing */* | ||||||
|  |         :param with_registration: whether register new blueprint instance with | ||||||
|  |             sanic apps that were registered with the old instance or not. | ||||||
|  |         :param with_ctx: whether ``ctx`` will be copied or not. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         attrs_backup = { | ||||||
|  |             "_apps": self._apps, | ||||||
|  |             "routes": self.routes, | ||||||
|  |             "websocket_routes": self.websocket_routes, | ||||||
|  |             "middlewares": self.middlewares, | ||||||
|  |             "exceptions": self.exceptions, | ||||||
|  |             "listeners": self.listeners, | ||||||
|  |             "statics": self.statics, | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         self.reset() | ||||||
|  |         new_bp = deepcopy(self) | ||||||
|  |         new_bp.name = name | ||||||
|  |  | ||||||
|  |         if not isinstance(url_prefix, Default): | ||||||
|  |             new_bp.url_prefix = url_prefix | ||||||
|  |         if not isinstance(version, Default): | ||||||
|  |             new_bp.version = version | ||||||
|  |         if not isinstance(strict_slashes, Default): | ||||||
|  |             new_bp.strict_slashes = strict_slashes | ||||||
|  |         if not isinstance(version_prefix, Default): | ||||||
|  |             new_bp.version_prefix = version_prefix | ||||||
|  |  | ||||||
|  |         for key, value in attrs_backup.items(): | ||||||
|  |             setattr(self, key, value) | ||||||
|  |  | ||||||
|  |         if with_registration and self._apps: | ||||||
|  |             if new_bp._future_statics: | ||||||
|  |                 raise SanicException( | ||||||
|  |                     "Static routes registered with the old blueprint instance," | ||||||
|  |                     " cannot be registered again." | ||||||
|  |                 ) | ||||||
|  |             for app in self._apps: | ||||||
|  |                 app.blueprint(new_bp) | ||||||
|  |  | ||||||
|  |         if not with_ctx: | ||||||
|  |             new_bp.ctx = SimpleNamespace() | ||||||
|  |  | ||||||
|  |         return new_bp | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def group( |     def group( | ||||||
|         *blueprints, |         *blueprints: Union[Blueprint, BlueprintGroup], | ||||||
|         url_prefix="", |         url_prefix: Optional[str] = None, | ||||||
|         version=None, |         version: Optional[Union[int, str, float]] = None, | ||||||
|         strict_slashes=None, |         strict_slashes: Optional[bool] = None, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
| @@ -196,6 +266,9 @@ class Blueprint(BaseSanic): | |||||||
|         opt_version = options.get("version", None) |         opt_version = options.get("version", None) | ||||||
|         opt_strict_slashes = options.get("strict_slashes", None) |         opt_strict_slashes = options.get("strict_slashes", None) | ||||||
|         opt_version_prefix = options.get("version_prefix", self.version_prefix) |         opt_version_prefix = options.get("version_prefix", self.version_prefix) | ||||||
|  |         error_format = options.get( | ||||||
|  |             "error_format", app.config.FALLBACK_ERROR_FORMAT | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         routes = [] |         routes = [] | ||||||
|         middleware = [] |         middleware = [] | ||||||
| @@ -243,6 +316,7 @@ class Blueprint(BaseSanic): | |||||||
|                 future.unquote, |                 future.unquote, | ||||||
|                 future.static, |                 future.static, | ||||||
|                 version_prefix, |                 version_prefix, | ||||||
|  |                 error_format, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             route = app._apply_route(apply_route) |             route = app._apply_route(apply_route) | ||||||
| @@ -261,19 +335,22 @@ class Blueprint(BaseSanic): | |||||||
|  |  | ||||||
|         route_names = [route.name for route in routes if route] |         route_names = [route.name for route in routes if route] | ||||||
|  |  | ||||||
|         # Middleware |  | ||||||
|         if route_names: |         if route_names: | ||||||
|  |             # Middleware | ||||||
|             for future in self._future_middleware: |             for future in self._future_middleware: | ||||||
|                 middleware.append(app._apply_middleware(future, route_names)) |                 middleware.append(app._apply_middleware(future, route_names)) | ||||||
|  |  | ||||||
|             # Exceptions |             # Exceptions | ||||||
|             for future in self._future_exceptions: |             for future in self._future_exceptions: | ||||||
|             exception_handlers.append(app._apply_exception_handler(future)) |                 exception_handlers.append( | ||||||
|  |                     app._apply_exception_handler(future, route_names) | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|         # Event listeners |         # Event listeners | ||||||
|         for listener in self._future_listeners: |         for listener in self._future_listeners: | ||||||
|             listeners[listener.event].append(app._apply_listener(listener)) |             listeners[listener.event].append(app._apply_listener(listener)) | ||||||
|  |  | ||||||
|  |         # Signals | ||||||
|         for signal in self._future_signals: |         for signal in self._future_signals: | ||||||
|             signal.condition.update({"blueprint": self.name}) |             signal.condition.update({"blueprint": self.name}) | ||||||
|             app._apply_signal(signal) |             app._apply_signal(signal) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ from pathlib import Path | |||||||
| from typing import Any, Dict, Optional, Union | from typing import Any, Dict, Optional, Union | ||||||
| from warnings import warn | from warnings import warn | ||||||
|  |  | ||||||
|  | from sanic.errorpages import check_error_format | ||||||
| from sanic.http import Http | from sanic.http import Http | ||||||
|  |  | ||||||
| from .utils import load_module_from_file_location, str_to_bool | from .utils import load_module_from_file_location, str_to_bool | ||||||
| @@ -20,7 +21,7 @@ BASE_LOGO = """ | |||||||
| DEFAULT_CONFIG = { | DEFAULT_CONFIG = { | ||||||
|     "ACCESS_LOG": True, |     "ACCESS_LOG": True, | ||||||
|     "EVENT_AUTOREGISTER": False, |     "EVENT_AUTOREGISTER": False, | ||||||
|     "FALLBACK_ERROR_FORMAT": "html", |     "FALLBACK_ERROR_FORMAT": "auto", | ||||||
|     "FORWARDED_FOR_HEADER": "X-Forwarded-For", |     "FORWARDED_FOR_HEADER": "X-Forwarded-For", | ||||||
|     "FORWARDED_SECRET": None, |     "FORWARDED_SECRET": None, | ||||||
|     "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0,  # 15 sec |     "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0,  # 15 sec | ||||||
| @@ -35,12 +36,9 @@ DEFAULT_CONFIG = { | |||||||
|     "REQUEST_MAX_SIZE": 100000000,  # 100 megabytes |     "REQUEST_MAX_SIZE": 100000000,  # 100 megabytes | ||||||
|     "REQUEST_TIMEOUT": 60,  # 60 seconds |     "REQUEST_TIMEOUT": 60,  # 60 seconds | ||||||
|     "RESPONSE_TIMEOUT": 60,  # 60 seconds |     "RESPONSE_TIMEOUT": 60,  # 60 seconds | ||||||
|     "WEBSOCKET_MAX_QUEUE": 32, |  | ||||||
|     "WEBSOCKET_MAX_SIZE": 2 ** 20,  # 1 megabyte |     "WEBSOCKET_MAX_SIZE": 2 ** 20,  # 1 megabyte | ||||||
|     "WEBSOCKET_PING_INTERVAL": 20, |     "WEBSOCKET_PING_INTERVAL": 20, | ||||||
|     "WEBSOCKET_PING_TIMEOUT": 20, |     "WEBSOCKET_PING_TIMEOUT": 20, | ||||||
|     "WEBSOCKET_READ_LIMIT": 2 ** 16, |  | ||||||
|     "WEBSOCKET_WRITE_LIMIT": 2 ** 16, |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -62,12 +60,10 @@ class Config(dict): | |||||||
|     REQUEST_MAX_SIZE: int |     REQUEST_MAX_SIZE: int | ||||||
|     REQUEST_TIMEOUT: int |     REQUEST_TIMEOUT: int | ||||||
|     RESPONSE_TIMEOUT: int |     RESPONSE_TIMEOUT: int | ||||||
|     WEBSOCKET_MAX_QUEUE: int |     SERVER_NAME: str | ||||||
|     WEBSOCKET_MAX_SIZE: int |     WEBSOCKET_MAX_SIZE: int | ||||||
|     WEBSOCKET_PING_INTERVAL: int |     WEBSOCKET_PING_INTERVAL: int | ||||||
|     WEBSOCKET_PING_TIMEOUT: int |     WEBSOCKET_PING_TIMEOUT: int | ||||||
|     WEBSOCKET_READ_LIMIT: int |  | ||||||
|     WEBSOCKET_WRITE_LIMIT: int |  | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| @@ -100,6 +96,7 @@ class Config(dict): | |||||||
|             self.load_environment_vars(SANIC_PREFIX) |             self.load_environment_vars(SANIC_PREFIX) | ||||||
|  |  | ||||||
|         self._configure_header_size() |         self._configure_header_size() | ||||||
|  |         self._check_error_format() | ||||||
|  |  | ||||||
|     def __getattr__(self, attr): |     def __getattr__(self, attr): | ||||||
|         try: |         try: | ||||||
| @@ -115,6 +112,8 @@ class Config(dict): | |||||||
|             "REQUEST_MAX_SIZE", |             "REQUEST_MAX_SIZE", | ||||||
|         ): |         ): | ||||||
|             self._configure_header_size() |             self._configure_header_size() | ||||||
|  |         elif attr == "FALLBACK_ERROR_FORMAT": | ||||||
|  |             self._check_error_format() | ||||||
|  |  | ||||||
|     def _configure_header_size(self): |     def _configure_header_size(self): | ||||||
|         Http.set_header_max_size( |         Http.set_header_max_size( | ||||||
| @@ -123,6 +122,9 @@ class Config(dict): | |||||||
|             self.REQUEST_MAX_SIZE, |             self.REQUEST_MAX_SIZE, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def _check_error_format(self): | ||||||
|  |         check_error_format(self.FALLBACK_ERROR_FORMAT) | ||||||
|  |  | ||||||
|     def load_environment_vars(self, prefix=SANIC_PREFIX): |     def load_environment_vars(self, prefix=SANIC_PREFIX): | ||||||
|         """ |         """ | ||||||
|         Looks for prefixed environment variables and applies |         Looks for prefixed environment variables and applies | ||||||
|   | |||||||
| @@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = { | |||||||
| } | } | ||||||
|  |  | ||||||
| RENDERERS_BY_CONTENT_TYPE = { | RENDERERS_BY_CONTENT_TYPE = { | ||||||
|     "multipart/form-data": HTMLRenderer, |  | ||||||
|     "application/json": JSONRenderer, |  | ||||||
|     "text/plain": TextRenderer, |     "text/plain": TextRenderer, | ||||||
|  |     "application/json": JSONRenderer, | ||||||
|  |     "multipart/form-data": HTMLRenderer, | ||||||
|  |     "text/html": HTMLRenderer, | ||||||
| } | } | ||||||
|  | CONTENT_TYPE_BY_RENDERERS = { | ||||||
|  |     v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | RESPONSE_MAPPING = { | ||||||
|  |     "empty": "html", | ||||||
|  |     "json": "json", | ||||||
|  |     "text": "text", | ||||||
|  |     "raw": "text", | ||||||
|  |     "html": "html", | ||||||
|  |     "file": "html", | ||||||
|  |     "file_stream": "text", | ||||||
|  |     "stream": "text", | ||||||
|  |     "redirect": "html", | ||||||
|  |     "text/plain": "text", | ||||||
|  |     "text/html": "html", | ||||||
|  |     "application/json": "json", | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_error_format(format): | ||||||
|  |     if format not in RENDERERS_BY_CONFIG and format != "auto": | ||||||
|  |         raise SanicException(f"Unknown format: {format}") | ||||||
|  |  | ||||||
|  |  | ||||||
| def exception_response( | def exception_response( | ||||||
|     request: Request, |     request: Request, | ||||||
|     exception: Exception, |     exception: Exception, | ||||||
|     debug: bool, |     debug: bool, | ||||||
|  |     fallback: str, | ||||||
|  |     base: t.Type[BaseRenderer], | ||||||
|     renderer: t.Type[t.Optional[BaseRenderer]] = None, |     renderer: t.Type[t.Optional[BaseRenderer]] = None, | ||||||
| ) -> HTTPResponse: | ) -> HTTPResponse: | ||||||
|     """ |     """ | ||||||
|     Render a response for the default FALLBACK exception handler. |     Render a response for the default FALLBACK exception handler. | ||||||
|     """ |     """ | ||||||
|  |     content_type = None | ||||||
|  |  | ||||||
|     if not renderer: |     if not renderer: | ||||||
|         renderer = HTMLRenderer |         # Make sure we have something set | ||||||
|  |         renderer = base | ||||||
|  |         render_format = fallback | ||||||
|  |  | ||||||
|         if request: |         if request: | ||||||
|             if request.app.config.FALLBACK_ERROR_FORMAT == "auto": |             # If there is a request, try and get the format | ||||||
|  |             # from the route | ||||||
|  |             if request.route: | ||||||
|                 try: |                 try: | ||||||
|                     renderer = JSONRenderer if request.json else HTMLRenderer |                     render_format = request.route.ctx.error_format | ||||||
|                 except InvalidUsage: |                 except AttributeError: | ||||||
|  |                     ... | ||||||
|  |  | ||||||
|  |             content_type = request.headers.getone("content-type", "").split( | ||||||
|  |                 ";" | ||||||
|  |             )[0] | ||||||
|  |  | ||||||
|  |             acceptable = request.accept | ||||||
|  |  | ||||||
|  |             # If the format is auto still, make a guess | ||||||
|  |             if render_format == "auto": | ||||||
|  |                 # First, if there is an Accept header, check if text/html | ||||||
|  |                 # is the first option | ||||||
|  |                 # According to MDN Web Docs, all major browsers use text/html | ||||||
|  |                 # as the primary value in Accept (with the exception of IE 8, | ||||||
|  |                 # and, well, if you are supporting IE 8, then you have bigger | ||||||
|  |                 # problems to concern yourself with than what default exception | ||||||
|  |                 # renderer is used) | ||||||
|  |                 # Source: | ||||||
|  |                 # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values | ||||||
|  |  | ||||||
|  |                 if acceptable and acceptable[0].match( | ||||||
|  |                     "text/html", | ||||||
|  |                     allow_type_wildcard=False, | ||||||
|  |                     allow_subtype_wildcard=False, | ||||||
|  |                 ): | ||||||
|                     renderer = HTMLRenderer |                     renderer = HTMLRenderer | ||||||
|  |  | ||||||
|                 content_type, *_ = request.headers.getone( |                 # Second, if there is an Accept header, check if | ||||||
|                     "content-type", "" |                 # application/json is an option, or if the content-type | ||||||
|                 ).split(";") |                 # is application/json | ||||||
|                 renderer = RENDERERS_BY_CONTENT_TYPE.get( |                 elif ( | ||||||
|                     content_type, renderer |                     acceptable | ||||||
|  |                     and acceptable.match( | ||||||
|  |                         "application/json", | ||||||
|  |                         allow_type_wildcard=False, | ||||||
|  |                         allow_subtype_wildcard=False, | ||||||
|                     ) |                     ) | ||||||
|  |                     or content_type == "application/json" | ||||||
|  |                 ): | ||||||
|  |                     renderer = JSONRenderer | ||||||
|  |  | ||||||
|  |                 # Third, if there is no Accept header, assume we want text. | ||||||
|  |                 # The likely use case here is a raw socket. | ||||||
|  |                 elif not acceptable: | ||||||
|  |                     renderer = TextRenderer | ||||||
|  |                 else: | ||||||
|  |                     # Fourth, look to see if there was a JSON body | ||||||
|  |                     # When in this situation, the request is probably coming | ||||||
|  |                     # from curl, an API client like Postman or Insomnia, or a | ||||||
|  |                     # package like requests or httpx | ||||||
|  |                     try: | ||||||
|  |                         # Give them the benefit of the doubt if they did: | ||||||
|  |                         # $ curl localhost:8000 -d '{"foo": "bar"}' | ||||||
|  |                         # And provide them with JSONRenderer | ||||||
|  |                         renderer = JSONRenderer if request.json else base | ||||||
|  |                     except InvalidUsage: | ||||||
|  |                         renderer = base | ||||||
|             else: |             else: | ||||||
|                 render_format = request.app.config.FALLBACK_ERROR_FORMAT |  | ||||||
|                 renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) |                 renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) | ||||||
|  |  | ||||||
|  |             # Lastly, if there is an Accept header, make sure | ||||||
|  |             # our choice is okay | ||||||
|  |             if acceptable: | ||||||
|  |                 type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer)  # type: ignore | ||||||
|  |                 if type_ and type_ not in acceptable: | ||||||
|  |                     # If the renderer selected is not in the Accept header | ||||||
|  |                     # look through what is in the Accept header, and select | ||||||
|  |                     # the first option that matches. Otherwise, just drop back | ||||||
|  |                     # to the original default | ||||||
|  |                     for accept in acceptable: | ||||||
|  |                         mtype = f"{accept.type_}/{accept.subtype}" | ||||||
|  |                         maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) | ||||||
|  |                         if maybe: | ||||||
|  |                             renderer = maybe | ||||||
|  |                             break | ||||||
|  |                     else: | ||||||
|  |                         renderer = base | ||||||
|  |  | ||||||
|     renderer = t.cast(t.Type[BaseRenderer], renderer) |     renderer = t.cast(t.Type[BaseRenderer], renderer) | ||||||
|     return renderer(request, exception, debug).render() |     return renderer(request, exception, debug).render() | ||||||
|   | |||||||
| @@ -4,14 +4,18 @@ from sanic.helpers import STATUS_CODES | |||||||
|  |  | ||||||
|  |  | ||||||
| class SanicException(Exception): | class SanicException(Exception): | ||||||
|  |     message: str = "" | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         message: Optional[Union[str, bytes]] = None, |         message: Optional[Union[str, bytes]] = None, | ||||||
|         status_code: Optional[int] = None, |         status_code: Optional[int] = None, | ||||||
|         quiet: Optional[bool] = None, |         quiet: Optional[bool] = None, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  |         if message is None: | ||||||
|         if message is None and status_code is not None: |             if self.message: | ||||||
|  |                 message = self.message | ||||||
|  |             elif status_code is not None: | ||||||
|                 msg: bytes = STATUS_CODES.get(status_code, b"") |                 msg: bytes = STATUS_CODES.get(status_code, b"") | ||||||
|                 message = msg.decode("utf8") |                 message = msg.decode("utf8") | ||||||
|  |  | ||||||
| @@ -31,6 +35,7 @@ class NotFound(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 404 |     status_code = 404 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class InvalidUsage(SanicException): | class InvalidUsage(SanicException): | ||||||
| @@ -39,6 +44,7 @@ class InvalidUsage(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 400 |     status_code = 400 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class MethodNotSupported(SanicException): | class MethodNotSupported(SanicException): | ||||||
| @@ -47,6 +53,7 @@ class MethodNotSupported(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 405 |     status_code = 405 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|     def __init__(self, message, method, allowed_methods): |     def __init__(self, message, method, allowed_methods): | ||||||
|         super().__init__(message) |         super().__init__(message) | ||||||
| @@ -70,6 +77,7 @@ class ServiceUnavailable(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 503 |     status_code = 503 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class URLBuildError(ServerError): | class URLBuildError(ServerError): | ||||||
| @@ -101,6 +109,7 @@ class RequestTimeout(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 408 |     status_code = 408 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class PayloadTooLarge(SanicException): | class PayloadTooLarge(SanicException): | ||||||
| @@ -109,6 +118,7 @@ class PayloadTooLarge(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 413 |     status_code = 413 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class HeaderNotFound(InvalidUsage): | class HeaderNotFound(InvalidUsage): | ||||||
| @@ -116,7 +126,11 @@ class HeaderNotFound(InvalidUsage): | |||||||
|     **Status**: 400 Bad Request |     **Status**: 400 Bad Request | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 400 |  | ||||||
|  | class InvalidHeader(InvalidUsage): | ||||||
|  |     """ | ||||||
|  |     **Status**: 400 Bad Request | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |  | ||||||
| class ContentRangeError(SanicException): | class ContentRangeError(SanicException): | ||||||
| @@ -125,6 +139,7 @@ class ContentRangeError(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 416 |     status_code = 416 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|     def __init__(self, message, content_range): |     def __init__(self, message, content_range): | ||||||
|         super().__init__(message) |         super().__init__(message) | ||||||
| @@ -137,6 +152,7 @@ class HeaderExpectationFailed(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 417 |     status_code = 417 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class Forbidden(SanicException): | class Forbidden(SanicException): | ||||||
| @@ -145,6 +161,7 @@ class Forbidden(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 403 |     status_code = 403 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class InvalidRangeType(ContentRangeError): | class InvalidRangeType(ContentRangeError): | ||||||
| @@ -153,6 +170,7 @@ class InvalidRangeType(ContentRangeError): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 416 |     status_code = 416 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|  |  | ||||||
| class PyFileError(Exception): | class PyFileError(Exception): | ||||||
| @@ -196,6 +214,7 @@ class Unauthorized(SanicException): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     status_code = 401 |     status_code = 401 | ||||||
|  |     quiet = True | ||||||
|  |  | ||||||
|     def __init__(self, message, status_code=None, scheme=None, **kwargs): |     def __init__(self, message, status_code=None, scheme=None, **kwargs): | ||||||
|         super().__init__(message, status_code) |         super().__init__(message, status_code) | ||||||
| @@ -218,6 +237,11 @@ class InvalidSignal(SanicException): | |||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebsocketClosed(SanicException): | ||||||
|  |     quiet = True | ||||||
|  |     message = "Client has closed the websocket connection" | ||||||
|  |  | ||||||
|  |  | ||||||
| def abort(status_code: int, message: Optional[Union[str, bytes]] = None): | def abort(status_code: int, message: Optional[Union[str, bytes]] = None): | ||||||
|     """ |     """ | ||||||
|     Raise an exception based on SanicException. Returns the HTTP response |     Raise an exception based on SanicException. Returns the HTTP response | ||||||
|   | |||||||
| @@ -1,12 +1,14 @@ | |||||||
| from traceback import format_exc | from inspect import signature | ||||||
|  | from typing import Dict, List, Optional, Tuple, Type | ||||||
|  |  | ||||||
| from sanic.errorpages import exception_response | from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response | ||||||
| from sanic.exceptions import ( | from sanic.exceptions import ( | ||||||
|     ContentRangeError, |     ContentRangeError, | ||||||
|     HeaderNotFound, |     HeaderNotFound, | ||||||
|     InvalidRangeType, |     InvalidRangeType, | ||||||
| ) | ) | ||||||
| from sanic.log import error_logger | from sanic.log import error_logger | ||||||
|  | from sanic.models.handler_types import RouteHandler | ||||||
| from sanic.response import text | from sanic.response import text | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -23,15 +25,47 @@ class ErrorHandler: | |||||||
|  |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     handlers = None |     # Beginning in v22.3, the base renderer will be TextRenderer | ||||||
|     cached_handlers = None |     def __init__( | ||||||
|  |         self, fallback: str = "auto", base: Type[BaseRenderer] = HTMLRenderer | ||||||
|     def __init__(self): |     ): | ||||||
|         self.handlers = [] |         self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] | ||||||
|         self.cached_handlers = {} |         self.cached_handlers: Dict[ | ||||||
|  |             Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] | ||||||
|  |         ] = {} | ||||||
|         self.debug = False |         self.debug = False | ||||||
|  |         self.fallback = fallback | ||||||
|  |         self.base = base | ||||||
|  |  | ||||||
|     def add(self, exception, handler): |     @classmethod | ||||||
|  |     def finalize(cls, error_handler): | ||||||
|  |         if not isinstance(error_handler, cls): | ||||||
|  |             error_logger.warning( | ||||||
|  |                 f"Error handler is non-conforming: {type(error_handler)}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         sig = signature(error_handler.lookup) | ||||||
|  |         if len(sig.parameters) == 1: | ||||||
|  |             error_logger.warning( | ||||||
|  |                 DeprecationWarning( | ||||||
|  |                     "You are using a deprecated error handler. The lookup " | ||||||
|  |                     "method should accept two positional parameters: " | ||||||
|  |                     "(exception, route_name: Optional[str]). " | ||||||
|  |                     "Until you upgrade your ErrorHandler.lookup, Blueprint " | ||||||
|  |                     "specific exceptions will not work properly. Beginning " | ||||||
|  |                     "in v22.3, the legacy style lookup method will not " | ||||||
|  |                     "work at all." | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|  |             error_handler._lookup = error_handler._legacy_lookup | ||||||
|  |  | ||||||
|  |     def _full_lookup(self, exception, route_name: Optional[str] = None): | ||||||
|  |         return self.lookup(exception, route_name) | ||||||
|  |  | ||||||
|  |     def _legacy_lookup(self, exception, route_name: Optional[str] = None): | ||||||
|  |         return self.lookup(exception) | ||||||
|  |  | ||||||
|  |     def add(self, exception, handler, route_names: Optional[List[str]] = None): | ||||||
|         """ |         """ | ||||||
|         Add a new exception handler to an already existing handler object. |         Add a new exception handler to an already existing handler object. | ||||||
|  |  | ||||||
| @@ -44,11 +78,16 @@ class ErrorHandler: | |||||||
|  |  | ||||||
|         :return: None |         :return: None | ||||||
|         """ |         """ | ||||||
|         # self.handlers to be deprecated and removed in version 21.12 |         # self.handlers is deprecated and will be removed in version 22.3 | ||||||
|         self.handlers.append((exception, handler)) |         self.handlers.append((exception, handler)) | ||||||
|         self.cached_handlers[exception] = handler |  | ||||||
|  |  | ||||||
|     def lookup(self, exception): |         if route_names: | ||||||
|  |             for route in route_names: | ||||||
|  |                 self.cached_handlers[(exception, route)] = handler | ||||||
|  |         else: | ||||||
|  |             self.cached_handlers[(exception, None)] = handler | ||||||
|  |  | ||||||
|  |     def lookup(self, exception, route_name: Optional[str] = None): | ||||||
|         """ |         """ | ||||||
|         Lookup the existing instance of :class:`ErrorHandler` and fetch the |         Lookup the existing instance of :class:`ErrorHandler` and fetch the | ||||||
|         registered handler for a specific type of exception. |         registered handler for a specific type of exception. | ||||||
| @@ -63,20 +102,31 @@ class ErrorHandler: | |||||||
|         :return: Registered function if found ``None`` otherwise |         :return: Registered function if found ``None`` otherwise | ||||||
|         """ |         """ | ||||||
|         exception_class = type(exception) |         exception_class = type(exception) | ||||||
|         if exception_class in self.cached_handlers: |  | ||||||
|             return self.cached_handlers[exception_class] |  | ||||||
|  |  | ||||||
|         for ancestor in type.mro(exception_class): |         for name in (route_name, None): | ||||||
|             if ancestor in self.cached_handlers: |             exception_key = (exception_class, name) | ||||||
|                 handler = self.cached_handlers[ancestor] |             handler = self.cached_handlers.get(exception_key) | ||||||
|                 self.cached_handlers[exception_class] = handler |             if handler: | ||||||
|                 return handler |                 return handler | ||||||
|  |  | ||||||
|  |         for name in (route_name, None): | ||||||
|  |             for ancestor in type.mro(exception_class): | ||||||
|  |                 exception_key = (ancestor, name) | ||||||
|  |                 if exception_key in self.cached_handlers: | ||||||
|  |                     handler = self.cached_handlers[exception_key] | ||||||
|  |                     self.cached_handlers[ | ||||||
|  |                         (exception_class, route_name) | ||||||
|  |                     ] = handler | ||||||
|  |                     return handler | ||||||
|  |  | ||||||
|                 if ancestor is BaseException: |                 if ancestor is BaseException: | ||||||
|                     break |                     break | ||||||
|         self.cached_handlers[exception_class] = None |         self.cached_handlers[(exception_class, route_name)] = None | ||||||
|         handler = None |         handler = None | ||||||
|         return handler |         return handler | ||||||
|  |  | ||||||
|  |     _lookup = _full_lookup | ||||||
|  |  | ||||||
|     def response(self, request, exception): |     def response(self, request, exception): | ||||||
|         """Fetches and executes an exception handler and returns a response |         """Fetches and executes an exception handler and returns a response | ||||||
|         object |         object | ||||||
| @@ -91,7 +141,8 @@ class ErrorHandler: | |||||||
|         :return: Wrap the return value obtained from :func:`default` |         :return: Wrap the return value obtained from :func:`default` | ||||||
|             or registered handler for that type of exception. |             or registered handler for that type of exception. | ||||||
|         """ |         """ | ||||||
|         handler = self.lookup(exception) |         route_name = request.name if request else None | ||||||
|  |         handler = self._lookup(exception, route_name) | ||||||
|         response = None |         response = None | ||||||
|         try: |         try: | ||||||
|             if handler: |             if handler: | ||||||
| @@ -99,7 +150,6 @@ class ErrorHandler: | |||||||
|             if response is None: |             if response is None: | ||||||
|                 response = self.default(request, exception) |                 response = self.default(request, exception) | ||||||
|         except Exception: |         except Exception: | ||||||
|             self.log(format_exc()) |  | ||||||
|             try: |             try: | ||||||
|                 url = repr(request.url) |                 url = repr(request.url) | ||||||
|             except AttributeError: |             except AttributeError: | ||||||
| @@ -115,11 +165,6 @@ class ErrorHandler: | |||||||
|                 return text("An error occurred while handling an error", 500) |                 return text("An error occurred while handling an error", 500) | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
|     def log(self, message, level="error"): |  | ||||||
|         """ |  | ||||||
|         Deprecated, do not use. |  | ||||||
|         """ |  | ||||||
|  |  | ||||||
|     def default(self, request, exception): |     def default(self, request, exception): | ||||||
|         """ |         """ | ||||||
|         Provide a default behavior for the objects of :class:`ErrorHandler`. |         Provide a default behavior for the objects of :class:`ErrorHandler`. | ||||||
| @@ -135,6 +180,17 @@ class ErrorHandler: | |||||||
|             :class:`Exception` |             :class:`Exception` | ||||||
|         :return: |         :return: | ||||||
|         """ |         """ | ||||||
|  |         self.log(request, exception) | ||||||
|  |         return exception_response( | ||||||
|  |             request, | ||||||
|  |             exception, | ||||||
|  |             debug=self.debug, | ||||||
|  |             base=self.base, | ||||||
|  |             fallback=self.fallback, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def log(request, exception): | ||||||
|         quiet = getattr(exception, "quiet", False) |         quiet = getattr(exception, "quiet", False) | ||||||
|         if quiet is False: |         if quiet is False: | ||||||
|             try: |             try: | ||||||
| @@ -142,13 +198,10 @@ class ErrorHandler: | |||||||
|             except AttributeError: |             except AttributeError: | ||||||
|                 url = "unknown" |                 url = "unknown" | ||||||
|  |  | ||||||
|             self.log(format_exc()) |  | ||||||
|             error_logger.exception( |             error_logger.exception( | ||||||
|                 "Exception occurred while handling uri: %s", url |                 "Exception occurred while handling uri: %s", url | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         return exception_response(request, exception, self.debug) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ContentRangeHandler: | class ContentRangeHandler: | ||||||
|     """ |     """ | ||||||
|   | |||||||
							
								
								
									
										200
									
								
								sanic/headers.py
									
									
									
									
									
								
							
							
						
						
									
										200
									
								
								sanic/headers.py
									
									
									
									
									
								
							| @@ -1,8 +1,11 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
| import re | import re | ||||||
|  |  | ||||||
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | ||||||
| from urllib.parse import unquote | from urllib.parse import unquote | ||||||
|  |  | ||||||
|  | from sanic.exceptions import InvalidHeader | ||||||
| from sanic.helpers import STATUS_CODES | from sanic.helpers import STATUS_CODES | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -30,6 +33,175 @@ _host_re = re.compile( | |||||||
| # For more information, consult ../tests/test_requests.py | # For more information, consult ../tests/test_requests.py | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def parse_arg_as_accept(f): | ||||||
|  |     def func(self, other, *args, **kwargs): | ||||||
|  |         if not isinstance(other, Accept) and other: | ||||||
|  |             other = Accept.parse(other) | ||||||
|  |         return f(self, other, *args, **kwargs) | ||||||
|  |  | ||||||
|  |     return func | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MediaType(str): | ||||||
|  |     def __new__(cls, value: str): | ||||||
|  |         return str.__new__(cls, value) | ||||||
|  |  | ||||||
|  |     def __init__(self, value: str) -> None: | ||||||
|  |         self.value = value | ||||||
|  |         self.is_wildcard = self.check_if_wildcard(value) | ||||||
|  |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         if self.is_wildcard: | ||||||
|  |             return True | ||||||
|  |  | ||||||
|  |         if self.match(other): | ||||||
|  |             return True | ||||||
|  |  | ||||||
|  |         other_is_wildcard = ( | ||||||
|  |             other.is_wildcard | ||||||
|  |             if isinstance(other, MediaType) | ||||||
|  |             else self.check_if_wildcard(other) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return other_is_wildcard | ||||||
|  |  | ||||||
|  |     def match(self, other): | ||||||
|  |         other_value = other.value if isinstance(other, MediaType) else other | ||||||
|  |         return self.value == other_value | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def check_if_wildcard(value): | ||||||
|  |         return value == "*" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Accept(str): | ||||||
|  |     def __new__(cls, value: str, *args, **kwargs): | ||||||
|  |         return str.__new__(cls, value) | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         value: str, | ||||||
|  |         type_: MediaType, | ||||||
|  |         subtype: MediaType, | ||||||
|  |         *, | ||||||
|  |         q: str = "1.0", | ||||||
|  |         **kwargs: str, | ||||||
|  |     ): | ||||||
|  |         qvalue = float(q) | ||||||
|  |         if qvalue > 1 or qvalue < 0: | ||||||
|  |             raise InvalidHeader( | ||||||
|  |                 f"Accept header qvalue must be between 0 and 1, not: {qvalue}" | ||||||
|  |             ) | ||||||
|  |         self.value = value | ||||||
|  |         self.type_ = type_ | ||||||
|  |         self.subtype = subtype | ||||||
|  |         self.qvalue = qvalue | ||||||
|  |         self.params = kwargs | ||||||
|  |  | ||||||
|  |     def _compare(self, other, method): | ||||||
|  |         try: | ||||||
|  |             return method(self.qvalue, other.qvalue) | ||||||
|  |         except (AttributeError, TypeError): | ||||||
|  |             return NotImplemented | ||||||
|  |  | ||||||
|  |     @parse_arg_as_accept | ||||||
|  |     def __lt__(self, other: Union[str, Accept]): | ||||||
|  |         return self._compare(other, lambda s, o: s < o) | ||||||
|  |  | ||||||
|  |     @parse_arg_as_accept | ||||||
|  |     def __le__(self, other: Union[str, Accept]): | ||||||
|  |         return self._compare(other, lambda s, o: s <= o) | ||||||
|  |  | ||||||
|  |     @parse_arg_as_accept | ||||||
|  |     def __eq__(self, other: Union[str, Accept]):  # type: ignore | ||||||
|  |         return self._compare(other, lambda s, o: s == o) | ||||||
|  |  | ||||||
|  |     @parse_arg_as_accept | ||||||
|  |     def __ge__(self, other: Union[str, Accept]): | ||||||
|  |         return self._compare(other, lambda s, o: s >= o) | ||||||
|  |  | ||||||
|  |     @parse_arg_as_accept | ||||||
|  |     def __gt__(self, other: Union[str, Accept]): | ||||||
|  |         return self._compare(other, lambda s, o: s > o) | ||||||
|  |  | ||||||
|  |     @parse_arg_as_accept | ||||||
|  |     def __ne__(self, other: Union[str, Accept]):  # type: ignore | ||||||
|  |         return self._compare(other, lambda s, o: s != o) | ||||||
|  |  | ||||||
|  |     @parse_arg_as_accept | ||||||
|  |     def match( | ||||||
|  |         self, | ||||||
|  |         other, | ||||||
|  |         *, | ||||||
|  |         allow_type_wildcard: bool = True, | ||||||
|  |         allow_subtype_wildcard: bool = True, | ||||||
|  |     ) -> bool: | ||||||
|  |         type_match = ( | ||||||
|  |             self.type_ == other.type_ | ||||||
|  |             if allow_type_wildcard | ||||||
|  |             else ( | ||||||
|  |                 self.type_.match(other.type_) | ||||||
|  |                 and not self.type_.is_wildcard | ||||||
|  |                 and not other.type_.is_wildcard | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         subtype_match = ( | ||||||
|  |             self.subtype == other.subtype | ||||||
|  |             if allow_subtype_wildcard | ||||||
|  |             else ( | ||||||
|  |                 self.subtype.match(other.subtype) | ||||||
|  |                 and not self.subtype.is_wildcard | ||||||
|  |                 and not other.subtype.is_wildcard | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return type_match and subtype_match | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def parse(cls, raw: str) -> Accept: | ||||||
|  |         invalid = False | ||||||
|  |         mtype = raw.strip() | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             media, *raw_params = mtype.split(";") | ||||||
|  |             type_, subtype = media.split("/") | ||||||
|  |         except ValueError: | ||||||
|  |             invalid = True | ||||||
|  |  | ||||||
|  |         if invalid or not type_ or not subtype: | ||||||
|  |             raise InvalidHeader(f"Header contains invalid Accept value: {raw}") | ||||||
|  |  | ||||||
|  |         params = dict( | ||||||
|  |             [ | ||||||
|  |                 (key.strip(), value.strip()) | ||||||
|  |                 for key, value in (param.split("=", 1) for param in raw_params) | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return cls(mtype, MediaType(type_), MediaType(subtype), **params) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AcceptContainer(list): | ||||||
|  |     def __contains__(self, o: object) -> bool: | ||||||
|  |         return any(item.match(o) for item in self) | ||||||
|  |  | ||||||
|  |     def match( | ||||||
|  |         self, | ||||||
|  |         o: object, | ||||||
|  |         *, | ||||||
|  |         allow_type_wildcard: bool = True, | ||||||
|  |         allow_subtype_wildcard: bool = True, | ||||||
|  |     ) -> bool: | ||||||
|  |         return any( | ||||||
|  |             item.match( | ||||||
|  |                 o, | ||||||
|  |                 allow_type_wildcard=allow_type_wildcard, | ||||||
|  |                 allow_subtype_wildcard=allow_subtype_wildcard, | ||||||
|  |             ) | ||||||
|  |             for item in self | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_content_header(value: str) -> Tuple[str, Options]: | def parse_content_header(value: str) -> Tuple[str, Options]: | ||||||
|     """Parse content-type and content-disposition header values. |     """Parse content-type and content-disposition header values. | ||||||
|  |  | ||||||
| @@ -194,3 +366,31 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: | |||||||
|         ret += b"%b: %b\r\n" % h |         ret += b"%b: %b\r\n" % h | ||||||
|     ret += b"\r\n" |     ret += b"\r\n" | ||||||
|     return ret |     return ret | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _sort_accept_value(accept: Accept): | ||||||
|  |     return ( | ||||||
|  |         accept.qvalue, | ||||||
|  |         len(accept.params), | ||||||
|  |         accept.subtype != "*", | ||||||
|  |         accept.type_ != "*", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def parse_accept(accept: str) -> AcceptContainer: | ||||||
|  |     """Parse an Accept header and order the acceptable media types in | ||||||
|  |     accorsing to RFC 7231, s. 5.3.2 | ||||||
|  |     https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 | ||||||
|  |     """ | ||||||
|  |     media_types = accept.split(",") | ||||||
|  |     accept_list: List[Accept] = [] | ||||||
|  |  | ||||||
|  |     for mtype in media_types: | ||||||
|  |         if not mtype: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         accept_list.append(Accept.parse(mtype)) | ||||||
|  |  | ||||||
|  |     return AcceptContainer( | ||||||
|  |         sorted(accept_list, key=_sort_accept_value, reverse=True) | ||||||
|  |     ) | ||||||
|   | |||||||
| @@ -155,3 +155,17 @@ def import_string(module_name, package=None): | |||||||
|     if ismodule(obj): |     if ismodule(obj): | ||||||
|         return obj |         return obj | ||||||
|     return obj() |     return obj() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Default: | ||||||
|  |     """ | ||||||
|  |     It is used to replace `None` or `object()` as a sentinel | ||||||
|  |     that represents a default value. Sometimes we want to set | ||||||
|  |     a value to `None` so we cannot use `None` to represent the | ||||||
|  |     default value, and `object()` is hard to be typed. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _default = Default() | ||||||
|   | |||||||
| @@ -21,6 +21,7 @@ from sanic.exceptions import ( | |||||||
| from sanic.headers import format_http1_response | from sanic.headers import format_http1_response | ||||||
| from sanic.helpers import has_message_body | from sanic.helpers import has_message_body | ||||||
| from sanic.log import access_logger, error_logger, logger | from sanic.log import access_logger, error_logger, logger | ||||||
|  | from sanic.touchup import TouchUpMeta | ||||||
|  |  | ||||||
|  |  | ||||||
| class Stage(Enum): | class Stage(Enum): | ||||||
| @@ -45,7 +46,7 @@ class Stage(Enum): | |||||||
| HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" | HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" | ||||||
|  |  | ||||||
|  |  | ||||||
| class Http: | class Http(metaclass=TouchUpMeta): | ||||||
|     """ |     """ | ||||||
|     Internal helper for managing the HTTP request/response cycle |     Internal helper for managing the HTTP request/response cycle | ||||||
|  |  | ||||||
| @@ -67,9 +68,15 @@ class Http: | |||||||
|     HEADER_CEILING = 16_384 |     HEADER_CEILING = 16_384 | ||||||
|     HEADER_MAX_SIZE = 0 |     HEADER_MAX_SIZE = 0 | ||||||
|  |  | ||||||
|  |     __touchup__ = ( | ||||||
|  |         "http1_request_header", | ||||||
|  |         "http1_response_header", | ||||||
|  |         "read", | ||||||
|  |     ) | ||||||
|     __slots__ = [ |     __slots__ = [ | ||||||
|         "_send", |         "_send", | ||||||
|         "_receive_more", |         "_receive_more", | ||||||
|  |         "dispatch", | ||||||
|         "recv_buffer", |         "recv_buffer", | ||||||
|         "protocol", |         "protocol", | ||||||
|         "expecting_continue", |         "expecting_continue", | ||||||
| @@ -95,19 +102,24 @@ class Http: | |||||||
|         self._receive_more = protocol.receive_more |         self._receive_more = protocol.receive_more | ||||||
|         self.recv_buffer = protocol.recv_buffer |         self.recv_buffer = protocol.recv_buffer | ||||||
|         self.protocol = protocol |         self.protocol = protocol | ||||||
|         self.expecting_continue: bool = False |         self.keep_alive = True | ||||||
|         self.stage: Stage = Stage.IDLE |         self.stage: Stage = Stage.IDLE | ||||||
|  |         self.dispatch = self.protocol.app.dispatch | ||||||
|  |         self.init_for_request() | ||||||
|  |  | ||||||
|  |     def init_for_request(self): | ||||||
|  |         """Init/reset all per-request variables.""" | ||||||
|  |         self.exception = None | ||||||
|  |         self.expecting_continue: bool = False | ||||||
|  |         self.head_only = None | ||||||
|         self.request_body = None |         self.request_body = None | ||||||
|         self.request_bytes = None |         self.request_bytes = None | ||||||
|         self.request_bytes_left = None |         self.request_bytes_left = None | ||||||
|         self.request_max_size = protocol.request_max_size |         self.request_max_size = self.protocol.request_max_size | ||||||
|         self.keep_alive = True |  | ||||||
|         self.head_only = None |  | ||||||
|         self.request: Request = None |         self.request: Request = None | ||||||
|         self.response: BaseHTTPResponse = None |         self.response: BaseHTTPResponse = None | ||||||
|         self.exception = None |  | ||||||
|         self.url = None |  | ||||||
|         self.upgrade_websocket = False |         self.upgrade_websocket = False | ||||||
|  |         self.url = None | ||||||
|  |  | ||||||
|     def __bool__(self): |     def __bool__(self): | ||||||
|         """Test if request handling is in progress""" |         """Test if request handling is in progress""" | ||||||
| @@ -136,6 +148,12 @@ class Http: | |||||||
|                     await self.response.send(end_stream=True) |                     await self.response.send(end_stream=True) | ||||||
|             except CancelledError: |             except CancelledError: | ||||||
|                 # Write an appropriate response before exiting |                 # Write an appropriate response before exiting | ||||||
|  |                 if not self.protocol.transport: | ||||||
|  |                     logger.info( | ||||||
|  |                         f"Request: {self.request.method} {self.request.url} " | ||||||
|  |                         "stopped. Transport is closed." | ||||||
|  |                     ) | ||||||
|  |                     return | ||||||
|                 e = self.exception or ServiceUnavailable("Cancelled") |                 e = self.exception or ServiceUnavailable("Cancelled") | ||||||
|                 self.exception = None |                 self.exception = None | ||||||
|                 self.keep_alive = False |                 self.keep_alive = False | ||||||
| @@ -148,7 +166,10 @@ class Http: | |||||||
|             if self.request_body: |             if self.request_body: | ||||||
|                 if self.response and 200 <= self.response.status < 300: |                 if self.response and 200 <= self.response.status < 300: | ||||||
|                     error_logger.error(f"{self.request} body not consumed.") |                     error_logger.error(f"{self.request} body not consumed.") | ||||||
|  |                 # Limit the size because the handler may have set it infinite | ||||||
|  |                 self.request_max_size = min( | ||||||
|  |                     self.request_max_size, self.protocol.request_max_size | ||||||
|  |                 ) | ||||||
|                 try: |                 try: | ||||||
|                     async for _ in self: |                     async for _ in self: | ||||||
|                         pass |                         pass | ||||||
| @@ -160,15 +181,23 @@ class Http: | |||||||
|                     await sleep(0.001) |                     await sleep(0.001) | ||||||
|                     self.keep_alive = False |                     self.keep_alive = False | ||||||
|  |  | ||||||
|  |             # Clean up to free memory and for the next request | ||||||
|  |             if self.request: | ||||||
|  |                 self.request.stream = None | ||||||
|  |                 if self.response: | ||||||
|  |                     self.response.stream = None | ||||||
|  |  | ||||||
|             # Exit and disconnect if no more requests can be taken |             # Exit and disconnect if no more requests can be taken | ||||||
|             if self.stage is not Stage.IDLE or not self.keep_alive: |             if self.stage is not Stage.IDLE or not self.keep_alive: | ||||||
|                 break |                 break | ||||||
|  |  | ||||||
|             # Wait for next request |             self.init_for_request() | ||||||
|  |  | ||||||
|  |             # Wait for the next request | ||||||
|             if not self.recv_buffer: |             if not self.recv_buffer: | ||||||
|                 await self._receive_more() |                 await self._receive_more() | ||||||
|  |  | ||||||
|     async def http1_request_header(self): |     async def http1_request_header(self):  # no cov | ||||||
|         """ |         """ | ||||||
|         Receive and parse request header into self.request. |         Receive and parse request header into self.request. | ||||||
|         """ |         """ | ||||||
| @@ -197,6 +226,12 @@ class Http: | |||||||
|             reqline, *split_headers = raw_headers.split("\r\n") |             reqline, *split_headers = raw_headers.split("\r\n") | ||||||
|             method, self.url, protocol = reqline.split(" ") |             method, self.url, protocol = reqline.split(" ") | ||||||
|  |  | ||||||
|  |             await self.dispatch( | ||||||
|  |                 "http.lifecycle.read_head", | ||||||
|  |                 inline=True, | ||||||
|  |                 context={"head": bytes(head)}, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|             if protocol == "HTTP/1.1": |             if protocol == "HTTP/1.1": | ||||||
|                 self.keep_alive = True |                 self.keep_alive = True | ||||||
|             elif protocol == "HTTP/1.0": |             elif protocol == "HTTP/1.0": | ||||||
| @@ -235,6 +270,11 @@ class Http: | |||||||
|             transport=self.protocol.transport, |             transport=self.protocol.transport, | ||||||
|             app=self.protocol.app, |             app=self.protocol.app, | ||||||
|         ) |         ) | ||||||
|  |         await self.dispatch( | ||||||
|  |             "http.lifecycle.request", | ||||||
|  |             inline=True, | ||||||
|  |             context={"request": request}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # Prepare for request body |         # Prepare for request body | ||||||
|         self.request_bytes_left = self.request_bytes = 0 |         self.request_bytes_left = self.request_bytes = 0 | ||||||
| @@ -265,7 +305,7 @@ class Http: | |||||||
|  |  | ||||||
|     async def http1_response_header( |     async def http1_response_header( | ||||||
|         self, data: bytes, end_stream: bool |         self, data: bytes, end_stream: bool | ||||||
|     ) -> None: |     ) -> None:  # no cov | ||||||
|         res = self.response |         res = self.response | ||||||
|  |  | ||||||
|         # Compatibility with simple response body |         # Compatibility with simple response body | ||||||
| @@ -437,8 +477,8 @@ class Http: | |||||||
|             "request": "nil", |             "request": "nil", | ||||||
|         } |         } | ||||||
|         if req is not None: |         if req is not None: | ||||||
|             if req.ip: |             if req.remote_addr or req.ip: | ||||||
|                 extra["host"] = f"{req.ip}:{req.port}" |                 extra["host"] = f"{req.remote_addr or req.ip}:{req.port}" | ||||||
|             extra["request"] = f"{req.method} {req.url}" |             extra["request"] = f"{req.method} {req.url}" | ||||||
|         access_logger.info("", extra=extra) |         access_logger.info("", extra=extra) | ||||||
|  |  | ||||||
| @@ -454,7 +494,7 @@ class Http: | |||||||
|             if data: |             if data: | ||||||
|                 yield data |                 yield data | ||||||
|  |  | ||||||
|     async def read(self) -> Optional[bytes]: |     async def read(self) -> Optional[bytes]:  # no cov | ||||||
|         """ |         """ | ||||||
|         Read some bytes of request body. |         Read some bytes of request body. | ||||||
|         """ |         """ | ||||||
| @@ -486,8 +526,6 @@ class Http: | |||||||
|                 self.keep_alive = False |                 self.keep_alive = False | ||||||
|                 raise InvalidUsage("Bad chunked encoding") |                 raise InvalidUsage("Bad chunked encoding") | ||||||
|  |  | ||||||
|             del buf[: pos + 2] |  | ||||||
|  |  | ||||||
|             if size <= 0: |             if size <= 0: | ||||||
|                 self.request_body = None |                 self.request_body = None | ||||||
|  |  | ||||||
| @@ -495,8 +533,17 @@ class Http: | |||||||
|                     self.keep_alive = False |                     self.keep_alive = False | ||||||
|                     raise InvalidUsage("Bad chunked encoding") |                     raise InvalidUsage("Bad chunked encoding") | ||||||
|  |  | ||||||
|  |                 # Consume CRLF, chunk size 0 and the two CRLF that follow | ||||||
|  |                 pos += 4 | ||||||
|  |                 # Might need to wait for the final CRLF | ||||||
|  |                 while len(buf) < pos: | ||||||
|  |                     await self._receive_more() | ||||||
|  |                 del buf[:pos] | ||||||
|                 return None |                 return None | ||||||
|  |  | ||||||
|  |             # Remove CRLF, chunk size and the CRLF that follows | ||||||
|  |             del buf[: pos + 2] | ||||||
|  |  | ||||||
|             self.request_bytes_left = size |             self.request_bytes_left = size | ||||||
|             self.request_bytes += size |             self.request_bytes += size | ||||||
|  |  | ||||||
| @@ -521,6 +568,12 @@ class Http: | |||||||
|  |  | ||||||
|         self.request_bytes_left -= size |         self.request_bytes_left -= size | ||||||
|  |  | ||||||
|  |         await self.dispatch( | ||||||
|  |             "http.lifecycle.read_body", | ||||||
|  |             inline=True, | ||||||
|  |             context={"body": data}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         return data |         return data | ||||||
|  |  | ||||||
|     # Response methods |     # Response methods | ||||||
|   | |||||||
| @@ -1,18 +1,19 @@ | |||||||
| from enum import Enum, auto | from enum import Enum, auto | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Any, Callable, Coroutine, List, Optional, Union | from typing import List, Optional, Union | ||||||
|  |  | ||||||
| from sanic.models.futures import FutureListener | from sanic.models.futures import FutureListener | ||||||
|  | from sanic.models.handler_types import ListenerType | ||||||
|  |  | ||||||
|  |  | ||||||
| class ListenerEvent(str, Enum): | class ListenerEvent(str, Enum): | ||||||
|     def _generate_next_value_(name: str, *args) -> str:  # type: ignore |     def _generate_next_value_(name: str, *args) -> str:  # type: ignore | ||||||
|         return name.lower() |         return name.lower() | ||||||
|  |  | ||||||
|     BEFORE_SERVER_START = auto() |     BEFORE_SERVER_START = "server.init.before" | ||||||
|     AFTER_SERVER_START = auto() |     AFTER_SERVER_START = "server.init.after" | ||||||
|     BEFORE_SERVER_STOP = auto() |     BEFORE_SERVER_STOP = "server.shutdown.before" | ||||||
|     AFTER_SERVER_STOP = auto() |     AFTER_SERVER_STOP = "server.shutdown.after" | ||||||
|     MAIN_PROCESS_START = auto() |     MAIN_PROCESS_START = auto() | ||||||
|     MAIN_PROCESS_STOP = auto() |     MAIN_PROCESS_STOP = auto() | ||||||
|  |  | ||||||
| @@ -26,9 +27,7 @@ class ListenerMixin: | |||||||
|  |  | ||||||
|     def listener( |     def listener( | ||||||
|         self, |         self, | ||||||
|         listener_or_event: Union[ |         listener_or_event: Union[ListenerType, str], | ||||||
|             Callable[..., Coroutine[Any, Any, None]], str |  | ||||||
|         ], |  | ||||||
|         event_or_none: Optional[str] = None, |         event_or_none: Optional[str] = None, | ||||||
|         apply: bool = True, |         apply: bool = True, | ||||||
|     ): |     ): | ||||||
| @@ -63,20 +62,20 @@ class ListenerMixin: | |||||||
|         else: |         else: | ||||||
|             return partial(register_listener, event=listener_or_event) |             return partial(register_listener, event=listener_or_event) | ||||||
|  |  | ||||||
|     def main_process_start(self, listener): |     def main_process_start(self, listener: ListenerType) -> ListenerType: | ||||||
|         return self.listener(listener, "main_process_start") |         return self.listener(listener, "main_process_start") | ||||||
|  |  | ||||||
|     def main_process_stop(self, listener): |     def main_process_stop(self, listener: ListenerType) -> ListenerType: | ||||||
|         return self.listener(listener, "main_process_stop") |         return self.listener(listener, "main_process_stop") | ||||||
|  |  | ||||||
|     def before_server_start(self, listener): |     def before_server_start(self, listener: ListenerType) -> ListenerType: | ||||||
|         return self.listener(listener, "before_server_start") |         return self.listener(listener, "before_server_start") | ||||||
|  |  | ||||||
|     def after_server_start(self, listener): |     def after_server_start(self, listener: ListenerType) -> ListenerType: | ||||||
|         return self.listener(listener, "after_server_start") |         return self.listener(listener, "after_server_start") | ||||||
|  |  | ||||||
|     def before_server_stop(self, listener): |     def before_server_stop(self, listener: ListenerType) -> ListenerType: | ||||||
|         return self.listener(listener, "before_server_stop") |         return self.listener(listener, "before_server_stop") | ||||||
|  |  | ||||||
|     def after_server_stop(self, listener): |     def after_server_stop(self, listener: ListenerType) -> ListenerType: | ||||||
|         return self.listener(listener, "after_server_stop") |         return self.listener(listener, "after_server_stop") | ||||||
|   | |||||||
| @@ -1,17 +1,20 @@ | |||||||
|  | from ast import NodeVisitor, Return, parse | ||||||
| from functools import partial, wraps | from functools import partial, wraps | ||||||
| from inspect import signature | from inspect import getsource, signature | ||||||
| from mimetypes import guess_type | from mimetypes import guess_type | ||||||
| from os import path | from os import path | ||||||
| from pathlib import PurePath | from pathlib import PurePath | ||||||
| from re import sub | from re import sub | ||||||
|  | from textwrap import dedent | ||||||
| from time import gmtime, strftime | from time import gmtime, strftime | ||||||
| from typing import Iterable, List, Optional, Set, Union | from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union | ||||||
| from urllib.parse import unquote | from urllib.parse import unquote | ||||||
|  |  | ||||||
| from sanic_routing.route import Route  # type: ignore | from sanic_routing.route import Route  # type: ignore | ||||||
|  |  | ||||||
| from sanic.compat import stat_async | from sanic.compat import stat_async | ||||||
| from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS | from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS | ||||||
|  | from sanic.errorpages import RESPONSE_MAPPING | ||||||
| from sanic.exceptions import ( | from sanic.exceptions import ( | ||||||
|     ContentRangeError, |     ContentRangeError, | ||||||
|     FileNotFound, |     FileNotFound, | ||||||
| @@ -21,10 +24,16 @@ from sanic.exceptions import ( | |||||||
| from sanic.handlers import ContentRangeHandler | from sanic.handlers import ContentRangeHandler | ||||||
| from sanic.log import error_logger | from sanic.log import error_logger | ||||||
| from sanic.models.futures import FutureRoute, FutureStatic | from sanic.models.futures import FutureRoute, FutureStatic | ||||||
|  | from sanic.models.handler_types import RouteHandler | ||||||
| from sanic.response import HTTPResponse, file, file_stream | from sanic.response import HTTPResponse, file, file_stream | ||||||
| from sanic.views import CompositionView | from sanic.views import CompositionView | ||||||
|  |  | ||||||
|  |  | ||||||
|  | RouteWrapper = Callable[ | ||||||
|  |     [RouteHandler], Union[RouteHandler, Tuple[Route, RouteHandler]] | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
| class RouteMixin: | class RouteMixin: | ||||||
|     name: str |     name: str | ||||||
|  |  | ||||||
| @@ -55,7 +64,8 @@ class RouteMixin: | |||||||
|         unquote: bool = False, |         unquote: bool = False, | ||||||
|         static: bool = False, |         static: bool = False, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Decorate a function to be registered as a route |         Decorate a function to be registered as a route | ||||||
|  |  | ||||||
| @@ -97,6 +107,7 @@ class RouteMixin: | |||||||
|             nonlocal websocket |             nonlocal websocket | ||||||
|             nonlocal static |             nonlocal static | ||||||
|             nonlocal version_prefix |             nonlocal version_prefix | ||||||
|  |             nonlocal error_format | ||||||
|  |  | ||||||
|             if isinstance(handler, tuple): |             if isinstance(handler, tuple): | ||||||
|                 # if a handler fn is already wrapped in a route, the handler |                 # if a handler fn is already wrapped in a route, the handler | ||||||
| @@ -115,10 +126,16 @@ class RouteMixin: | |||||||
|                         "Expected either string or Iterable of host strings, " |                         "Expected either string or Iterable of host strings, " | ||||||
|                         "not %s" % host |                         "not %s" % host | ||||||
|                     ) |                     ) | ||||||
|  |             if isinstance(subprotocols, list): | ||||||
|             if isinstance(subprotocols, (list, tuple, set)): |                 # Ordered subprotocols, maintain order | ||||||
|  |                 subprotocols = tuple(subprotocols) | ||||||
|  |             elif isinstance(subprotocols, set): | ||||||
|  |                 # subprotocol is unordered, keep it unordered | ||||||
|                 subprotocols = frozenset(subprotocols) |                 subprotocols = frozenset(subprotocols) | ||||||
|  |  | ||||||
|  |             if not error_format or error_format == "auto": | ||||||
|  |                 error_format = self._determine_error_format(handler) | ||||||
|  |  | ||||||
|             route = FutureRoute( |             route = FutureRoute( | ||||||
|                 handler, |                 handler, | ||||||
|                 uri, |                 uri, | ||||||
| @@ -134,6 +151,7 @@ class RouteMixin: | |||||||
|                 unquote, |                 unquote, | ||||||
|                 static, |                 static, | ||||||
|                 version_prefix, |                 version_prefix, | ||||||
|  |                 error_format, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self._future_routes.add(route) |             self._future_routes.add(route) | ||||||
| @@ -168,7 +186,7 @@ class RouteMixin: | |||||||
|  |  | ||||||
|     def add_route( |     def add_route( | ||||||
|         self, |         self, | ||||||
|         handler, |         handler: RouteHandler, | ||||||
|         uri: str, |         uri: str, | ||||||
|         methods: Iterable[str] = frozenset({"GET"}), |         methods: Iterable[str] = frozenset({"GET"}), | ||||||
|         host: Optional[str] = None, |         host: Optional[str] = None, | ||||||
| @@ -177,7 +195,8 @@ class RouteMixin: | |||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         stream: bool = False, |         stream: bool = False, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteHandler: | ||||||
|         """A helper method to register class instance or |         """A helper method to register class instance or | ||||||
|         functions as a handler to the application url |         functions as a handler to the application url | ||||||
|         routes. |         routes. | ||||||
| @@ -200,7 +219,8 @@ class RouteMixin: | |||||||
|             methods = set() |             methods = set() | ||||||
|  |  | ||||||
|             for method in HTTP_METHODS: |             for method in HTTP_METHODS: | ||||||
|                 _handler = getattr(handler.view_class, method.lower(), None) |                 view_class = getattr(handler, "view_class") | ||||||
|  |                 _handler = getattr(view_class, method.lower(), None) | ||||||
|                 if _handler: |                 if _handler: | ||||||
|                     methods.add(method) |                     methods.add(method) | ||||||
|                     if hasattr(_handler, "is_stream"): |                     if hasattr(_handler, "is_stream"): | ||||||
| @@ -226,6 +246,7 @@ class RouteMixin: | |||||||
|             version=version, |             version=version, | ||||||
|             name=name, |             name=name, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         )(handler) |         )(handler) | ||||||
|         return handler |         return handler | ||||||
|  |  | ||||||
| @@ -239,7 +260,8 @@ class RouteMixin: | |||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         ignore_body: bool = True, |         ignore_body: bool = True, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Add an API URL under the **GET** *HTTP* method |         Add an API URL under the **GET** *HTTP* method | ||||||
|  |  | ||||||
| @@ -262,6 +284,7 @@ class RouteMixin: | |||||||
|             name=name, |             name=name, | ||||||
|             ignore_body=ignore_body, |             ignore_body=ignore_body, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def post( |     def post( | ||||||
| @@ -273,7 +296,8 @@ class RouteMixin: | |||||||
|         version: Optional[int] = None, |         version: Optional[int] = None, | ||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Add an API URL under the **POST** *HTTP* method |         Add an API URL under the **POST** *HTTP* method | ||||||
|  |  | ||||||
| @@ -296,6 +320,7 @@ class RouteMixin: | |||||||
|             version=version, |             version=version, | ||||||
|             name=name, |             name=name, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def put( |     def put( | ||||||
| @@ -307,7 +332,8 @@ class RouteMixin: | |||||||
|         version: Optional[int] = None, |         version: Optional[int] = None, | ||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Add an API URL under the **PUT** *HTTP* method |         Add an API URL under the **PUT** *HTTP* method | ||||||
|  |  | ||||||
| @@ -330,6 +356,7 @@ class RouteMixin: | |||||||
|             version=version, |             version=version, | ||||||
|             name=name, |             name=name, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def head( |     def head( | ||||||
| @@ -341,7 +368,8 @@ class RouteMixin: | |||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         ignore_body: bool = True, |         ignore_body: bool = True, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Add an API URL under the **HEAD** *HTTP* method |         Add an API URL under the **HEAD** *HTTP* method | ||||||
|  |  | ||||||
| @@ -372,6 +400,7 @@ class RouteMixin: | |||||||
|             name=name, |             name=name, | ||||||
|             ignore_body=ignore_body, |             ignore_body=ignore_body, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def options( |     def options( | ||||||
| @@ -383,7 +412,8 @@ class RouteMixin: | |||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         ignore_body: bool = True, |         ignore_body: bool = True, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Add an API URL under the **OPTIONS** *HTTP* method |         Add an API URL under the **OPTIONS** *HTTP* method | ||||||
|  |  | ||||||
| @@ -414,6 +444,7 @@ class RouteMixin: | |||||||
|             name=name, |             name=name, | ||||||
|             ignore_body=ignore_body, |             ignore_body=ignore_body, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def patch( |     def patch( | ||||||
| @@ -425,7 +456,8 @@ class RouteMixin: | |||||||
|         version: Optional[int] = None, |         version: Optional[int] = None, | ||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Add an API URL under the **PATCH** *HTTP* method |         Add an API URL under the **PATCH** *HTTP* method | ||||||
|  |  | ||||||
| @@ -458,6 +490,7 @@ class RouteMixin: | |||||||
|             version=version, |             version=version, | ||||||
|             name=name, |             name=name, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def delete( |     def delete( | ||||||
| @@ -469,7 +502,8 @@ class RouteMixin: | |||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         ignore_body: bool = True, |         ignore_body: bool = True, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|     ): |         error_format: Optional[str] = None, | ||||||
|  |     ) -> RouteWrapper: | ||||||
|         """ |         """ | ||||||
|         Add an API URL under the **DELETE** *HTTP* method |         Add an API URL under the **DELETE** *HTTP* method | ||||||
|  |  | ||||||
| @@ -492,6 +526,7 @@ class RouteMixin: | |||||||
|             name=name, |             name=name, | ||||||
|             ignore_body=ignore_body, |             ignore_body=ignore_body, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def websocket( |     def websocket( | ||||||
| @@ -504,6 +539,7 @@ class RouteMixin: | |||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         apply: bool = True, |         apply: bool = True, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|  |         error_format: Optional[str] = None, | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Decorate a function to be registered as a websocket route |         Decorate a function to be registered as a websocket route | ||||||
| @@ -530,6 +566,7 @@ class RouteMixin: | |||||||
|             subprotocols=subprotocols, |             subprotocols=subprotocols, | ||||||
|             websocket=True, |             websocket=True, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def add_websocket_route( |     def add_websocket_route( | ||||||
| @@ -542,6 +579,7 @@ class RouteMixin: | |||||||
|         version: Optional[int] = None, |         version: Optional[int] = None, | ||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|  |         error_format: Optional[str] = None, | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         A helper method to register a function as a websocket route. |         A helper method to register a function as a websocket route. | ||||||
| @@ -570,6 +608,7 @@ class RouteMixin: | |||||||
|             version=version, |             version=version, | ||||||
|             name=name, |             name=name, | ||||||
|             version_prefix=version_prefix, |             version_prefix=version_prefix, | ||||||
|  |             error_format=error_format, | ||||||
|         )(handler) |         )(handler) | ||||||
|  |  | ||||||
|     def static( |     def static( | ||||||
| @@ -585,6 +624,7 @@ class RouteMixin: | |||||||
|         strict_slashes=None, |         strict_slashes=None, | ||||||
|         content_type=None, |         content_type=None, | ||||||
|         apply=True, |         apply=True, | ||||||
|  |         resource_type=None, | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Register a root to serve files from. The input can either be a |         Register a root to serve files from. The input can either be a | ||||||
| @@ -634,6 +674,7 @@ class RouteMixin: | |||||||
|             host, |             host, | ||||||
|             strict_slashes, |             strict_slashes, | ||||||
|             content_type, |             content_type, | ||||||
|  |             resource_type, | ||||||
|         ) |         ) | ||||||
|         self._future_statics.add(static) |         self._future_statics.add(static) | ||||||
|  |  | ||||||
| @@ -777,10 +818,11 @@ class RouteMixin: | |||||||
|             ) |             ) | ||||||
|         except Exception: |         except Exception: | ||||||
|             error_logger.exception( |             error_logger.exception( | ||||||
|                 f"Exception in static request handler:\ |                 f"Exception in static request handler: " | ||||||
|  path={file_or_directory}, " |                 f"path={file_or_directory}, " | ||||||
|                 f"relative_url={__file_uri__}" |                 f"relative_url={__file_uri__}" | ||||||
|             ) |             ) | ||||||
|  |             raise | ||||||
|  |  | ||||||
|     def _register_static( |     def _register_static( | ||||||
|         self, |         self, | ||||||
| @@ -828,8 +870,27 @@ class RouteMixin: | |||||||
|         name = static.name |         name = static.name | ||||||
|         # If we're not trying to match a file directly, |         # If we're not trying to match a file directly, | ||||||
|         # serve from the folder |         # serve from the folder | ||||||
|  |         if not static.resource_type: | ||||||
|             if not path.isfile(file_or_directory): |             if not path.isfile(file_or_directory): | ||||||
|                 uri += "/<__file_uri__:path>" |                 uri += "/<__file_uri__:path>" | ||||||
|  |         elif static.resource_type == "dir": | ||||||
|  |             if path.isfile(file_or_directory): | ||||||
|  |                 raise TypeError( | ||||||
|  |                     "Resource type improperly identified as directory. " | ||||||
|  |                     f"'{file_or_directory}'" | ||||||
|  |                 ) | ||||||
|  |             uri += "/<__file_uri__:path>" | ||||||
|  |         elif static.resource_type == "file" and not path.isfile( | ||||||
|  |             file_or_directory | ||||||
|  |         ): | ||||||
|  |             raise TypeError( | ||||||
|  |                 "Resource type improperly identified as file. " | ||||||
|  |                 f"'{file_or_directory}'" | ||||||
|  |             ) | ||||||
|  |         elif static.resource_type != "file": | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The resource_type should be set to 'file' or 'dir'" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         # special prefix for static files |         # special prefix for static files | ||||||
|         # if not static.name.startswith("_static_"): |         # if not static.name.startswith("_static_"): | ||||||
| @@ -846,7 +907,7 @@ class RouteMixin: | |||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         route, _ = self.route( |         route, _ = self.route(  # type: ignore | ||||||
|             uri=uri, |             uri=uri, | ||||||
|             methods=["GET", "HEAD"], |             methods=["GET", "HEAD"], | ||||||
|             name=name, |             name=name, | ||||||
| @@ -856,3 +917,43 @@ class RouteMixin: | |||||||
|         )(_handler) |         )(_handler) | ||||||
|  |  | ||||||
|         return route |         return route | ||||||
|  |  | ||||||
|  |     def _determine_error_format(self, handler) -> str: | ||||||
|  |         if not isinstance(handler, CompositionView): | ||||||
|  |             try: | ||||||
|  |                 src = dedent(getsource(handler)) | ||||||
|  |                 tree = parse(src) | ||||||
|  |                 http_response_types = self._get_response_types(tree) | ||||||
|  |  | ||||||
|  |                 if len(http_response_types) == 1: | ||||||
|  |                     return next(iter(http_response_types)) | ||||||
|  |             except (OSError, TypeError): | ||||||
|  |                 ... | ||||||
|  |  | ||||||
|  |         return "auto" | ||||||
|  |  | ||||||
|  |     def _get_response_types(self, node): | ||||||
|  |         types = set() | ||||||
|  |  | ||||||
|  |         class HttpResponseVisitor(NodeVisitor): | ||||||
|  |             def visit_Return(self, node: Return) -> Any: | ||||||
|  |                 nonlocal types | ||||||
|  |  | ||||||
|  |                 try: | ||||||
|  |                     checks = [node.value.func.id]  # type: ignore | ||||||
|  |                     if node.value.keywords:  # type: ignore | ||||||
|  |                         checks += [ | ||||||
|  |                             k.value | ||||||
|  |                             for k in node.value.keywords  # type: ignore | ||||||
|  |                             if k.arg == "content_type" | ||||||
|  |                         ] | ||||||
|  |  | ||||||
|  |                     for check in checks: | ||||||
|  |                         if check in RESPONSE_MAPPING: | ||||||
|  |                             types.add(RESPONSE_MAPPING[check]) | ||||||
|  |                 except AttributeError: | ||||||
|  |                     ... | ||||||
|  |  | ||||||
|  |         HttpResponseVisitor().visit(node) | ||||||
|  |  | ||||||
|  |         return types | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ class SignalMixin: | |||||||
|         *, |         *, | ||||||
|         apply: bool = True, |         apply: bool = True, | ||||||
|         condition: Dict[str, Any] = None, |         condition: Dict[str, Any] = None, | ||||||
|     ) -> Callable[[SignalHandler], FutureSignal]: |     ) -> Callable[[SignalHandler], SignalHandler]: | ||||||
|         """ |         """ | ||||||
|         For creating a signal handler, used similar to a route handler: |         For creating a signal handler, used similar to a route handler: | ||||||
|  |  | ||||||
| @@ -54,7 +54,7 @@ class SignalMixin: | |||||||
|             if apply: |             if apply: | ||||||
|                 self._apply_signal(future_signal) |                 self._apply_signal(future_signal) | ||||||
|  |  | ||||||
|             return future_signal |             return handler | ||||||
|  |  | ||||||
|         return decorator |         return decorator | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ import asyncio | |||||||
| from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union | from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union | ||||||
|  |  | ||||||
| from sanic.exceptions import InvalidUsage | from sanic.exceptions import InvalidUsage | ||||||
| from sanic.websocket import WebSocketConnection | from sanic.server.websockets.connection import WebSocketConnection | ||||||
|  |  | ||||||
|  |  | ||||||
| ASGIScope = MutableMapping[str, Any] | ASGIScope = MutableMapping[str, Any] | ||||||
|   | |||||||
| @@ -24,6 +24,7 @@ class FutureRoute(NamedTuple): | |||||||
|     unquote: bool |     unquote: bool | ||||||
|     static: bool |     static: bool | ||||||
|     version_prefix: str |     version_prefix: str | ||||||
|  |     error_format: Optional[str] | ||||||
|  |  | ||||||
|  |  | ||||||
| class FutureListener(NamedTuple): | class FutureListener(NamedTuple): | ||||||
| @@ -52,6 +53,7 @@ class FutureStatic(NamedTuple): | |||||||
|     host: Optional[str] |     host: Optional[str] | ||||||
|     strict_slashes: Optional[bool] |     strict_slashes: Optional[bool] | ||||||
|     content_type: Optional[bool] |     content_type: Optional[bool] | ||||||
|  |     resource_type: Optional[str] | ||||||
|  |  | ||||||
|  |  | ||||||
| class FutureSignal(NamedTuple): | class FutureSignal(NamedTuple): | ||||||
|   | |||||||
| @@ -21,5 +21,5 @@ MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType] | |||||||
| ListenerType = Callable[ | ListenerType = Callable[ | ||||||
|     [Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]] |     [Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]] | ||||||
| ] | ] | ||||||
| RouteHandler = Callable[..., Coroutine[Any, Any, HTTPResponse]] | RouteHandler = Callable[..., Coroutine[Any, Any, Optional[HTTPResponse]]] | ||||||
| SignalHandler = Callable[..., Coroutine[Any, Any, None]] | SignalHandler = Callable[..., Coroutine[Any, Any, None]] | ||||||
|   | |||||||
							
								
								
									
										52
									
								
								sanic/models/server_types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								sanic/models/server_types.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | |||||||
|  | from types import SimpleNamespace | ||||||
|  |  | ||||||
|  | from sanic.models.protocol_types import TransportProtocol | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Signal: | ||||||
|  |     stopped = False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConnInfo: | ||||||
|  |     """ | ||||||
|  |     Local and remote addresses and SSL status info. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     __slots__ = ( | ||||||
|  |         "client_port", | ||||||
|  |         "client", | ||||||
|  |         "client_ip", | ||||||
|  |         "ctx", | ||||||
|  |         "peername", | ||||||
|  |         "server_port", | ||||||
|  |         "server", | ||||||
|  |         "sockname", | ||||||
|  |         "ssl", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     def __init__(self, transport: TransportProtocol, unix=None): | ||||||
|  |         self.ctx = SimpleNamespace() | ||||||
|  |         self.peername = None | ||||||
|  |         self.server = self.client = "" | ||||||
|  |         self.server_port = self.client_port = 0 | ||||||
|  |         self.client_ip = "" | ||||||
|  |         self.sockname = addr = transport.get_extra_info("sockname") | ||||||
|  |         self.ssl: bool = bool(transport.get_extra_info("sslcontext")) | ||||||
|  |  | ||||||
|  |         if isinstance(addr, str):  # UNIX socket | ||||||
|  |             self.server = unix or addr | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) | ||||||
|  |         if isinstance(addr, tuple): | ||||||
|  |             self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" | ||||||
|  |             self.server_port = addr[1] | ||||||
|  |             # self.server gets non-standard port appended | ||||||
|  |             if addr[1] != (443 if self.ssl else 80): | ||||||
|  |                 self.server = f"{self.server}:{addr[1]}" | ||||||
|  |         self.peername = addr = transport.get_extra_info("peername") | ||||||
|  |  | ||||||
|  |         if isinstance(addr, tuple): | ||||||
|  |             self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" | ||||||
|  |             self.client_ip = addr[0] | ||||||
|  |             self.client_port = addr[1] | ||||||
| @@ -34,7 +34,9 @@ from sanic.compat import CancelledErrors, Header | |||||||
| from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE | from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE | ||||||
| from sanic.exceptions import InvalidUsage | from sanic.exceptions import InvalidUsage | ||||||
| from sanic.headers import ( | from sanic.headers import ( | ||||||
|  |     AcceptContainer, | ||||||
|     Options, |     Options, | ||||||
|  |     parse_accept, | ||||||
|     parse_content_header, |     parse_content_header, | ||||||
|     parse_forwarded, |     parse_forwarded, | ||||||
|     parse_host, |     parse_host, | ||||||
| @@ -94,6 +96,7 @@ class Request: | |||||||
|         "head", |         "head", | ||||||
|         "headers", |         "headers", | ||||||
|         "method", |         "method", | ||||||
|  |         "parsed_accept", | ||||||
|         "parsed_args", |         "parsed_args", | ||||||
|         "parsed_not_grouped_args", |         "parsed_not_grouped_args", | ||||||
|         "parsed_files", |         "parsed_files", | ||||||
| @@ -136,6 +139,7 @@ class Request: | |||||||
|         self.conn_info: Optional[ConnInfo] = None |         self.conn_info: Optional[ConnInfo] = None | ||||||
|         self.ctx = SimpleNamespace() |         self.ctx = SimpleNamespace() | ||||||
|         self.parsed_forwarded: Optional[Options] = None |         self.parsed_forwarded: Optional[Options] = None | ||||||
|  |         self.parsed_accept: Optional[AcceptContainer] = None | ||||||
|         self.parsed_json = None |         self.parsed_json = None | ||||||
|         self.parsed_form = None |         self.parsed_form = None | ||||||
|         self.parsed_files = None |         self.parsed_files = None | ||||||
| @@ -296,6 +300,13 @@ class Request: | |||||||
|  |  | ||||||
|         return self.parsed_json |         return self.parsed_json | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def accept(self) -> AcceptContainer: | ||||||
|  |         if self.parsed_accept is None: | ||||||
|  |             accept_header = self.headers.getone("accept", "") | ||||||
|  |             self.parsed_accept = parse_accept(accept_header) | ||||||
|  |         return self.parsed_accept | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def token(self): |     def token(self): | ||||||
|         """Attempt to return the auth header token. |         """Attempt to return the auth header token. | ||||||
| @@ -497,6 +508,10 @@ class Request: | |||||||
|         """ |         """ | ||||||
|         return self._match_info |         return self._match_info | ||||||
|  |  | ||||||
|  |     @match_info.setter | ||||||
|  |     def match_info(self, value): | ||||||
|  |         self._match_info = value | ||||||
|  |  | ||||||
|     # Transport properties (obtained from local interface only) |     # Transport properties (obtained from local interface only) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|   | |||||||
| @@ -1,5 +1,9 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
|  | from inspect import signature | ||||||
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | ||||||
|  | from uuid import UUID | ||||||
|  |  | ||||||
| from sanic_routing import BaseRouter  # type: ignore | from sanic_routing import BaseRouter  # type: ignore | ||||||
| from sanic_routing.exceptions import NoMethod  # type: ignore | from sanic_routing.exceptions import NoMethod  # type: ignore | ||||||
| @@ -9,6 +13,7 @@ from sanic_routing.exceptions import ( | |||||||
| from sanic_routing.route import Route  # type: ignore | from sanic_routing.route import Route  # type: ignore | ||||||
|  |  | ||||||
| from sanic.constants import HTTP_METHODS | from sanic.constants import HTTP_METHODS | ||||||
|  | from sanic.errorpages import check_error_format | ||||||
| from sanic.exceptions import MethodNotSupported, NotFound, SanicException | from sanic.exceptions import MethodNotSupported, NotFound, SanicException | ||||||
| from sanic.models.handler_types import RouteHandler | from sanic.models.handler_types import RouteHandler | ||||||
|  |  | ||||||
| @@ -74,6 +79,7 @@ class Router(BaseRouter): | |||||||
|         unquote: bool = False, |         unquote: bool = False, | ||||||
|         static: bool = False, |         static: bool = False, | ||||||
|         version_prefix: str = "/v", |         version_prefix: str = "/v", | ||||||
|  |         error_format: Optional[str] = None, | ||||||
|     ) -> Union[Route, List[Route]]: |     ) -> Union[Route, List[Route]]: | ||||||
|         """ |         """ | ||||||
|         Add a handler to the router |         Add a handler to the router | ||||||
| @@ -106,6 +112,8 @@ class Router(BaseRouter): | |||||||
|             version = str(version).strip("/").lstrip("v") |             version = str(version).strip("/").lstrip("v") | ||||||
|             uri = "/".join([f"{version_prefix}{version}", uri.lstrip("/")]) |             uri = "/".join([f"{version_prefix}{version}", uri.lstrip("/")]) | ||||||
|  |  | ||||||
|  |         uri = self._normalize(uri, handler) | ||||||
|  |  | ||||||
|         params = dict( |         params = dict( | ||||||
|             path=uri, |             path=uri, | ||||||
|             handler=handler, |             handler=handler, | ||||||
| @@ -131,6 +139,11 @@ class Router(BaseRouter): | |||||||
|             route.ctx.stream = stream |             route.ctx.stream = stream | ||||||
|             route.ctx.hosts = hosts |             route.ctx.hosts = hosts | ||||||
|             route.ctx.static = static |             route.ctx.static = static | ||||||
|  |             route.ctx.error_format = ( | ||||||
|  |                 error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             check_error_format(route.ctx.error_format) | ||||||
|  |  | ||||||
|             routes.append(route) |             routes.append(route) | ||||||
|  |  | ||||||
| @@ -187,3 +200,24 @@ class Router(BaseRouter): | |||||||
|                 raise SanicException( |                 raise SanicException( | ||||||
|                     f"Invalid route: {route}. Parameter names cannot use '__'." |                     f"Invalid route: {route}. Parameter names cannot use '__'." | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|  |     def _normalize(self, uri: str, handler: RouteHandler) -> str: | ||||||
|  |         if "<" not in uri: | ||||||
|  |             return uri | ||||||
|  |  | ||||||
|  |         sig = signature(handler) | ||||||
|  |         mapping = { | ||||||
|  |             param.name: param.annotation.__name__.lower() | ||||||
|  |             for param in sig.parameters.values() | ||||||
|  |             if param.annotation in (str, int, float, UUID) | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         reconstruction = [] | ||||||
|  |         for part in uri.split("/"): | ||||||
|  |             if part.startswith("<") and ":" not in part: | ||||||
|  |                 name = part[1:-1] | ||||||
|  |                 annotation = mapping.get(name) | ||||||
|  |                 if annotation: | ||||||
|  |                     part = f"<{name}:{annotation}>" | ||||||
|  |             reconstruction.append(part) | ||||||
|  |         return "/".join(reconstruction) | ||||||
|   | |||||||
							
								
								
									
										793
									
								
								sanic/server.py
									
									
									
									
									
								
							
							
						
						
									
										793
									
								
								sanic/server.py
									
									
									
									
									
								
							| @@ -1,793 +0,0 @@ | |||||||
| from __future__ import annotations |  | ||||||
|  |  | ||||||
| from ssl import SSLContext |  | ||||||
| from types import SimpleNamespace |  | ||||||
| from typing import ( |  | ||||||
|     TYPE_CHECKING, |  | ||||||
|     Any, |  | ||||||
|     Callable, |  | ||||||
|     Dict, |  | ||||||
|     Iterable, |  | ||||||
|     Optional, |  | ||||||
|     Type, |  | ||||||
|     Union, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| from sanic.models.handler_types import ListenerType |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: |  | ||||||
|     from sanic.app import Sanic |  | ||||||
|  |  | ||||||
| import asyncio |  | ||||||
| import multiprocessing |  | ||||||
| import os |  | ||||||
| import secrets |  | ||||||
| import socket |  | ||||||
| import stat |  | ||||||
|  |  | ||||||
| from asyncio import CancelledError |  | ||||||
| from asyncio.transports import Transport |  | ||||||
| from functools import partial |  | ||||||
| from inspect import isawaitable |  | ||||||
| from ipaddress import ip_address |  | ||||||
| from signal import SIG_IGN, SIGINT, SIGTERM, Signals |  | ||||||
| from signal import signal as signal_func |  | ||||||
| from time import monotonic as current_time |  | ||||||
|  |  | ||||||
| from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows |  | ||||||
| from sanic.config import Config |  | ||||||
| from sanic.exceptions import RequestTimeout, ServiceUnavailable |  | ||||||
| from sanic.http import Http, Stage |  | ||||||
| from sanic.log import error_logger, logger |  | ||||||
| from sanic.models.protocol_types import TransportProtocol |  | ||||||
| from sanic.request import Request |  | ||||||
|  |  | ||||||
|  |  | ||||||
| try: |  | ||||||
|     import uvloop  # type: ignore |  | ||||||
|  |  | ||||||
|     if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): |  | ||||||
|         asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) |  | ||||||
| except ImportError: |  | ||||||
|     pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Signal: |  | ||||||
|     stopped = False |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ConnInfo: |  | ||||||
|     """ |  | ||||||
|     Local and remote addresses and SSL status info. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     __slots__ = ( |  | ||||||
|         "client_port", |  | ||||||
|         "client", |  | ||||||
|         "client_ip", |  | ||||||
|         "ctx", |  | ||||||
|         "peername", |  | ||||||
|         "server_port", |  | ||||||
|         "server", |  | ||||||
|         "sockname", |  | ||||||
|         "ssl", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     def __init__(self, transport: TransportProtocol, unix=None): |  | ||||||
|         self.ctx = SimpleNamespace() |  | ||||||
|         self.peername = None |  | ||||||
|         self.server = self.client = "" |  | ||||||
|         self.server_port = self.client_port = 0 |  | ||||||
|         self.client_ip = "" |  | ||||||
|         self.sockname = addr = transport.get_extra_info("sockname") |  | ||||||
|         self.ssl: bool = bool(transport.get_extra_info("sslcontext")) |  | ||||||
|  |  | ||||||
|         if isinstance(addr, str):  # UNIX socket |  | ||||||
|             self.server = unix or addr |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) |  | ||||||
|         if isinstance(addr, tuple): |  | ||||||
|             self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" |  | ||||||
|             self.server_port = addr[1] |  | ||||||
|             # self.server gets non-standard port appended |  | ||||||
|             if addr[1] != (443 if self.ssl else 80): |  | ||||||
|                 self.server = f"{self.server}:{addr[1]}" |  | ||||||
|         self.peername = addr = transport.get_extra_info("peername") |  | ||||||
|  |  | ||||||
|         if isinstance(addr, tuple): |  | ||||||
|             self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" |  | ||||||
|             self.client_ip = addr[0] |  | ||||||
|             self.client_port = addr[1] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class HttpProtocol(asyncio.Protocol): |  | ||||||
|     """ |  | ||||||
|     This class provides a basic HTTP implementation of the sanic framework. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     __slots__ = ( |  | ||||||
|         # app |  | ||||||
|         "app", |  | ||||||
|         # event loop, connection |  | ||||||
|         "loop", |  | ||||||
|         "transport", |  | ||||||
|         "connections", |  | ||||||
|         "signal", |  | ||||||
|         "conn_info", |  | ||||||
|         "ctx", |  | ||||||
|         # request params |  | ||||||
|         "request", |  | ||||||
|         # request config |  | ||||||
|         "request_handler", |  | ||||||
|         "request_timeout", |  | ||||||
|         "response_timeout", |  | ||||||
|         "keep_alive_timeout", |  | ||||||
|         "request_max_size", |  | ||||||
|         "request_class", |  | ||||||
|         "error_handler", |  | ||||||
|         # enable or disable access log purpose |  | ||||||
|         "access_log", |  | ||||||
|         # connection management |  | ||||||
|         "state", |  | ||||||
|         "url", |  | ||||||
|         "_handler_task", |  | ||||||
|         "_can_write", |  | ||||||
|         "_data_received", |  | ||||||
|         "_time", |  | ||||||
|         "_task", |  | ||||||
|         "_http", |  | ||||||
|         "_exception", |  | ||||||
|         "recv_buffer", |  | ||||||
|         "_unix", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         *, |  | ||||||
|         loop, |  | ||||||
|         app: Sanic, |  | ||||||
|         signal=None, |  | ||||||
|         connections=None, |  | ||||||
|         state=None, |  | ||||||
|         unix=None, |  | ||||||
|         **kwargs, |  | ||||||
|     ): |  | ||||||
|         asyncio.set_event_loop(loop) |  | ||||||
|         self.loop = loop |  | ||||||
|         self.app: Sanic = app |  | ||||||
|         self.url = None |  | ||||||
|         self.transport: Optional[Transport] = None |  | ||||||
|         self.conn_info: Optional[ConnInfo] = None |  | ||||||
|         self.request: Optional[Request] = None |  | ||||||
|         self.signal = signal or Signal() |  | ||||||
|         self.access_log = self.app.config.ACCESS_LOG |  | ||||||
|         self.connections = connections if connections is not None else set() |  | ||||||
|         self.request_handler = self.app.handle_request |  | ||||||
|         self.error_handler = self.app.error_handler |  | ||||||
|         self.request_timeout = self.app.config.REQUEST_TIMEOUT |  | ||||||
|         self.response_timeout = self.app.config.RESPONSE_TIMEOUT |  | ||||||
|         self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT |  | ||||||
|         self.request_max_size = self.app.config.REQUEST_MAX_SIZE |  | ||||||
|         self.request_class = self.app.request_class or Request |  | ||||||
|         self.state = state if state else {} |  | ||||||
|         if "requests_count" not in self.state: |  | ||||||
|             self.state["requests_count"] = 0 |  | ||||||
|         self._data_received = asyncio.Event() |  | ||||||
|         self._can_write = asyncio.Event() |  | ||||||
|         self._can_write.set() |  | ||||||
|         self._exception = None |  | ||||||
|         self._unix = unix |  | ||||||
|  |  | ||||||
|     def _setup_connection(self): |  | ||||||
|         self._http = Http(self) |  | ||||||
|         self._time = current_time() |  | ||||||
|         self.check_timeouts() |  | ||||||
|  |  | ||||||
|     async def connection_task(self): |  | ||||||
|         """ |  | ||||||
|         Run a HTTP connection. |  | ||||||
|  |  | ||||||
|         Timeouts and some additional error handling occur here, while most of |  | ||||||
|         everything else happens in class Http or in code called from there. |  | ||||||
|         """ |  | ||||||
|         try: |  | ||||||
|             self._setup_connection() |  | ||||||
|             await self._http.http1() |  | ||||||
|         except CancelledError: |  | ||||||
|             pass |  | ||||||
|         except Exception: |  | ||||||
|             error_logger.exception("protocol.connection_task uncaught") |  | ||||||
|         finally: |  | ||||||
|             if self.app.debug and self._http: |  | ||||||
|                 ip = self.transport.get_extra_info("peername") |  | ||||||
|                 error_logger.error( |  | ||||||
|                     "Connection lost before response written" |  | ||||||
|                     f" @ {ip} {self._http.request}" |  | ||||||
|                 ) |  | ||||||
|             self._http = None |  | ||||||
|             self._task = None |  | ||||||
|             try: |  | ||||||
|                 self.close() |  | ||||||
|             except BaseException: |  | ||||||
|                 error_logger.exception("Closing failed") |  | ||||||
|  |  | ||||||
|     async def receive_more(self): |  | ||||||
|         """ |  | ||||||
|         Wait until more data is received into the Server protocol's buffer |  | ||||||
|         """ |  | ||||||
|         self.transport.resume_reading() |  | ||||||
|         self._data_received.clear() |  | ||||||
|         await self._data_received.wait() |  | ||||||
|  |  | ||||||
|     def check_timeouts(self): |  | ||||||
|         """ |  | ||||||
|         Runs itself periodically to enforce any expired timeouts. |  | ||||||
|         """ |  | ||||||
|         try: |  | ||||||
|             if not self._task: |  | ||||||
|                 return |  | ||||||
|             duration = current_time() - self._time |  | ||||||
|             stage = self._http.stage |  | ||||||
|             if stage is Stage.IDLE and duration > self.keep_alive_timeout: |  | ||||||
|                 logger.debug("KeepAlive Timeout. Closing connection.") |  | ||||||
|             elif stage is Stage.REQUEST and duration > self.request_timeout: |  | ||||||
|                 logger.debug("Request Timeout. Closing connection.") |  | ||||||
|                 self._http.exception = RequestTimeout("Request Timeout") |  | ||||||
|             elif stage is Stage.HANDLER and self._http.upgrade_websocket: |  | ||||||
|                 logger.debug("Handling websocket. Timeouts disabled.") |  | ||||||
|                 return |  | ||||||
|             elif ( |  | ||||||
|                 stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) |  | ||||||
|                 and duration > self.response_timeout |  | ||||||
|             ): |  | ||||||
|                 logger.debug("Response Timeout. Closing connection.") |  | ||||||
|                 self._http.exception = ServiceUnavailable("Response Timeout") |  | ||||||
|             else: |  | ||||||
|                 interval = ( |  | ||||||
|                     min( |  | ||||||
|                         self.keep_alive_timeout, |  | ||||||
|                         self.request_timeout, |  | ||||||
|                         self.response_timeout, |  | ||||||
|                     ) |  | ||||||
|                     / 2 |  | ||||||
|                 ) |  | ||||||
|                 self.loop.call_later(max(0.1, interval), self.check_timeouts) |  | ||||||
|                 return |  | ||||||
|             self._task.cancel() |  | ||||||
|         except Exception: |  | ||||||
|             error_logger.exception("protocol.check_timeouts") |  | ||||||
|  |  | ||||||
|     async def send(self, data): |  | ||||||
|         """ |  | ||||||
|         Writes data with backpressure control. |  | ||||||
|         """ |  | ||||||
|         await self._can_write.wait() |  | ||||||
|         if self.transport.is_closing(): |  | ||||||
|             raise CancelledError |  | ||||||
|         self.transport.write(data) |  | ||||||
|         self._time = current_time() |  | ||||||
|  |  | ||||||
|     def close_if_idle(self) -> bool: |  | ||||||
|         """ |  | ||||||
|         Close the connection if a request is not being sent or received |  | ||||||
|  |  | ||||||
|         :return: boolean - True if closed, false if staying open |  | ||||||
|         """ |  | ||||||
|         if self._http is None or self._http.stage is Stage.IDLE: |  | ||||||
|             self.close() |  | ||||||
|             return True |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     def close(self): |  | ||||||
|         """ |  | ||||||
|         Force close the connection. |  | ||||||
|         """ |  | ||||||
|         # Cause a call to connection_lost where further cleanup occurs |  | ||||||
|         if self.transport: |  | ||||||
|             self.transport.close() |  | ||||||
|             self.transport = None |  | ||||||
|  |  | ||||||
|     # -------------------------------------------- # |  | ||||||
|     # Only asyncio.Protocol callbacks below this |  | ||||||
|     # -------------------------------------------- # |  | ||||||
|  |  | ||||||
|     def connection_made(self, transport): |  | ||||||
|         try: |  | ||||||
|             # TODO: Benchmark to find suitable write buffer limits |  | ||||||
|             transport.set_write_buffer_limits(low=16384, high=65536) |  | ||||||
|             self.connections.add(self) |  | ||||||
|             self.transport = transport |  | ||||||
|             self._task = self.loop.create_task(self.connection_task()) |  | ||||||
|             self.recv_buffer = bytearray() |  | ||||||
|             self.conn_info = ConnInfo(self.transport, unix=self._unix) |  | ||||||
|         except Exception: |  | ||||||
|             error_logger.exception("protocol.connect_made") |  | ||||||
|  |  | ||||||
|     def connection_lost(self, exc): |  | ||||||
|         try: |  | ||||||
|             self.connections.discard(self) |  | ||||||
|             self.resume_writing() |  | ||||||
|             if self._task: |  | ||||||
|                 self._task.cancel() |  | ||||||
|         except Exception: |  | ||||||
|             error_logger.exception("protocol.connection_lost") |  | ||||||
|  |  | ||||||
|     def pause_writing(self): |  | ||||||
|         self._can_write.clear() |  | ||||||
|  |  | ||||||
|     def resume_writing(self): |  | ||||||
|         self._can_write.set() |  | ||||||
|  |  | ||||||
|     def data_received(self, data: bytes): |  | ||||||
|         try: |  | ||||||
|             self._time = current_time() |  | ||||||
|             if not data: |  | ||||||
|                 return self.close() |  | ||||||
|             self.recv_buffer += data |  | ||||||
|  |  | ||||||
|             if ( |  | ||||||
|                 len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE |  | ||||||
|                 and self.transport |  | ||||||
|             ): |  | ||||||
|                 self.transport.pause_reading() |  | ||||||
|  |  | ||||||
|             if self._data_received: |  | ||||||
|                 self._data_received.set() |  | ||||||
|         except Exception: |  | ||||||
|             error_logger.exception("protocol.data_received") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): |  | ||||||
|     """ |  | ||||||
|     Trigger event callbacks (functions or async) |  | ||||||
|  |  | ||||||
|     :param events: one or more sync or async functions to execute |  | ||||||
|     :param loop: event loop |  | ||||||
|     """ |  | ||||||
|     if events: |  | ||||||
|         for event in events: |  | ||||||
|             result = event(loop) |  | ||||||
|             if isawaitable(result): |  | ||||||
|                 loop.run_until_complete(result) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AsyncioServer: |  | ||||||
|     """ |  | ||||||
|     Wraps an asyncio server with functionality that might be useful to |  | ||||||
|     a user who needs to manage the server lifecycle manually. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     __slots__ = ( |  | ||||||
|         "loop", |  | ||||||
|         "serve_coro", |  | ||||||
|         "_after_start", |  | ||||||
|         "_before_stop", |  | ||||||
|         "_after_stop", |  | ||||||
|         "server", |  | ||||||
|         "connections", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         loop, |  | ||||||
|         serve_coro, |  | ||||||
|         connections, |  | ||||||
|         after_start: Optional[Iterable[ListenerType]], |  | ||||||
|         before_stop: Optional[Iterable[ListenerType]], |  | ||||||
|         after_stop: Optional[Iterable[ListenerType]], |  | ||||||
|     ): |  | ||||||
|         # Note, Sanic already called "before_server_start" events |  | ||||||
|         # before this helper was even created. So we don't need it here. |  | ||||||
|         self.loop = loop |  | ||||||
|         self.serve_coro = serve_coro |  | ||||||
|         self._after_start = after_start |  | ||||||
|         self._before_stop = before_stop |  | ||||||
|         self._after_stop = after_stop |  | ||||||
|         self.server = None |  | ||||||
|         self.connections = connections |  | ||||||
|  |  | ||||||
|     def after_start(self): |  | ||||||
|         """ |  | ||||||
|         Trigger "after_server_start" events |  | ||||||
|         """ |  | ||||||
|         trigger_events(self._after_start, self.loop) |  | ||||||
|  |  | ||||||
|     def before_stop(self): |  | ||||||
|         """ |  | ||||||
|         Trigger "before_server_stop" events |  | ||||||
|         """ |  | ||||||
|         trigger_events(self._before_stop, self.loop) |  | ||||||
|  |  | ||||||
|     def after_stop(self): |  | ||||||
|         """ |  | ||||||
|         Trigger "after_server_stop" events |  | ||||||
|         """ |  | ||||||
|         trigger_events(self._after_stop, self.loop) |  | ||||||
|  |  | ||||||
|     def is_serving(self) -> bool: |  | ||||||
|         if self.server: |  | ||||||
|             return self.server.is_serving() |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     def wait_closed(self): |  | ||||||
|         if self.server: |  | ||||||
|             return self.server.wait_closed() |  | ||||||
|  |  | ||||||
|     def close(self): |  | ||||||
|         if self.server: |  | ||||||
|             self.server.close() |  | ||||||
|             coro = self.wait_closed() |  | ||||||
|             task = asyncio.ensure_future(coro, loop=self.loop) |  | ||||||
|             return task |  | ||||||
|  |  | ||||||
|     def start_serving(self): |  | ||||||
|         if self.server: |  | ||||||
|             try: |  | ||||||
|                 return self.server.start_serving() |  | ||||||
|             except AttributeError: |  | ||||||
|                 raise NotImplementedError( |  | ||||||
|                     "server.start_serving not available in this version " |  | ||||||
|                     "of asyncio or uvloop." |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|     def serve_forever(self): |  | ||||||
|         if self.server: |  | ||||||
|             try: |  | ||||||
|                 return self.server.serve_forever() |  | ||||||
|             except AttributeError: |  | ||||||
|                 raise NotImplementedError( |  | ||||||
|                     "server.serve_forever not available in this version " |  | ||||||
|                     "of asyncio or uvloop." |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|     def __await__(self): |  | ||||||
|         """ |  | ||||||
|         Starts the asyncio server, returns AsyncServerCoro |  | ||||||
|         """ |  | ||||||
|         task = asyncio.ensure_future(self.serve_coro) |  | ||||||
|         while not task.done(): |  | ||||||
|             yield |  | ||||||
|         self.server = task.result() |  | ||||||
|         return self |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def serve( |  | ||||||
|     host, |  | ||||||
|     port, |  | ||||||
|     app, |  | ||||||
|     before_start: Optional[Iterable[ListenerType]] = None, |  | ||||||
|     after_start: Optional[Iterable[ListenerType]] = None, |  | ||||||
|     before_stop: Optional[Iterable[ListenerType]] = None, |  | ||||||
|     after_stop: Optional[Iterable[ListenerType]] = None, |  | ||||||
|     ssl: Optional[SSLContext] = None, |  | ||||||
|     sock: Optional[socket.socket] = None, |  | ||||||
|     unix: Optional[str] = None, |  | ||||||
|     reuse_port: bool = False, |  | ||||||
|     loop=None, |  | ||||||
|     protocol: Type[asyncio.Protocol] = HttpProtocol, |  | ||||||
|     backlog: int = 100, |  | ||||||
|     register_sys_signals: bool = True, |  | ||||||
|     run_multiple: bool = False, |  | ||||||
|     run_async: bool = False, |  | ||||||
|     connections=None, |  | ||||||
|     signal=Signal(), |  | ||||||
|     state=None, |  | ||||||
|     asyncio_server_kwargs=None, |  | ||||||
| ): |  | ||||||
|     """Start asynchronous HTTP Server on an individual process. |  | ||||||
|  |  | ||||||
|     :param host: Address to host on |  | ||||||
|     :param port: Port to host on |  | ||||||
|     :param before_start: function to be executed before the server starts |  | ||||||
|                          listening. Takes arguments `app` instance and `loop` |  | ||||||
|     :param after_start: function to be executed after the server starts |  | ||||||
|                         listening. Takes  arguments `app` instance and `loop` |  | ||||||
|     :param before_stop: function to be executed when a stop signal is |  | ||||||
|                         received before it is respected. Takes arguments |  | ||||||
|                         `app` instance and `loop` |  | ||||||
|     :param after_stop: function to be executed when a stop signal is |  | ||||||
|                        received after it is respected. Takes arguments |  | ||||||
|                        `app` instance and `loop` |  | ||||||
|     :param ssl: SSLContext |  | ||||||
|     :param sock: Socket for the server to accept connections from |  | ||||||
|     :param unix: Unix socket to listen on instead of TCP port |  | ||||||
|     :param reuse_port: `True` for multiple workers |  | ||||||
|     :param loop: asyncio compatible event loop |  | ||||||
|     :param run_async: bool: Do not create a new event loop for the server, |  | ||||||
|                       and return an AsyncServer object rather than running it |  | ||||||
|     :param asyncio_server_kwargs: key-value args for asyncio/uvloop |  | ||||||
|                                   create_server method |  | ||||||
|     :return: Nothing |  | ||||||
|     """ |  | ||||||
|     if not run_async and not loop: |  | ||||||
|         # create new event_loop after fork |  | ||||||
|         loop = asyncio.new_event_loop() |  | ||||||
|         asyncio.set_event_loop(loop) |  | ||||||
|  |  | ||||||
|     if app.debug: |  | ||||||
|         loop.set_debug(app.debug) |  | ||||||
|  |  | ||||||
|     app.asgi = False |  | ||||||
|  |  | ||||||
|     connections = connections if connections is not None else set() |  | ||||||
|     protocol_kwargs = _build_protocol_kwargs(protocol, app.config) |  | ||||||
|     server = partial( |  | ||||||
|         protocol, |  | ||||||
|         loop=loop, |  | ||||||
|         connections=connections, |  | ||||||
|         signal=signal, |  | ||||||
|         app=app, |  | ||||||
|         state=state, |  | ||||||
|         unix=unix, |  | ||||||
|         **protocol_kwargs, |  | ||||||
|     ) |  | ||||||
|     asyncio_server_kwargs = ( |  | ||||||
|         asyncio_server_kwargs if asyncio_server_kwargs else {} |  | ||||||
|     ) |  | ||||||
|     # UNIX sockets are always bound by us (to preserve semantics between modes) |  | ||||||
|     if unix: |  | ||||||
|         sock = bind_unix_socket(unix, backlog=backlog) |  | ||||||
|     server_coroutine = loop.create_server( |  | ||||||
|         server, |  | ||||||
|         None if sock else host, |  | ||||||
|         None if sock else port, |  | ||||||
|         ssl=ssl, |  | ||||||
|         reuse_port=reuse_port, |  | ||||||
|         sock=sock, |  | ||||||
|         backlog=backlog, |  | ||||||
|         **asyncio_server_kwargs, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     if run_async: |  | ||||||
|         return AsyncioServer( |  | ||||||
|             loop=loop, |  | ||||||
|             serve_coro=server_coroutine, |  | ||||||
|             connections=connections, |  | ||||||
|             after_start=after_start, |  | ||||||
|             before_stop=before_stop, |  | ||||||
|             after_stop=after_stop, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     trigger_events(before_start, loop) |  | ||||||
|  |  | ||||||
|     try: |  | ||||||
|         http_server = loop.run_until_complete(server_coroutine) |  | ||||||
|     except BaseException: |  | ||||||
|         error_logger.exception("Unable to start server") |  | ||||||
|         return |  | ||||||
|  |  | ||||||
|     trigger_events(after_start, loop) |  | ||||||
|  |  | ||||||
|     # Ignore SIGINT when run_multiple |  | ||||||
|     if run_multiple: |  | ||||||
|         signal_func(SIGINT, SIG_IGN) |  | ||||||
|  |  | ||||||
|     # Register signals for graceful termination |  | ||||||
|     if register_sys_signals: |  | ||||||
|         if OS_IS_WINDOWS: |  | ||||||
|             ctrlc_workaround_for_windows(app) |  | ||||||
|         else: |  | ||||||
|             for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: |  | ||||||
|                 loop.add_signal_handler(_signal, app.stop) |  | ||||||
|     pid = os.getpid() |  | ||||||
|     try: |  | ||||||
|         logger.info("Starting worker [%s]", pid) |  | ||||||
|         loop.run_forever() |  | ||||||
|     finally: |  | ||||||
|         logger.info("Stopping worker [%s]", pid) |  | ||||||
|  |  | ||||||
|         # Run the on_stop function if provided |  | ||||||
|         trigger_events(before_stop, loop) |  | ||||||
|  |  | ||||||
|         # Wait for event loop to finish and all connections to drain |  | ||||||
|         http_server.close() |  | ||||||
|         loop.run_until_complete(http_server.wait_closed()) |  | ||||||
|  |  | ||||||
|         # Complete all tasks on the loop |  | ||||||
|         signal.stopped = True |  | ||||||
|         for connection in connections: |  | ||||||
|             connection.close_if_idle() |  | ||||||
|  |  | ||||||
|         # Gracefully shutdown timeout. |  | ||||||
|         # We should provide graceful_shutdown_timeout, |  | ||||||
|         # instead of letting connection hangs forever. |  | ||||||
|         # Let's roughly calcucate time. |  | ||||||
|         graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT |  | ||||||
|         start_shutdown: float = 0 |  | ||||||
|         while connections and (start_shutdown < graceful): |  | ||||||
|             loop.run_until_complete(asyncio.sleep(0.1)) |  | ||||||
|             start_shutdown = start_shutdown + 0.1 |  | ||||||
|  |  | ||||||
|         # Force close non-idle connection after waiting for |  | ||||||
|         # graceful_shutdown_timeout |  | ||||||
|         coros = [] |  | ||||||
|         for conn in connections: |  | ||||||
|             if hasattr(conn, "websocket") and conn.websocket: |  | ||||||
|                 coros.append(conn.websocket.close_connection()) |  | ||||||
|             else: |  | ||||||
|                 conn.close() |  | ||||||
|  |  | ||||||
|         _shutdown = asyncio.gather(*coros) |  | ||||||
|         loop.run_until_complete(_shutdown) |  | ||||||
|  |  | ||||||
|         trigger_events(after_stop, loop) |  | ||||||
|  |  | ||||||
|         remove_unix_socket(unix) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _build_protocol_kwargs( |  | ||||||
|     protocol: Type[asyncio.Protocol], config: Config |  | ||||||
| ) -> Dict[str, Union[int, float]]: |  | ||||||
|     if hasattr(protocol, "websocket_handshake"): |  | ||||||
|         return { |  | ||||||
|             "websocket_max_size": config.WEBSOCKET_MAX_SIZE, |  | ||||||
|             "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, |  | ||||||
|             "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, |  | ||||||
|             "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, |  | ||||||
|             "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, |  | ||||||
|             "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, |  | ||||||
|         } |  | ||||||
|     return {} |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: |  | ||||||
|     """Create TCP server socket. |  | ||||||
|     :param host: IPv4, IPv6 or hostname may be specified |  | ||||||
|     :param port: TCP port number |  | ||||||
|     :param backlog: Maximum number of connections to queue |  | ||||||
|     :return: socket.socket object |  | ||||||
|     """ |  | ||||||
|     try:  # IP address: family must be specified for IPv6 at least |  | ||||||
|         ip = ip_address(host) |  | ||||||
|         host = str(ip) |  | ||||||
|         sock = socket.socket( |  | ||||||
|             socket.AF_INET6 if ip.version == 6 else socket.AF_INET |  | ||||||
|         ) |  | ||||||
|     except ValueError:  # Hostname, may become AF_INET or AF_INET6 |  | ||||||
|         sock = socket.socket() |  | ||||||
|     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |  | ||||||
|     sock.bind((host, port)) |  | ||||||
|     sock.listen(backlog) |  | ||||||
|     return sock |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: |  | ||||||
|     """Create unix socket. |  | ||||||
|     :param path: filesystem path |  | ||||||
|     :param backlog: Maximum number of connections to queue |  | ||||||
|     :return: socket.socket object |  | ||||||
|     """ |  | ||||||
|     """Open or atomically replace existing socket with zero downtime.""" |  | ||||||
|     # Sanitise and pre-verify socket path |  | ||||||
|     path = os.path.abspath(path) |  | ||||||
|     folder = os.path.dirname(path) |  | ||||||
|     if not os.path.isdir(folder): |  | ||||||
|         raise FileNotFoundError(f"Socket folder does not exist: {folder}") |  | ||||||
|     try: |  | ||||||
|         if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): |  | ||||||
|             raise FileExistsError(f"Existing file is not a socket: {path}") |  | ||||||
|     except FileNotFoundError: |  | ||||||
|         pass |  | ||||||
|     # Create new socket with a random temporary name |  | ||||||
|     tmp_path = f"{path}.{secrets.token_urlsafe()}" |  | ||||||
|     sock = socket.socket(socket.AF_UNIX) |  | ||||||
|     try: |  | ||||||
|         # Critical section begins (filename races) |  | ||||||
|         sock.bind(tmp_path) |  | ||||||
|         try: |  | ||||||
|             os.chmod(tmp_path, mode) |  | ||||||
|             # Start listening before rename to avoid connection failures |  | ||||||
|             sock.listen(backlog) |  | ||||||
|             os.rename(tmp_path, path) |  | ||||||
|         except:  # noqa: E722 |  | ||||||
|             try: |  | ||||||
|                 os.unlink(tmp_path) |  | ||||||
|             finally: |  | ||||||
|                 raise |  | ||||||
|     except:  # noqa: E722 |  | ||||||
|         try: |  | ||||||
|             sock.close() |  | ||||||
|         finally: |  | ||||||
|             raise |  | ||||||
|     return sock |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def remove_unix_socket(path: Optional[str]) -> None: |  | ||||||
|     """Remove dead unix socket during server exit.""" |  | ||||||
|     if not path: |  | ||||||
|         return |  | ||||||
|     try: |  | ||||||
|         if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): |  | ||||||
|             # Is it actually dead (doesn't belong to a new server instance)? |  | ||||||
|             with socket.socket(socket.AF_UNIX) as testsock: |  | ||||||
|                 try: |  | ||||||
|                     testsock.connect(path) |  | ||||||
|                 except ConnectionRefusedError: |  | ||||||
|                     os.unlink(path) |  | ||||||
|     except FileNotFoundError: |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def serve_single(server_settings): |  | ||||||
|     main_start = server_settings.pop("main_start", None) |  | ||||||
|     main_stop = server_settings.pop("main_stop", None) |  | ||||||
|  |  | ||||||
|     if not server_settings.get("run_async"): |  | ||||||
|         # create new event_loop after fork |  | ||||||
|         loop = asyncio.new_event_loop() |  | ||||||
|         asyncio.set_event_loop(loop) |  | ||||||
|         server_settings["loop"] = loop |  | ||||||
|  |  | ||||||
|     trigger_events(main_start, server_settings["loop"]) |  | ||||||
|     serve(**server_settings) |  | ||||||
|     trigger_events(main_stop, server_settings["loop"]) |  | ||||||
|  |  | ||||||
|     server_settings["loop"].close() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def serve_multiple(server_settings, workers): |  | ||||||
|     """Start multiple server processes simultaneously.  Stop on interrupt |  | ||||||
|     and terminate signals, and drain connections when complete. |  | ||||||
|  |  | ||||||
|     :param server_settings: kw arguments to be passed to the serve function |  | ||||||
|     :param workers: number of workers to launch |  | ||||||
|     :param stop_event: if provided, is used as a stop signal |  | ||||||
|     :return: |  | ||||||
|     """ |  | ||||||
|     server_settings["reuse_port"] = True |  | ||||||
|     server_settings["run_multiple"] = True |  | ||||||
|  |  | ||||||
|     main_start = server_settings.pop("main_start", None) |  | ||||||
|     main_stop = server_settings.pop("main_stop", None) |  | ||||||
|     loop = asyncio.new_event_loop() |  | ||||||
|     asyncio.set_event_loop(loop) |  | ||||||
|  |  | ||||||
|     trigger_events(main_start, loop) |  | ||||||
|  |  | ||||||
|     # Create a listening socket or use the one in settings |  | ||||||
|     sock = server_settings.get("sock") |  | ||||||
|     unix = server_settings["unix"] |  | ||||||
|     backlog = server_settings["backlog"] |  | ||||||
|     if unix: |  | ||||||
|         sock = bind_unix_socket(unix, backlog=backlog) |  | ||||||
|         server_settings["unix"] = unix |  | ||||||
|     if sock is None: |  | ||||||
|         sock = bind_socket( |  | ||||||
|             server_settings["host"], server_settings["port"], backlog=backlog |  | ||||||
|         ) |  | ||||||
|         sock.set_inheritable(True) |  | ||||||
|         server_settings["sock"] = sock |  | ||||||
|         server_settings["host"] = None |  | ||||||
|         server_settings["port"] = None |  | ||||||
|  |  | ||||||
|     processes = [] |  | ||||||
|  |  | ||||||
|     def sig_handler(signal, frame): |  | ||||||
|         logger.info("Received signal %s. Shutting down.", Signals(signal).name) |  | ||||||
|         for process in processes: |  | ||||||
|             os.kill(process.pid, SIGTERM) |  | ||||||
|  |  | ||||||
|     signal_func(SIGINT, lambda s, f: sig_handler(s, f)) |  | ||||||
|     signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) |  | ||||||
|     mp = multiprocessing.get_context("fork") |  | ||||||
|  |  | ||||||
|     for _ in range(workers): |  | ||||||
|         process = mp.Process(target=serve, kwargs=server_settings) |  | ||||||
|         process.daemon = True |  | ||||||
|         process.start() |  | ||||||
|         processes.append(process) |  | ||||||
|  |  | ||||||
|     for process in processes: |  | ||||||
|         process.join() |  | ||||||
|  |  | ||||||
|     # the above processes will block this until they're stopped |  | ||||||
|     for process in processes: |  | ||||||
|         process.terminate() |  | ||||||
|  |  | ||||||
|     trigger_events(main_stop, loop) |  | ||||||
|  |  | ||||||
|     sock.close() |  | ||||||
|     loop.close() |  | ||||||
|     remove_unix_socket(unix) |  | ||||||
							
								
								
									
										26
									
								
								sanic/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								sanic/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | import asyncio | ||||||
|  |  | ||||||
|  | from sanic.models.server_types import ConnInfo, Signal | ||||||
|  | from sanic.server.async_server import AsyncioServer | ||||||
|  | from sanic.server.protocols.http_protocol import HttpProtocol | ||||||
|  | from sanic.server.runners import serve, serve_multiple, serve_single | ||||||
|  |  | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     import uvloop  # type: ignore | ||||||
|  |  | ||||||
|  |     if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): | ||||||
|  |         asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | ||||||
|  | except ImportError: | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | __all__ = ( | ||||||
|  |     "AsyncioServer", | ||||||
|  |     "ConnInfo", | ||||||
|  |     "HttpProtocol", | ||||||
|  |     "Signal", | ||||||
|  |     "serve", | ||||||
|  |     "serve_multiple", | ||||||
|  |     "serve_single", | ||||||
|  | ) | ||||||
							
								
								
									
										115
									
								
								sanic/server/async_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								sanic/server/async_server.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,115 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | import asyncio | ||||||
|  |  | ||||||
|  | from sanic.exceptions import SanicException | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AsyncioServer: | ||||||
|  |     """ | ||||||
|  |     Wraps an asyncio server with functionality that might be useful to | ||||||
|  |     a user who needs to manage the server lifecycle manually. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     __slots__ = ("app", "connections", "loop", "serve_coro", "server", "init") | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         app, | ||||||
|  |         loop, | ||||||
|  |         serve_coro, | ||||||
|  |         connections, | ||||||
|  |     ): | ||||||
|  |         # Note, Sanic already called "before_server_start" events | ||||||
|  |         # before this helper was even created. So we don't need it here. | ||||||
|  |         self.app = app | ||||||
|  |         self.connections = connections | ||||||
|  |         self.loop = loop | ||||||
|  |         self.serve_coro = serve_coro | ||||||
|  |         self.server = None | ||||||
|  |         self.init = False | ||||||
|  |  | ||||||
|  |     def startup(self): | ||||||
|  |         """ | ||||||
|  |         Trigger "before_server_start" events | ||||||
|  |         """ | ||||||
|  |         self.init = True | ||||||
|  |         return self.app._startup() | ||||||
|  |  | ||||||
|  |     def before_start(self): | ||||||
|  |         """ | ||||||
|  |         Trigger "before_server_start" events | ||||||
|  |         """ | ||||||
|  |         return self._server_event("init", "before") | ||||||
|  |  | ||||||
|  |     def after_start(self): | ||||||
|  |         """ | ||||||
|  |         Trigger "after_server_start" events | ||||||
|  |         """ | ||||||
|  |         return self._server_event("init", "after") | ||||||
|  |  | ||||||
|  |     def before_stop(self): | ||||||
|  |         """ | ||||||
|  |         Trigger "before_server_stop" events | ||||||
|  |         """ | ||||||
|  |         return self._server_event("shutdown", "before") | ||||||
|  |  | ||||||
|  |     def after_stop(self): | ||||||
|  |         """ | ||||||
|  |         Trigger "after_server_stop" events | ||||||
|  |         """ | ||||||
|  |         return self._server_event("shutdown", "after") | ||||||
|  |  | ||||||
|  |     def is_serving(self) -> bool: | ||||||
|  |         if self.server: | ||||||
|  |             return self.server.is_serving() | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def wait_closed(self): | ||||||
|  |         if self.server: | ||||||
|  |             return self.server.wait_closed() | ||||||
|  |  | ||||||
|  |     def close(self): | ||||||
|  |         if self.server: | ||||||
|  |             self.server.close() | ||||||
|  |             coro = self.wait_closed() | ||||||
|  |             task = asyncio.ensure_future(coro, loop=self.loop) | ||||||
|  |             return task | ||||||
|  |  | ||||||
|  |     def start_serving(self): | ||||||
|  |         if self.server: | ||||||
|  |             try: | ||||||
|  |                 return self.server.start_serving() | ||||||
|  |             except AttributeError: | ||||||
|  |                 raise NotImplementedError( | ||||||
|  |                     "server.start_serving not available in this version " | ||||||
|  |                     "of asyncio or uvloop." | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def serve_forever(self): | ||||||
|  |         if self.server: | ||||||
|  |             try: | ||||||
|  |                 return self.server.serve_forever() | ||||||
|  |             except AttributeError: | ||||||
|  |                 raise NotImplementedError( | ||||||
|  |                     "server.serve_forever not available in this version " | ||||||
|  |                     "of asyncio or uvloop." | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def _server_event(self, concern: str, action: str): | ||||||
|  |         if not self.init: | ||||||
|  |             raise SanicException( | ||||||
|  |                 "Cannot dispatch server event without " | ||||||
|  |                 "first running server.startup()" | ||||||
|  |             ) | ||||||
|  |         return self.app._server_event(concern, action, loop=self.loop) | ||||||
|  |  | ||||||
|  |     def __await__(self): | ||||||
|  |         """ | ||||||
|  |         Starts the asyncio server, returns AsyncServerCoro | ||||||
|  |         """ | ||||||
|  |         task = asyncio.ensure_future(self.serve_coro) | ||||||
|  |         while not task.done(): | ||||||
|  |             yield | ||||||
|  |         self.server = task.result() | ||||||
|  |         return self | ||||||
							
								
								
									
										16
									
								
								sanic/server/events.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								sanic/server/events.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | |||||||
|  | from inspect import isawaitable | ||||||
|  | from typing import Any, Callable, Iterable, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): | ||||||
|  |     """ | ||||||
|  |     Trigger event callbacks (functions or async) | ||||||
|  |  | ||||||
|  |     :param events: one or more sync or async functions to execute | ||||||
|  |     :param loop: event loop | ||||||
|  |     """ | ||||||
|  |     if events: | ||||||
|  |         for event in events: | ||||||
|  |             result = event(loop) | ||||||
|  |             if isawaitable(result): | ||||||
|  |                 loop.run_until_complete(result) | ||||||
							
								
								
									
										0
									
								
								sanic/server/protocols/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sanic/server/protocols/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										143
									
								
								sanic/server/protocols/base_protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								sanic/server/protocols/base_protocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | from typing import TYPE_CHECKING, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from sanic.app import Sanic | ||||||
|  |  | ||||||
|  | import asyncio | ||||||
|  |  | ||||||
|  | from asyncio import CancelledError | ||||||
|  | from asyncio.transports import Transport | ||||||
|  | from time import monotonic as current_time | ||||||
|  |  | ||||||
|  | from sanic.log import error_logger | ||||||
|  | from sanic.models.server_types import ConnInfo, Signal | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SanicProtocol(asyncio.Protocol): | ||||||
|  |     __slots__ = ( | ||||||
|  |         "app", | ||||||
|  |         # event loop, connection | ||||||
|  |         "loop", | ||||||
|  |         "transport", | ||||||
|  |         "connections", | ||||||
|  |         "conn_info", | ||||||
|  |         "signal", | ||||||
|  |         "_can_write", | ||||||
|  |         "_time", | ||||||
|  |         "_task", | ||||||
|  |         "_unix", | ||||||
|  |         "_data_received", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         loop, | ||||||
|  |         app: Sanic, | ||||||
|  |         signal=None, | ||||||
|  |         connections=None, | ||||||
|  |         unix=None, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         asyncio.set_event_loop(loop) | ||||||
|  |         self.loop = loop | ||||||
|  |         self.app: Sanic = app | ||||||
|  |         self.signal = signal or Signal() | ||||||
|  |         self.transport: Optional[Transport] = None | ||||||
|  |         self.connections = connections if connections is not None else set() | ||||||
|  |         self.conn_info: Optional[ConnInfo] = None | ||||||
|  |         self._can_write = asyncio.Event() | ||||||
|  |         self._can_write.set() | ||||||
|  |         self._unix = unix | ||||||
|  |         self._time = 0.0  # type: float | ||||||
|  |         self._task = None  # type: Optional[asyncio.Task] | ||||||
|  |         self._data_received = asyncio.Event() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def ctx(self): | ||||||
|  |         if self.conn_info is not None: | ||||||
|  |             return self.conn_info.ctx | ||||||
|  |         else: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     async def send(self, data): | ||||||
|  |         """ | ||||||
|  |         Generic data write implementation with backpressure control. | ||||||
|  |         """ | ||||||
|  |         await self._can_write.wait() | ||||||
|  |         if self.transport.is_closing(): | ||||||
|  |             raise CancelledError | ||||||
|  |         self.transport.write(data) | ||||||
|  |         self._time = current_time() | ||||||
|  |  | ||||||
|  |     async def receive_more(self): | ||||||
|  |         """ | ||||||
|  |         Wait until more data is received into the Server protocol's buffer | ||||||
|  |         """ | ||||||
|  |         self.transport.resume_reading() | ||||||
|  |         self._data_received.clear() | ||||||
|  |         await self._data_received.wait() | ||||||
|  |  | ||||||
|  |     def close(self, timeout: Optional[float] = None): | ||||||
|  |         """ | ||||||
|  |         Attempt close the connection. | ||||||
|  |         """ | ||||||
|  |         # Cause a call to connection_lost where further cleanup occurs | ||||||
|  |         if self.transport: | ||||||
|  |             self.transport.close() | ||||||
|  |             if timeout is None: | ||||||
|  |                 timeout = self.app.config.GRACEFUL_SHUTDOWN_TIMEOUT | ||||||
|  |             self.loop.call_later(timeout, self.abort) | ||||||
|  |  | ||||||
|  |     def abort(self): | ||||||
|  |         """ | ||||||
|  |         Force close the connection. | ||||||
|  |         """ | ||||||
|  |         # Cause a call to connection_lost where further cleanup occurs | ||||||
|  |         if self.transport: | ||||||
|  |             self.transport.abort() | ||||||
|  |             self.transport = None | ||||||
|  |  | ||||||
|  |     # asyncio.Protocol API Callbacks # | ||||||
|  |     # ------------------------------ # | ||||||
|  |     def connection_made(self, transport): | ||||||
|  |         """ | ||||||
|  |         Generic connection-made, with no connection_task, and no recv_buffer. | ||||||
|  |         Override this for protocol-specific connection implementations. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             transport.set_write_buffer_limits(low=16384, high=65536) | ||||||
|  |             self.connections.add(self) | ||||||
|  |             self.transport = transport | ||||||
|  |             self.conn_info = ConnInfo(self.transport, unix=self._unix) | ||||||
|  |         except Exception: | ||||||
|  |             error_logger.exception("protocol.connect_made") | ||||||
|  |  | ||||||
|  |     def connection_lost(self, exc): | ||||||
|  |         try: | ||||||
|  |             self.connections.discard(self) | ||||||
|  |             self.resume_writing() | ||||||
|  |             if self._task: | ||||||
|  |                 self._task.cancel() | ||||||
|  |         except BaseException: | ||||||
|  |             error_logger.exception("protocol.connection_lost") | ||||||
|  |  | ||||||
|  |     def pause_writing(self): | ||||||
|  |         self._can_write.clear() | ||||||
|  |  | ||||||
|  |     def resume_writing(self): | ||||||
|  |         self._can_write.set() | ||||||
|  |  | ||||||
|  |     def data_received(self, data: bytes): | ||||||
|  |         try: | ||||||
|  |             self._time = current_time() | ||||||
|  |             if not data: | ||||||
|  |                 return self.close() | ||||||
|  |  | ||||||
|  |             if self._data_received: | ||||||
|  |                 self._data_received.set() | ||||||
|  |         except BaseException: | ||||||
|  |             error_logger.exception("protocol.data_received") | ||||||
							
								
								
									
										238
									
								
								sanic/server/protocols/http_protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								sanic/server/protocols/http_protocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,238 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | from typing import TYPE_CHECKING, Optional | ||||||
|  |  | ||||||
|  | from sanic.touchup.meta import TouchUpMeta | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from sanic.app import Sanic | ||||||
|  |  | ||||||
|  | from asyncio import CancelledError | ||||||
|  | from time import monotonic as current_time | ||||||
|  |  | ||||||
|  | from sanic.exceptions import RequestTimeout, ServiceUnavailable | ||||||
|  | from sanic.http import Http, Stage | ||||||
|  | from sanic.log import error_logger, logger | ||||||
|  | from sanic.models.server_types import ConnInfo | ||||||
|  | from sanic.request import Request | ||||||
|  | from sanic.server.protocols.base_protocol import SanicProtocol | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): | ||||||
|  |     """ | ||||||
|  |     This class provides implements the HTTP 1.1 protocol on top of our | ||||||
|  |     Sanic Server transport | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     __touchup__ = ( | ||||||
|  |         "send", | ||||||
|  |         "connection_task", | ||||||
|  |     ) | ||||||
|  |     __slots__ = ( | ||||||
|  |         # request params | ||||||
|  |         "request", | ||||||
|  |         # request config | ||||||
|  |         "request_handler", | ||||||
|  |         "request_timeout", | ||||||
|  |         "response_timeout", | ||||||
|  |         "keep_alive_timeout", | ||||||
|  |         "request_max_size", | ||||||
|  |         "request_class", | ||||||
|  |         "error_handler", | ||||||
|  |         # enable or disable access log purpose | ||||||
|  |         "access_log", | ||||||
|  |         # connection management | ||||||
|  |         "state", | ||||||
|  |         "url", | ||||||
|  |         "_handler_task", | ||||||
|  |         "_http", | ||||||
|  |         "_exception", | ||||||
|  |         "recv_buffer", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         loop, | ||||||
|  |         app: Sanic, | ||||||
|  |         signal=None, | ||||||
|  |         connections=None, | ||||||
|  |         state=None, | ||||||
|  |         unix=None, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         super().__init__( | ||||||
|  |             loop=loop, | ||||||
|  |             app=app, | ||||||
|  |             signal=signal, | ||||||
|  |             connections=connections, | ||||||
|  |             unix=unix, | ||||||
|  |         ) | ||||||
|  |         self.url = None | ||||||
|  |         self.request: Optional[Request] = None | ||||||
|  |         self.access_log = self.app.config.ACCESS_LOG | ||||||
|  |         self.request_handler = self.app.handle_request | ||||||
|  |         self.error_handler = self.app.error_handler | ||||||
|  |         self.request_timeout = self.app.config.REQUEST_TIMEOUT | ||||||
|  |         self.response_timeout = self.app.config.RESPONSE_TIMEOUT | ||||||
|  |         self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT | ||||||
|  |         self.request_max_size = self.app.config.REQUEST_MAX_SIZE | ||||||
|  |         self.request_class = self.app.request_class or Request | ||||||
|  |         self.state = state if state else {} | ||||||
|  |         if "requests_count" not in self.state: | ||||||
|  |             self.state["requests_count"] = 0 | ||||||
|  |         self._exception = None | ||||||
|  |  | ||||||
|  |     def _setup_connection(self): | ||||||
|  |         self._http = Http(self) | ||||||
|  |         self._time = current_time() | ||||||
|  |         self.check_timeouts() | ||||||
|  |  | ||||||
|  |     async def connection_task(self):  # no cov | ||||||
|  |         """ | ||||||
|  |         Run a HTTP connection. | ||||||
|  |  | ||||||
|  |         Timeouts and some additional error handling occur here, while most of | ||||||
|  |         everything else happens in class Http or in code called from there. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             self._setup_connection() | ||||||
|  |             await self.app.dispatch( | ||||||
|  |                 "http.lifecycle.begin", | ||||||
|  |                 inline=True, | ||||||
|  |                 context={"conn_info": self.conn_info}, | ||||||
|  |             ) | ||||||
|  |             await self._http.http1() | ||||||
|  |         except CancelledError: | ||||||
|  |             pass | ||||||
|  |         except Exception: | ||||||
|  |             error_logger.exception("protocol.connection_task uncaught") | ||||||
|  |         finally: | ||||||
|  |             if ( | ||||||
|  |                 self.app.debug | ||||||
|  |                 and self._http | ||||||
|  |                 and self.transport | ||||||
|  |                 and not self._http.upgrade_websocket | ||||||
|  |             ): | ||||||
|  |                 ip = self.transport.get_extra_info("peername") | ||||||
|  |                 error_logger.error( | ||||||
|  |                     "Connection lost before response written" | ||||||
|  |                     f" @ {ip} {self._http.request}" | ||||||
|  |                 ) | ||||||
|  |             self._http = None | ||||||
|  |             self._task = None | ||||||
|  |             try: | ||||||
|  |                 self.close() | ||||||
|  |             except BaseException: | ||||||
|  |                 error_logger.exception("Closing failed") | ||||||
|  |             finally: | ||||||
|  |                 await self.app.dispatch( | ||||||
|  |                     "http.lifecycle.complete", | ||||||
|  |                     inline=True, | ||||||
|  |                     context={"conn_info": self.conn_info}, | ||||||
|  |                 ) | ||||||
|  |                 # Important to keep this Ellipsis here for the TouchUp module | ||||||
|  |                 ... | ||||||
|  |  | ||||||
|  |     def check_timeouts(self): | ||||||
|  |         """ | ||||||
|  |         Runs itself periodically to enforce any expired timeouts. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             if not self._task: | ||||||
|  |                 return | ||||||
|  |             duration = current_time() - self._time | ||||||
|  |             stage = self._http.stage | ||||||
|  |             if stage is Stage.IDLE and duration > self.keep_alive_timeout: | ||||||
|  |                 logger.debug("KeepAlive Timeout. Closing connection.") | ||||||
|  |             elif stage is Stage.REQUEST and duration > self.request_timeout: | ||||||
|  |                 logger.debug("Request Timeout. Closing connection.") | ||||||
|  |                 self._http.exception = RequestTimeout("Request Timeout") | ||||||
|  |             elif stage is Stage.HANDLER and self._http.upgrade_websocket: | ||||||
|  |                 logger.debug("Handling websocket. Timeouts disabled.") | ||||||
|  |                 return | ||||||
|  |             elif ( | ||||||
|  |                 stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) | ||||||
|  |                 and duration > self.response_timeout | ||||||
|  |             ): | ||||||
|  |                 logger.debug("Response Timeout. Closing connection.") | ||||||
|  |                 self._http.exception = ServiceUnavailable("Response Timeout") | ||||||
|  |             else: | ||||||
|  |                 interval = ( | ||||||
|  |                     min( | ||||||
|  |                         self.keep_alive_timeout, | ||||||
|  |                         self.request_timeout, | ||||||
|  |                         self.response_timeout, | ||||||
|  |                     ) | ||||||
|  |                     / 2 | ||||||
|  |                 ) | ||||||
|  |                 self.loop.call_later(max(0.1, interval), self.check_timeouts) | ||||||
|  |                 return | ||||||
|  |             self._task.cancel() | ||||||
|  |         except Exception: | ||||||
|  |             error_logger.exception("protocol.check_timeouts") | ||||||
|  |  | ||||||
|  |     async def send(self, data):  # no cov | ||||||
|  |         """ | ||||||
|  |         Writes HTTP data with backpressure control. | ||||||
|  |         """ | ||||||
|  |         await self._can_write.wait() | ||||||
|  |         if self.transport.is_closing(): | ||||||
|  |             raise CancelledError | ||||||
|  |         await self.app.dispatch( | ||||||
|  |             "http.lifecycle.send", | ||||||
|  |             inline=True, | ||||||
|  |             context={"data": data}, | ||||||
|  |         ) | ||||||
|  |         self.transport.write(data) | ||||||
|  |         self._time = current_time() | ||||||
|  |  | ||||||
|  |     def close_if_idle(self) -> bool: | ||||||
|  |         """ | ||||||
|  |         Close the connection if a request is not being sent or received | ||||||
|  |  | ||||||
|  |         :return: boolean - True if closed, false if staying open | ||||||
|  |         """ | ||||||
|  |         if self._http is None or self._http.stage is Stage.IDLE: | ||||||
|  |             self.close() | ||||||
|  |             return True | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     # -------------------------------------------- # | ||||||
|  |     # Only asyncio.Protocol callbacks below this | ||||||
|  |     # -------------------------------------------- # | ||||||
|  |  | ||||||
|  |     def connection_made(self, transport): | ||||||
|  |         """ | ||||||
|  |         HTTP-protocol-specific new connection handler | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # TODO: Benchmark to find suitable write buffer limits | ||||||
|  |             transport.set_write_buffer_limits(low=16384, high=65536) | ||||||
|  |             self.connections.add(self) | ||||||
|  |             self.transport = transport | ||||||
|  |             self._task = self.loop.create_task(self.connection_task()) | ||||||
|  |             self.recv_buffer = bytearray() | ||||||
|  |             self.conn_info = ConnInfo(self.transport, unix=self._unix) | ||||||
|  |         except Exception: | ||||||
|  |             error_logger.exception("protocol.connect_made") | ||||||
|  |  | ||||||
|  |     def data_received(self, data: bytes): | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             self._time = current_time() | ||||||
|  |             if not data: | ||||||
|  |                 return self.close() | ||||||
|  |             self.recv_buffer += data | ||||||
|  |  | ||||||
|  |             if ( | ||||||
|  |                 len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE | ||||||
|  |                 and self.transport | ||||||
|  |             ): | ||||||
|  |                 self.transport.pause_reading() | ||||||
|  |  | ||||||
|  |             if self._data_received: | ||||||
|  |                 self._data_received.set() | ||||||
|  |         except Exception: | ||||||
|  |             error_logger.exception("protocol.data_received") | ||||||
							
								
								
									
										164
									
								
								sanic/server/protocols/websocket_protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								sanic/server/protocols/websocket_protocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,164 @@ | |||||||
|  | from typing import TYPE_CHECKING, Optional, Sequence | ||||||
|  |  | ||||||
|  | from websockets.connection import CLOSED, CLOSING, OPEN | ||||||
|  | from websockets.server import ServerConnection | ||||||
|  |  | ||||||
|  | from sanic.exceptions import ServerError | ||||||
|  | from sanic.log import error_logger | ||||||
|  | from sanic.server import HttpProtocol | ||||||
|  |  | ||||||
|  | from ..websockets.impl import WebsocketImplProtocol | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from websockets import http11 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebSocketProtocol(HttpProtocol): | ||||||
|  |  | ||||||
|  |     websocket: Optional[WebsocketImplProtocol] | ||||||
|  |     websocket_timeout: float | ||||||
|  |     websocket_max_size = Optional[int] | ||||||
|  |     websocket_ping_interval = Optional[float] | ||||||
|  |     websocket_ping_timeout = Optional[float] | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         *args, | ||||||
|  |         websocket_timeout: float = 10.0, | ||||||
|  |         websocket_max_size: Optional[int] = None, | ||||||
|  |         websocket_max_queue: Optional[int] = None,  # max_queue is deprecated | ||||||
|  |         websocket_read_limit: Optional[int] = None,  # read_limit is deprecated | ||||||
|  |         websocket_write_limit: Optional[int] = None,  # write_limit deprecated | ||||||
|  |         websocket_ping_interval: Optional[float] = 20.0, | ||||||
|  |         websocket_ping_timeout: Optional[float] = 20.0, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.websocket = None | ||||||
|  |         self.websocket_timeout = websocket_timeout | ||||||
|  |         self.websocket_max_size = websocket_max_size | ||||||
|  |         if websocket_max_queue is not None and websocket_max_queue > 0: | ||||||
|  |             # TODO: Reminder remove this warning in v22.3 | ||||||
|  |             error_logger.warning( | ||||||
|  |                 DeprecationWarning( | ||||||
|  |                     "Websocket no longer uses queueing, so websocket_max_queue" | ||||||
|  |                     " is no longer required." | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         if websocket_read_limit is not None and websocket_read_limit > 0: | ||||||
|  |             # TODO: Reminder remove this warning in v22.3 | ||||||
|  |             error_logger.warning( | ||||||
|  |                 DeprecationWarning( | ||||||
|  |                     "Websocket no longer uses read buffers, so " | ||||||
|  |                     "websocket_read_limit is not required." | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         if websocket_write_limit is not None and websocket_write_limit > 0: | ||||||
|  |             # TODO: Reminder remove this warning in v22.3 | ||||||
|  |             error_logger.warning( | ||||||
|  |                 DeprecationWarning( | ||||||
|  |                     "Websocket no longer uses write buffers, so " | ||||||
|  |                     "websocket_write_limit is not required." | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         self.websocket_ping_interval = websocket_ping_interval | ||||||
|  |         self.websocket_ping_timeout = websocket_ping_timeout | ||||||
|  |  | ||||||
|  |     def connection_lost(self, exc): | ||||||
|  |         if self.websocket is not None: | ||||||
|  |             self.websocket.connection_lost(exc) | ||||||
|  |         super().connection_lost(exc) | ||||||
|  |  | ||||||
|  |     def data_received(self, data): | ||||||
|  |         if self.websocket is not None: | ||||||
|  |             self.websocket.data_received(data) | ||||||
|  |         else: | ||||||
|  |             # Pass it to HttpProtocol handler first | ||||||
|  |             # That will (hopefully) upgrade it to a websocket. | ||||||
|  |             super().data_received(data) | ||||||
|  |  | ||||||
|  |     def eof_received(self) -> Optional[bool]: | ||||||
|  |         if self.websocket is not None: | ||||||
|  |             return self.websocket.eof_received() | ||||||
|  |         else: | ||||||
|  |             return False | ||||||
|  |  | ||||||
|  |     def close(self, timeout: Optional[float] = None): | ||||||
|  |         # Called by HttpProtocol at the end of connection_task | ||||||
|  |         # If we've upgraded to websocket, we do our own closing | ||||||
|  |         if self.websocket is not None: | ||||||
|  |             # Note, we don't want to use websocket.close() | ||||||
|  |             # That is used for user's application code to send a | ||||||
|  |             # websocket close packet. This is different. | ||||||
|  |             self.websocket.end_connection(1001) | ||||||
|  |         else: | ||||||
|  |             super().close() | ||||||
|  |  | ||||||
|  |     def close_if_idle(self): | ||||||
|  |         # Called by Sanic Server when shutting down | ||||||
|  |         # If we've upgraded to websocket, shut it down | ||||||
|  |         if self.websocket is not None: | ||||||
|  |             if self.websocket.connection.state in (CLOSING, CLOSED): | ||||||
|  |                 return True | ||||||
|  |             elif self.websocket.loop is not None: | ||||||
|  |                 self.websocket.loop.create_task(self.websocket.close(1001)) | ||||||
|  |             else: | ||||||
|  |                 self.websocket.end_connection(1001) | ||||||
|  |         else: | ||||||
|  |             return super().close_if_idle() | ||||||
|  |  | ||||||
|  |     async def websocket_handshake( | ||||||
|  |         self, request, subprotocols=Optional[Sequence[str]] | ||||||
|  |     ): | ||||||
|  |         # let the websockets package do the handshake with the client | ||||||
|  |         try: | ||||||
|  |             if subprotocols is not None: | ||||||
|  |                 # subprotocols can be a set or frozenset, | ||||||
|  |                 # but ServerConnection needs a list | ||||||
|  |                 subprotocols = list(subprotocols) | ||||||
|  |             ws_conn = ServerConnection( | ||||||
|  |                 max_size=self.websocket_max_size, | ||||||
|  |                 subprotocols=subprotocols, | ||||||
|  |                 state=OPEN, | ||||||
|  |                 logger=error_logger, | ||||||
|  |             ) | ||||||
|  |             resp: "http11.Response" = ws_conn.accept(request) | ||||||
|  |         except Exception: | ||||||
|  |             msg = ( | ||||||
|  |                 "Failed to open a WebSocket connection.\n" | ||||||
|  |                 "See server log for more information.\n" | ||||||
|  |             ) | ||||||
|  |             raise ServerError(msg, status_code=500) | ||||||
|  |         if 100 <= resp.status_code <= 299: | ||||||
|  |             rbody = "".join( | ||||||
|  |                 [ | ||||||
|  |                     "HTTP/1.1 ", | ||||||
|  |                     str(resp.status_code), | ||||||
|  |                     " ", | ||||||
|  |                     resp.reason_phrase, | ||||||
|  |                     "\r\n", | ||||||
|  |                 ] | ||||||
|  |             ) | ||||||
|  |             rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) | ||||||
|  |             if resp.body is not None: | ||||||
|  |                 rbody += f"\r\n{resp.body}\r\n\r\n" | ||||||
|  |             else: | ||||||
|  |                 rbody += "\r\n" | ||||||
|  |             await super().send(rbody.encode()) | ||||||
|  |         else: | ||||||
|  |             raise ServerError(resp.body, resp.status_code) | ||||||
|  |         self.websocket = WebsocketImplProtocol( | ||||||
|  |             ws_conn, | ||||||
|  |             ping_interval=self.websocket_ping_interval, | ||||||
|  |             ping_timeout=self.websocket_ping_timeout, | ||||||
|  |             close_timeout=self.websocket_timeout, | ||||||
|  |         ) | ||||||
|  |         loop = ( | ||||||
|  |             request.transport.loop | ||||||
|  |             if hasattr(request, "transport") | ||||||
|  |             and hasattr(request.transport, "loop") | ||||||
|  |             else None | ||||||
|  |         ) | ||||||
|  |         await self.websocket.connection_made(self, loop=loop) | ||||||
|  |         return self.websocket | ||||||
							
								
								
									
										280
									
								
								sanic/server/runners.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										280
									
								
								sanic/server/runners.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,280 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | from ssl import SSLContext | ||||||
|  | from typing import TYPE_CHECKING, Dict, Optional, Type, Union | ||||||
|  |  | ||||||
|  | from sanic.config import Config | ||||||
|  | from sanic.server.events import trigger_events | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from sanic.app import Sanic | ||||||
|  |  | ||||||
|  | import asyncio | ||||||
|  | import multiprocessing | ||||||
|  | import os | ||||||
|  | import socket | ||||||
|  |  | ||||||
|  | from functools import partial | ||||||
|  | from signal import SIG_IGN, SIGINT, SIGTERM, Signals | ||||||
|  | from signal import signal as signal_func | ||||||
|  |  | ||||||
|  | from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows | ||||||
|  | from sanic.log import error_logger, logger | ||||||
|  | from sanic.models.server_types import Signal | ||||||
|  | from sanic.server.async_server import AsyncioServer | ||||||
|  | from sanic.server.protocols.http_protocol import HttpProtocol | ||||||
|  | from sanic.server.socket import ( | ||||||
|  |     bind_socket, | ||||||
|  |     bind_unix_socket, | ||||||
|  |     remove_unix_socket, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def serve( | ||||||
|  |     host, | ||||||
|  |     port, | ||||||
|  |     app: Sanic, | ||||||
|  |     ssl: Optional[SSLContext] = None, | ||||||
|  |     sock: Optional[socket.socket] = None, | ||||||
|  |     unix: Optional[str] = None, | ||||||
|  |     reuse_port: bool = False, | ||||||
|  |     loop=None, | ||||||
|  |     protocol: Type[asyncio.Protocol] = HttpProtocol, | ||||||
|  |     backlog: int = 100, | ||||||
|  |     register_sys_signals: bool = True, | ||||||
|  |     run_multiple: bool = False, | ||||||
|  |     run_async: bool = False, | ||||||
|  |     connections=None, | ||||||
|  |     signal=Signal(), | ||||||
|  |     state=None, | ||||||
|  |     asyncio_server_kwargs=None, | ||||||
|  | ): | ||||||
|  |     """Start asynchronous HTTP Server on an individual process. | ||||||
|  |  | ||||||
|  |     :param host: Address to host on | ||||||
|  |     :param port: Port to host on | ||||||
|  |     :param before_start: function to be executed before the server starts | ||||||
|  |                          listening. Takes arguments `app` instance and `loop` | ||||||
|  |     :param after_start: function to be executed after the server starts | ||||||
|  |                         listening. Takes  arguments `app` instance and `loop` | ||||||
|  |     :param before_stop: function to be executed when a stop signal is | ||||||
|  |                         received before it is respected. Takes arguments | ||||||
|  |                         `app` instance and `loop` | ||||||
|  |     :param after_stop: function to be executed when a stop signal is | ||||||
|  |                        received after it is respected. Takes arguments | ||||||
|  |                        `app` instance and `loop` | ||||||
|  |     :param ssl: SSLContext | ||||||
|  |     :param sock: Socket for the server to accept connections from | ||||||
|  |     :param unix: Unix socket to listen on instead of TCP port | ||||||
|  |     :param reuse_port: `True` for multiple workers | ||||||
|  |     :param loop: asyncio compatible event loop | ||||||
|  |     :param run_async: bool: Do not create a new event loop for the server, | ||||||
|  |                       and return an AsyncServer object rather than running it | ||||||
|  |     :param asyncio_server_kwargs: key-value args for asyncio/uvloop | ||||||
|  |                                   create_server method | ||||||
|  |     :return: Nothing | ||||||
|  |     """ | ||||||
|  |     if not run_async and not loop: | ||||||
|  |         # create new event_loop after fork | ||||||
|  |         loop = asyncio.new_event_loop() | ||||||
|  |         asyncio.set_event_loop(loop) | ||||||
|  |  | ||||||
|  |     if app.debug: | ||||||
|  |         loop.set_debug(app.debug) | ||||||
|  |  | ||||||
|  |     app.asgi = False | ||||||
|  |  | ||||||
|  |     connections = connections if connections is not None else set() | ||||||
|  |     protocol_kwargs = _build_protocol_kwargs(protocol, app.config) | ||||||
|  |     server = partial( | ||||||
|  |         protocol, | ||||||
|  |         loop=loop, | ||||||
|  |         connections=connections, | ||||||
|  |         signal=signal, | ||||||
|  |         app=app, | ||||||
|  |         state=state, | ||||||
|  |         unix=unix, | ||||||
|  |         **protocol_kwargs, | ||||||
|  |     ) | ||||||
|  |     asyncio_server_kwargs = ( | ||||||
|  |         asyncio_server_kwargs if asyncio_server_kwargs else {} | ||||||
|  |     ) | ||||||
|  |     # UNIX sockets are always bound by us (to preserve semantics between modes) | ||||||
|  |     if unix: | ||||||
|  |         sock = bind_unix_socket(unix, backlog=backlog) | ||||||
|  |     server_coroutine = loop.create_server( | ||||||
|  |         server, | ||||||
|  |         None if sock else host, | ||||||
|  |         None if sock else port, | ||||||
|  |         ssl=ssl, | ||||||
|  |         reuse_port=reuse_port, | ||||||
|  |         sock=sock, | ||||||
|  |         backlog=backlog, | ||||||
|  |         **asyncio_server_kwargs, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     if run_async: | ||||||
|  |         return AsyncioServer( | ||||||
|  |             app=app, | ||||||
|  |             loop=loop, | ||||||
|  |             serve_coro=server_coroutine, | ||||||
|  |             connections=connections, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     loop.run_until_complete(app._startup()) | ||||||
|  |     loop.run_until_complete(app._server_event("init", "before")) | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         http_server = loop.run_until_complete(server_coroutine) | ||||||
|  |     except BaseException: | ||||||
|  |         error_logger.exception("Unable to start server") | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     # Ignore SIGINT when run_multiple | ||||||
|  |     if run_multiple: | ||||||
|  |         signal_func(SIGINT, SIG_IGN) | ||||||
|  |  | ||||||
|  |     # Register signals for graceful termination | ||||||
|  |     if register_sys_signals: | ||||||
|  |         if OS_IS_WINDOWS: | ||||||
|  |             ctrlc_workaround_for_windows(app) | ||||||
|  |         else: | ||||||
|  |             for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: | ||||||
|  |                 loop.add_signal_handler(_signal, app.stop) | ||||||
|  |  | ||||||
|  |     loop.run_until_complete(app._server_event("init", "after")) | ||||||
|  |     pid = os.getpid() | ||||||
|  |     try: | ||||||
|  |         logger.info("Starting worker [%s]", pid) | ||||||
|  |         loop.run_forever() | ||||||
|  |     finally: | ||||||
|  |         logger.info("Stopping worker [%s]", pid) | ||||||
|  |  | ||||||
|  |         # Run the on_stop function if provided | ||||||
|  |         loop.run_until_complete(app._server_event("shutdown", "before")) | ||||||
|  |  | ||||||
|  |         # Wait for event loop to finish and all connections to drain | ||||||
|  |         http_server.close() | ||||||
|  |         loop.run_until_complete(http_server.wait_closed()) | ||||||
|  |  | ||||||
|  |         # Complete all tasks on the loop | ||||||
|  |         signal.stopped = True | ||||||
|  |         for connection in connections: | ||||||
|  |             connection.close_if_idle() | ||||||
|  |  | ||||||
|  |         # Gracefully shutdown timeout. | ||||||
|  |         # We should provide graceful_shutdown_timeout, | ||||||
|  |         # instead of letting connection hangs forever. | ||||||
|  |         # Let's roughly calcucate time. | ||||||
|  |         graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT | ||||||
|  |         start_shutdown: float = 0 | ||||||
|  |         while connections and (start_shutdown < graceful): | ||||||
|  |             loop.run_until_complete(asyncio.sleep(0.1)) | ||||||
|  |             start_shutdown = start_shutdown + 0.1 | ||||||
|  |  | ||||||
|  |         # Force close non-idle connection after waiting for | ||||||
|  |         # graceful_shutdown_timeout | ||||||
|  |         for conn in connections: | ||||||
|  |             if hasattr(conn, "websocket") and conn.websocket: | ||||||
|  |                 conn.websocket.fail_connection(code=1001) | ||||||
|  |             else: | ||||||
|  |                 conn.abort() | ||||||
|  |         loop.run_until_complete(app._server_event("shutdown", "after")) | ||||||
|  |  | ||||||
|  |         remove_unix_socket(unix) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def serve_single(server_settings): | ||||||
|  |     main_start = server_settings.pop("main_start", None) | ||||||
|  |     main_stop = server_settings.pop("main_stop", None) | ||||||
|  |  | ||||||
|  |     if not server_settings.get("run_async"): | ||||||
|  |         # create new event_loop after fork | ||||||
|  |         loop = asyncio.new_event_loop() | ||||||
|  |         asyncio.set_event_loop(loop) | ||||||
|  |         server_settings["loop"] = loop | ||||||
|  |  | ||||||
|  |     trigger_events(main_start, server_settings["loop"]) | ||||||
|  |     serve(**server_settings) | ||||||
|  |     trigger_events(main_stop, server_settings["loop"]) | ||||||
|  |  | ||||||
|  |     server_settings["loop"].close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def serve_multiple(server_settings, workers): | ||||||
|  |     """Start multiple server processes simultaneously.  Stop on interrupt | ||||||
|  |     and terminate signals, and drain connections when complete. | ||||||
|  |  | ||||||
|  |     :param server_settings: kw arguments to be passed to the serve function | ||||||
|  |     :param workers: number of workers to launch | ||||||
|  |     :param stop_event: if provided, is used as a stop signal | ||||||
|  |     :return: | ||||||
|  |     """ | ||||||
|  |     server_settings["reuse_port"] = True | ||||||
|  |     server_settings["run_multiple"] = True | ||||||
|  |  | ||||||
|  |     main_start = server_settings.pop("main_start", None) | ||||||
|  |     main_stop = server_settings.pop("main_stop", None) | ||||||
|  |     loop = asyncio.new_event_loop() | ||||||
|  |     asyncio.set_event_loop(loop) | ||||||
|  |  | ||||||
|  |     trigger_events(main_start, loop) | ||||||
|  |  | ||||||
|  |     # Create a listening socket or use the one in settings | ||||||
|  |     sock = server_settings.get("sock") | ||||||
|  |     unix = server_settings["unix"] | ||||||
|  |     backlog = server_settings["backlog"] | ||||||
|  |     if unix: | ||||||
|  |         sock = bind_unix_socket(unix, backlog=backlog) | ||||||
|  |         server_settings["unix"] = unix | ||||||
|  |     if sock is None: | ||||||
|  |         sock = bind_socket( | ||||||
|  |             server_settings["host"], server_settings["port"], backlog=backlog | ||||||
|  |         ) | ||||||
|  |         sock.set_inheritable(True) | ||||||
|  |         server_settings["sock"] = sock | ||||||
|  |         server_settings["host"] = None | ||||||
|  |         server_settings["port"] = None | ||||||
|  |  | ||||||
|  |     processes = [] | ||||||
|  |  | ||||||
|  |     def sig_handler(signal, frame): | ||||||
|  |         logger.info("Received signal %s. Shutting down.", Signals(signal).name) | ||||||
|  |         for process in processes: | ||||||
|  |             os.kill(process.pid, SIGTERM) | ||||||
|  |  | ||||||
|  |     signal_func(SIGINT, lambda s, f: sig_handler(s, f)) | ||||||
|  |     signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) | ||||||
|  |     mp = multiprocessing.get_context("fork") | ||||||
|  |  | ||||||
|  |     for _ in range(workers): | ||||||
|  |         process = mp.Process(target=serve, kwargs=server_settings) | ||||||
|  |         process.daemon = True | ||||||
|  |         process.start() | ||||||
|  |         processes.append(process) | ||||||
|  |  | ||||||
|  |     for process in processes: | ||||||
|  |         process.join() | ||||||
|  |  | ||||||
|  |     # the above processes will block this until they're stopped | ||||||
|  |     for process in processes: | ||||||
|  |         process.terminate() | ||||||
|  |  | ||||||
|  |     trigger_events(main_stop, loop) | ||||||
|  |  | ||||||
|  |     sock.close() | ||||||
|  |     loop.close() | ||||||
|  |     remove_unix_socket(unix) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _build_protocol_kwargs( | ||||||
|  |     protocol: Type[asyncio.Protocol], config: Config | ||||||
|  | ) -> Dict[str, Union[int, float]]: | ||||||
|  |     if hasattr(protocol, "websocket_handshake"): | ||||||
|  |         return { | ||||||
|  |             "websocket_max_size": config.WEBSOCKET_MAX_SIZE, | ||||||
|  |             "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, | ||||||
|  |             "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, | ||||||
|  |         } | ||||||
|  |     return {} | ||||||
							
								
								
									
										87
									
								
								sanic/server/socket.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								sanic/server/socket.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import secrets | ||||||
|  | import socket | ||||||
|  | import stat | ||||||
|  |  | ||||||
|  | from ipaddress import ip_address | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: | ||||||
|  |     """Create TCP server socket. | ||||||
|  |     :param host: IPv4, IPv6 or hostname may be specified | ||||||
|  |     :param port: TCP port number | ||||||
|  |     :param backlog: Maximum number of connections to queue | ||||||
|  |     :return: socket.socket object | ||||||
|  |     """ | ||||||
|  |     try:  # IP address: family must be specified for IPv6 at least | ||||||
|  |         ip = ip_address(host) | ||||||
|  |         host = str(ip) | ||||||
|  |         sock = socket.socket( | ||||||
|  |             socket.AF_INET6 if ip.version == 6 else socket.AF_INET | ||||||
|  |         ) | ||||||
|  |     except ValueError:  # Hostname, may become AF_INET or AF_INET6 | ||||||
|  |         sock = socket.socket() | ||||||
|  |     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||||||
|  |     sock.bind((host, port)) | ||||||
|  |     sock.listen(backlog) | ||||||
|  |     return sock | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: | ||||||
|  |     """Create unix socket. | ||||||
|  |     :param path: filesystem path | ||||||
|  |     :param backlog: Maximum number of connections to queue | ||||||
|  |     :return: socket.socket object | ||||||
|  |     """ | ||||||
|  |     """Open or atomically replace existing socket with zero downtime.""" | ||||||
|  |     # Sanitise and pre-verify socket path | ||||||
|  |     path = os.path.abspath(path) | ||||||
|  |     folder = os.path.dirname(path) | ||||||
|  |     if not os.path.isdir(folder): | ||||||
|  |         raise FileNotFoundError(f"Socket folder does not exist: {folder}") | ||||||
|  |     try: | ||||||
|  |         if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): | ||||||
|  |             raise FileExistsError(f"Existing file is not a socket: {path}") | ||||||
|  |     except FileNotFoundError: | ||||||
|  |         pass | ||||||
|  |     # Create new socket with a random temporary name | ||||||
|  |     tmp_path = f"{path}.{secrets.token_urlsafe()}" | ||||||
|  |     sock = socket.socket(socket.AF_UNIX) | ||||||
|  |     try: | ||||||
|  |         # Critical section begins (filename races) | ||||||
|  |         sock.bind(tmp_path) | ||||||
|  |         try: | ||||||
|  |             os.chmod(tmp_path, mode) | ||||||
|  |             # Start listening before rename to avoid connection failures | ||||||
|  |             sock.listen(backlog) | ||||||
|  |             os.rename(tmp_path, path) | ||||||
|  |         except:  # noqa: E722 | ||||||
|  |             try: | ||||||
|  |                 os.unlink(tmp_path) | ||||||
|  |             finally: | ||||||
|  |                 raise | ||||||
|  |     except:  # noqa: E722 | ||||||
|  |         try: | ||||||
|  |             sock.close() | ||||||
|  |         finally: | ||||||
|  |             raise | ||||||
|  |     return sock | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def remove_unix_socket(path: Optional[str]) -> None: | ||||||
|  |     """Remove dead unix socket during server exit.""" | ||||||
|  |     if not path: | ||||||
|  |         return | ||||||
|  |     try: | ||||||
|  |         if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): | ||||||
|  |             # Is it actually dead (doesn't belong to a new server instance)? | ||||||
|  |             with socket.socket(socket.AF_UNIX) as testsock: | ||||||
|  |                 try: | ||||||
|  |                     testsock.connect(path) | ||||||
|  |                 except ConnectionRefusedError: | ||||||
|  |                     os.unlink(path) | ||||||
|  |     except FileNotFoundError: | ||||||
|  |         pass | ||||||
							
								
								
									
										0
									
								
								sanic/server/websockets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sanic/server/websockets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										82
									
								
								sanic/server/websockets/connection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								sanic/server/websockets/connection.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | |||||||
|  | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Awaitable, | ||||||
|  |     Callable, | ||||||
|  |     Dict, | ||||||
|  |     List, | ||||||
|  |     MutableMapping, | ||||||
|  |     Optional, | ||||||
|  |     Union, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ASIMessage = MutableMapping[str, Any] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebSocketConnection: | ||||||
|  |     """ | ||||||
|  |     This is for ASGI Connections. | ||||||
|  |     It provides an interface similar to WebsocketProtocol, but | ||||||
|  |     sends/receives over an ASGI connection. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     # TODO | ||||||
|  |     # - Implement ping/pong | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         send: Callable[[ASIMessage], Awaitable[None]], | ||||||
|  |         receive: Callable[[], Awaitable[ASIMessage]], | ||||||
|  |         subprotocols: Optional[List[str]] = None, | ||||||
|  |     ) -> None: | ||||||
|  |         self._send = send | ||||||
|  |         self._receive = receive | ||||||
|  |         self._subprotocols = subprotocols or [] | ||||||
|  |  | ||||||
|  |     async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: | ||||||
|  |         message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} | ||||||
|  |  | ||||||
|  |         if isinstance(data, bytes): | ||||||
|  |             message.update({"bytes": data}) | ||||||
|  |         else: | ||||||
|  |             message.update({"text": str(data)}) | ||||||
|  |  | ||||||
|  |         await self._send(message) | ||||||
|  |  | ||||||
|  |     async def recv(self, *args, **kwargs) -> Optional[str]: | ||||||
|  |         message = await self._receive() | ||||||
|  |  | ||||||
|  |         if message["type"] == "websocket.receive": | ||||||
|  |             return message["text"] | ||||||
|  |         elif message["type"] == "websocket.disconnect": | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |     receive = recv | ||||||
|  |  | ||||||
|  |     async def accept(self, subprotocols: Optional[List[str]] = None) -> None: | ||||||
|  |         subprotocol = None | ||||||
|  |         if subprotocols: | ||||||
|  |             for subp in subprotocols: | ||||||
|  |                 if subp in self.subprotocols: | ||||||
|  |                     subprotocol = subp | ||||||
|  |                     break | ||||||
|  |  | ||||||
|  |         await self._send( | ||||||
|  |             { | ||||||
|  |                 "type": "websocket.accept", | ||||||
|  |                 "subprotocol": subprotocol, | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     async def close(self, code: int = 1000, reason: str = "") -> None: | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def subprotocols(self): | ||||||
|  |         return self._subprotocols | ||||||
|  |  | ||||||
|  |     @subprotocols.setter | ||||||
|  |     def subprotocols(self, subprotocols: Optional[List[str]] = None): | ||||||
|  |         self._subprotocols = subprotocols or [] | ||||||
							
								
								
									
										294
									
								
								sanic/server/websockets/frame.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								sanic/server/websockets/frame.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,294 @@ | |||||||
|  | import asyncio | ||||||
|  | import codecs | ||||||
|  |  | ||||||
|  | from typing import TYPE_CHECKING, AsyncIterator, List, Optional | ||||||
|  |  | ||||||
|  | from websockets.frames import Frame, Opcode | ||||||
|  | from websockets.typing import Data | ||||||
|  |  | ||||||
|  | from sanic.exceptions import ServerError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from .impl import WebsocketImplProtocol | ||||||
|  |  | ||||||
|  | UTF8Decoder = codecs.getincrementaldecoder("utf-8") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebsocketFrameAssembler: | ||||||
|  |     """ | ||||||
|  |     Assemble a message from frames. | ||||||
|  |     Code borrowed from aaugustin/websockets project: | ||||||
|  |     https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     __slots__ = ( | ||||||
|  |         "protocol", | ||||||
|  |         "read_mutex", | ||||||
|  |         "write_mutex", | ||||||
|  |         "message_complete", | ||||||
|  |         "message_fetched", | ||||||
|  |         "get_in_progress", | ||||||
|  |         "decoder", | ||||||
|  |         "completed_queue", | ||||||
|  |         "chunks", | ||||||
|  |         "chunks_queue", | ||||||
|  |         "paused", | ||||||
|  |         "get_id", | ||||||
|  |         "put_id", | ||||||
|  |     ) | ||||||
|  |     if TYPE_CHECKING: | ||||||
|  |         protocol: "WebsocketImplProtocol" | ||||||
|  |         read_mutex: asyncio.Lock | ||||||
|  |         write_mutex: asyncio.Lock | ||||||
|  |         message_complete: asyncio.Event | ||||||
|  |         message_fetched: asyncio.Event | ||||||
|  |         completed_queue: asyncio.Queue | ||||||
|  |         get_in_progress: bool | ||||||
|  |         decoder: Optional[codecs.IncrementalDecoder] | ||||||
|  |         # For streaming chunks rather than messages: | ||||||
|  |         chunks: List[Data] | ||||||
|  |         chunks_queue: Optional[asyncio.Queue[Optional[Data]]] | ||||||
|  |         paused: bool | ||||||
|  |  | ||||||
|  |     def __init__(self, protocol) -> None: | ||||||
|  |  | ||||||
|  |         self.protocol = protocol | ||||||
|  |  | ||||||
|  |         self.read_mutex = asyncio.Lock() | ||||||
|  |         self.write_mutex = asyncio.Lock() | ||||||
|  |  | ||||||
|  |         self.completed_queue = asyncio.Queue( | ||||||
|  |             maxsize=1 | ||||||
|  |         )  # type: asyncio.Queue[Data] | ||||||
|  |  | ||||||
|  |         # put() sets this event to tell get() that a message can be fetched. | ||||||
|  |         self.message_complete = asyncio.Event() | ||||||
|  |         # get() sets this event to let put() | ||||||
|  |         self.message_fetched = asyncio.Event() | ||||||
|  |  | ||||||
|  |         # This flag prevents concurrent calls to get() by user code. | ||||||
|  |         self.get_in_progress = False | ||||||
|  |  | ||||||
|  |         # Decoder for text frames, None for binary frames. | ||||||
|  |         self.decoder = None | ||||||
|  |  | ||||||
|  |         # Buffer data from frames belonging to the same message. | ||||||
|  |         self.chunks = [] | ||||||
|  |  | ||||||
|  |         # When switching from "buffering" to "streaming", we use a thread-safe | ||||||
|  |         # queue for transferring frames from the writing thread (library code) | ||||||
|  |         # to the reading thread (user code). We're buffering when chunks_queue | ||||||
|  |         # is None and streaming when it's a Queue. None is a sentinel | ||||||
|  |         # value marking the end of the stream, superseding message_complete. | ||||||
|  |  | ||||||
|  |         # Stream data from frames belonging to the same message. | ||||||
|  |         self.chunks_queue = None | ||||||
|  |  | ||||||
|  |         # Flag to indicate we've paused the protocol | ||||||
|  |         self.paused = False | ||||||
|  |  | ||||||
|  |     async def get(self, timeout: Optional[float] = None) -> Optional[Data]: | ||||||
|  |         """ | ||||||
|  |         Read the next message. | ||||||
|  |         :meth:`get` returns a single :class:`str` or :class:`bytes`. | ||||||
|  |         If the :message was fragmented, :meth:`get` waits until the last frame | ||||||
|  |         is received, then it reassembles the message. | ||||||
|  |         If ``timeout`` is set and elapses before a complete message is | ||||||
|  |         received, :meth:`get` returns ``None``. | ||||||
|  |         """ | ||||||
|  |         async with self.read_mutex: | ||||||
|  |             if timeout is not None and timeout <= 0: | ||||||
|  |                 if not self.message_complete.is_set(): | ||||||
|  |                     return None | ||||||
|  |             if self.get_in_progress: | ||||||
|  |                 # This should be guarded against with the read_mutex, | ||||||
|  |                 # exception is only here as a failsafe | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Called get() on Websocket frame assembler " | ||||||
|  |                     "while asynchronous get is already in progress." | ||||||
|  |                 ) | ||||||
|  |             self.get_in_progress = True | ||||||
|  |  | ||||||
|  |             # If the message_complete event isn't set yet, release the lock to | ||||||
|  |             # allow put() to run and eventually set it. | ||||||
|  |             # Locking with get_in_progress ensures only one task can get here. | ||||||
|  |             if timeout is None: | ||||||
|  |                 completed = await self.message_complete.wait() | ||||||
|  |             elif timeout <= 0: | ||||||
|  |                 completed = self.message_complete.is_set() | ||||||
|  |             else: | ||||||
|  |                 try: | ||||||
|  |                     await asyncio.wait_for( | ||||||
|  |                         self.message_complete.wait(), timeout=timeout | ||||||
|  |                     ) | ||||||
|  |                 except asyncio.TimeoutError: | ||||||
|  |                     ... | ||||||
|  |                 finally: | ||||||
|  |                     completed = self.message_complete.is_set() | ||||||
|  |  | ||||||
|  |             # Unpause the transport, if its paused | ||||||
|  |             if self.paused: | ||||||
|  |                 self.protocol.resume_frames() | ||||||
|  |                 self.paused = False | ||||||
|  |             if not self.get_in_progress: | ||||||
|  |                 # This should be guarded against with the read_mutex, | ||||||
|  |                 # exception is here as a failsafe | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "State of Websocket frame assembler was modified while an " | ||||||
|  |                     "asynchronous get was in progress." | ||||||
|  |                 ) | ||||||
|  |             self.get_in_progress = False | ||||||
|  |  | ||||||
|  |             # Waiting for a complete message timed out. | ||||||
|  |             if not completed: | ||||||
|  |                 return None | ||||||
|  |             if not self.message_complete.is_set(): | ||||||
|  |                 return None | ||||||
|  |  | ||||||
|  |             self.message_complete.clear() | ||||||
|  |  | ||||||
|  |             joiner: Data = b"" if self.decoder is None else "" | ||||||
|  |             # mypy cannot figure out that chunks have the proper type. | ||||||
|  |             message: Data = joiner.join(self.chunks)  # type: ignore | ||||||
|  |             if self.message_fetched.is_set(): | ||||||
|  |                 # This should be guarded against with the read_mutex, | ||||||
|  |                 # and get_in_progress check, this exception is here | ||||||
|  |                 # as a failsafe | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Websocket get() found a message when " | ||||||
|  |                     "state was already fetched." | ||||||
|  |                 ) | ||||||
|  |             self.message_fetched.set() | ||||||
|  |             self.chunks = [] | ||||||
|  |             # this should already be None, but set it here for safety | ||||||
|  |             self.chunks_queue = None | ||||||
|  |             return message | ||||||
|  |  | ||||||
|  |     async def get_iter(self) -> AsyncIterator[Data]: | ||||||
|  |         """ | ||||||
|  |         Stream the next message. | ||||||
|  |         Iterating the return value of :meth:`get_iter` yields a :class:`str` | ||||||
|  |         or :class:`bytes` for each frame in the message. | ||||||
|  |         """ | ||||||
|  |         async with self.read_mutex: | ||||||
|  |             if self.get_in_progress: | ||||||
|  |                 # This should be guarded against with the read_mutex, | ||||||
|  |                 # exception is only here as a failsafe | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Called get_iter on Websocket frame assembler " | ||||||
|  |                     "while asynchronous get is already in progress." | ||||||
|  |                 ) | ||||||
|  |             self.get_in_progress = True | ||||||
|  |  | ||||||
|  |             chunks = self.chunks | ||||||
|  |             self.chunks = [] | ||||||
|  |             self.chunks_queue = asyncio.Queue() | ||||||
|  |  | ||||||
|  |             # Sending None in chunk_queue supersedes setting message_complete | ||||||
|  |             # when switching to "streaming". If message is already complete | ||||||
|  |             # when the switch happens, put() didn't send None, so we have to. | ||||||
|  |             if self.message_complete.is_set(): | ||||||
|  |                 await self.chunks_queue.put(None) | ||||||
|  |  | ||||||
|  |             # Locking with get_in_progress ensures only one task can get here | ||||||
|  |             for c in chunks: | ||||||
|  |                 yield c | ||||||
|  |             while True: | ||||||
|  |                 chunk = await self.chunks_queue.get() | ||||||
|  |                 if chunk is None: | ||||||
|  |                     break | ||||||
|  |                 yield chunk | ||||||
|  |  | ||||||
|  |             # Unpause the transport, if its paused | ||||||
|  |             if self.paused: | ||||||
|  |                 self.protocol.resume_frames() | ||||||
|  |                 self.paused = False | ||||||
|  |             if not self.get_in_progress: | ||||||
|  |                 # This should be guarded against with the read_mutex, | ||||||
|  |                 # exception is here as a failsafe | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "State of Websocket frame assembler was modified while an " | ||||||
|  |                     "asynchronous get was in progress." | ||||||
|  |                 ) | ||||||
|  |             self.get_in_progress = False | ||||||
|  |             if not self.message_complete.is_set(): | ||||||
|  |                 # This should be guarded against with the read_mutex, | ||||||
|  |                 # exception is here as a failsafe | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Websocket frame assembler chunks queue ended before " | ||||||
|  |                     "message was complete." | ||||||
|  |                 ) | ||||||
|  |             self.message_complete.clear() | ||||||
|  |             if self.message_fetched.is_set(): | ||||||
|  |                 # This should be guarded against with the read_mutex, | ||||||
|  |                 # and get_in_progress check, this exception is | ||||||
|  |                 # here as a failsafe | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Websocket get_iter() found a message when state was " | ||||||
|  |                     "already fetched." | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             self.message_fetched.set() | ||||||
|  |             # this should already be empty, but set it here for safety | ||||||
|  |             self.chunks = [] | ||||||
|  |             self.chunks_queue = None | ||||||
|  |  | ||||||
|  |     async def put(self, frame: Frame) -> None: | ||||||
|  |         """ | ||||||
|  |         Add ``frame`` to the next message. | ||||||
|  |         When ``frame`` is the final frame in a message, :meth:`put` waits | ||||||
|  |         until the message is fetched, either by calling :meth:`get` or by | ||||||
|  |         iterating the return value of :meth:`get_iter`. | ||||||
|  |         :meth:`put` assumes that the stream of frames respects the protocol. | ||||||
|  |         If it doesn't, the behavior is undefined. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         async with self.write_mutex: | ||||||
|  |             if frame.opcode is Opcode.TEXT: | ||||||
|  |                 self.decoder = UTF8Decoder(errors="strict") | ||||||
|  |             elif frame.opcode is Opcode.BINARY: | ||||||
|  |                 self.decoder = None | ||||||
|  |             elif frame.opcode is Opcode.CONT: | ||||||
|  |                 pass | ||||||
|  |             else: | ||||||
|  |                 # Ignore control frames. | ||||||
|  |                 return | ||||||
|  |             data: Data | ||||||
|  |             if self.decoder is not None: | ||||||
|  |                 data = self.decoder.decode(frame.data, frame.fin) | ||||||
|  |             else: | ||||||
|  |                 data = frame.data | ||||||
|  |             if self.chunks_queue is None: | ||||||
|  |                 self.chunks.append(data) | ||||||
|  |             else: | ||||||
|  |                 await self.chunks_queue.put(data) | ||||||
|  |  | ||||||
|  |             if not frame.fin: | ||||||
|  |                 return | ||||||
|  |             if not self.get_in_progress: | ||||||
|  |                 # nobody is waiting for this frame, so try to pause subsequent | ||||||
|  |                 # frames at the protocol level | ||||||
|  |                 self.paused = self.protocol.pause_frames() | ||||||
|  |             # Message is complete. Wait until it's fetched to return. | ||||||
|  |  | ||||||
|  |             if self.chunks_queue is not None: | ||||||
|  |                 await self.chunks_queue.put(None) | ||||||
|  |             if self.message_complete.is_set(): | ||||||
|  |                 # This should be guarded against with the write_mutex | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Websocket put() got a new message when a message was " | ||||||
|  |                     "already in its chamber." | ||||||
|  |                 ) | ||||||
|  |             self.message_complete.set()  # Signal to get() it can serve the | ||||||
|  |             if self.message_fetched.is_set(): | ||||||
|  |                 # This should be guarded against with the write_mutex | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Websocket put() got a new message when the previous " | ||||||
|  |                     "message was not yet fetched." | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             # Allow get() to run and eventually set the event. | ||||||
|  |             await self.message_fetched.wait() | ||||||
|  |             self.message_fetched.clear() | ||||||
|  |             self.decoder = None | ||||||
							
								
								
									
										834
									
								
								sanic/server/websockets/impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										834
									
								
								sanic/server/websockets/impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,834 @@ | |||||||
|  | import asyncio | ||||||
|  | import random | ||||||
|  | import struct | ||||||
|  |  | ||||||
|  | from typing import ( | ||||||
|  |     AsyncIterator, | ||||||
|  |     Dict, | ||||||
|  |     Iterable, | ||||||
|  |     Mapping, | ||||||
|  |     Optional, | ||||||
|  |     Sequence, | ||||||
|  |     Union, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | from websockets.connection import CLOSED, CLOSING, OPEN, Event | ||||||
|  | from websockets.exceptions import ConnectionClosed, ConnectionClosedError | ||||||
|  | from websockets.frames import Frame, Opcode | ||||||
|  | from websockets.server import ServerConnection | ||||||
|  | from websockets.typing import Data | ||||||
|  |  | ||||||
|  | from sanic.log import error_logger, logger | ||||||
|  | from sanic.server.protocols.base_protocol import SanicProtocol | ||||||
|  |  | ||||||
|  | from ...exceptions import ServerError, WebsocketClosed | ||||||
|  | from .frame import WebsocketFrameAssembler | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebsocketImplProtocol: | ||||||
|  |     connection: ServerConnection | ||||||
|  |     io_proto: Optional[SanicProtocol] | ||||||
|  |     loop: Optional[asyncio.AbstractEventLoop] | ||||||
|  |     max_queue: int | ||||||
|  |     close_timeout: float | ||||||
|  |     ping_interval: Optional[float] | ||||||
|  |     ping_timeout: Optional[float] | ||||||
|  |     assembler: WebsocketFrameAssembler | ||||||
|  |     # Dict[bytes, asyncio.Future[None]] | ||||||
|  |     pings: Dict[bytes, asyncio.Future] | ||||||
|  |     conn_mutex: asyncio.Lock | ||||||
|  |     recv_lock: asyncio.Lock | ||||||
|  |     recv_cancel: Optional[asyncio.Future] | ||||||
|  |     process_event_mutex: asyncio.Lock | ||||||
|  |     can_pause: bool | ||||||
|  |     # Optional[asyncio.Future[None]] | ||||||
|  |     data_finished_fut: Optional[asyncio.Future] | ||||||
|  |     # Optional[asyncio.Future[None]] | ||||||
|  |     pause_frame_fut: Optional[asyncio.Future] | ||||||
|  |     # Optional[asyncio.Future[None]] | ||||||
|  |     connection_lost_waiter: Optional[asyncio.Future] | ||||||
|  |     keepalive_ping_task: Optional[asyncio.Task] | ||||||
|  |     auto_closer_task: Optional[asyncio.Task] | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         connection, | ||||||
|  |         max_queue=None, | ||||||
|  |         ping_interval: Optional[float] = 20, | ||||||
|  |         ping_timeout: Optional[float] = 20, | ||||||
|  |         close_timeout: float = 10, | ||||||
|  |         loop=None, | ||||||
|  |     ): | ||||||
|  |         self.connection = connection | ||||||
|  |         self.io_proto = None | ||||||
|  |         self.loop = None | ||||||
|  |         self.max_queue = max_queue | ||||||
|  |         self.close_timeout = close_timeout | ||||||
|  |         self.ping_interval = ping_interval | ||||||
|  |         self.ping_timeout = ping_timeout | ||||||
|  |         self.assembler = WebsocketFrameAssembler(self) | ||||||
|  |         self.pings = {} | ||||||
|  |         self.conn_mutex = asyncio.Lock() | ||||||
|  |         self.recv_lock = asyncio.Lock() | ||||||
|  |         self.recv_cancel = None | ||||||
|  |         self.process_event_mutex = asyncio.Lock() | ||||||
|  |         self.data_finished_fut = None | ||||||
|  |         self.can_pause = True | ||||||
|  |         self.pause_frame_fut = None | ||||||
|  |         self.keepalive_ping_task = None | ||||||
|  |         self.auto_closer_task = None | ||||||
|  |         self.connection_lost_waiter = None | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def subprotocol(self): | ||||||
|  |         return self.connection.subprotocol | ||||||
|  |  | ||||||
|  |     def pause_frames(self): | ||||||
|  |         if not self.can_pause: | ||||||
|  |             return False | ||||||
|  |         if self.pause_frame_fut: | ||||||
|  |             logger.debug("Websocket connection already paused.") | ||||||
|  |             return False | ||||||
|  |         if (not self.loop) or (not self.io_proto): | ||||||
|  |             return False | ||||||
|  |         if self.io_proto.transport: | ||||||
|  |             self.io_proto.transport.pause_reading() | ||||||
|  |         self.pause_frame_fut = self.loop.create_future() | ||||||
|  |         logger.debug("Websocket connection paused.") | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def resume_frames(self): | ||||||
|  |         if not self.pause_frame_fut: | ||||||
|  |             logger.debug("Websocket connection not paused.") | ||||||
|  |             return False | ||||||
|  |         if (not self.loop) or (not self.io_proto): | ||||||
|  |             logger.debug( | ||||||
|  |                 "Websocket attempting to resume reading frames, " | ||||||
|  |                 "but connection is gone." | ||||||
|  |             ) | ||||||
|  |             return False | ||||||
|  |         if self.io_proto.transport: | ||||||
|  |             self.io_proto.transport.resume_reading() | ||||||
|  |         self.pause_frame_fut.set_result(None) | ||||||
|  |         self.pause_frame_fut = None | ||||||
|  |         logger.debug("Websocket connection unpaused.") | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     async def connection_made( | ||||||
|  |         self, | ||||||
|  |         io_proto: SanicProtocol, | ||||||
|  |         loop: Optional[asyncio.AbstractEventLoop] = None, | ||||||
|  |     ): | ||||||
|  |         if not loop: | ||||||
|  |             try: | ||||||
|  |                 loop = getattr(io_proto, "loop") | ||||||
|  |             except AttributeError: | ||||||
|  |                 loop = asyncio.get_event_loop() | ||||||
|  |         if not loop: | ||||||
|  |             # This catch is for mypy type checker | ||||||
|  |             # to assert loop is not None here. | ||||||
|  |             raise ServerError("Connection received with no asyncio loop.") | ||||||
|  |         if self.auto_closer_task: | ||||||
|  |             raise ServerError( | ||||||
|  |                 "Cannot call connection_made more than once " | ||||||
|  |                 "on a websocket connection." | ||||||
|  |             ) | ||||||
|  |         self.loop = loop | ||||||
|  |         self.io_proto = io_proto | ||||||
|  |         self.connection_lost_waiter = self.loop.create_future() | ||||||
|  |         self.data_finished_fut = asyncio.shield(self.loop.create_future()) | ||||||
|  |  | ||||||
|  |         if self.ping_interval: | ||||||
|  |             self.keepalive_ping_task = asyncio.create_task( | ||||||
|  |                 self.keepalive_ping() | ||||||
|  |             ) | ||||||
|  |         self.auto_closer_task = asyncio.create_task( | ||||||
|  |             self.auto_close_connection() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     async def wait_for_connection_lost(self, timeout=None) -> bool: | ||||||
|  |         """ | ||||||
|  |         Wait until the TCP connection is closed or ``timeout`` elapses. | ||||||
|  |         If timeout is None, wait forever. | ||||||
|  |         Recommend you should pass in self.close_timeout as timeout | ||||||
|  |  | ||||||
|  |         Return ``True`` if the connection is closed and ``False`` otherwise. | ||||||
|  |  | ||||||
|  |         """ | ||||||
|  |         if not self.connection_lost_waiter: | ||||||
|  |             return False | ||||||
|  |         if self.connection_lost_waiter.done(): | ||||||
|  |             return True | ||||||
|  |         else: | ||||||
|  |             try: | ||||||
|  |                 await asyncio.wait_for( | ||||||
|  |                     asyncio.shield(self.connection_lost_waiter), timeout | ||||||
|  |                 ) | ||||||
|  |                 return True | ||||||
|  |             except asyncio.TimeoutError: | ||||||
|  |                 # Re-check self.connection_lost_waiter.done() synchronously | ||||||
|  |                 # because connection_lost() could run between the moment the | ||||||
|  |                 # timeout occurs and the moment this coroutine resumes running | ||||||
|  |                 return self.connection_lost_waiter.done() | ||||||
|  |  | ||||||
|  |     async def process_events(self, events: Sequence[Event]) -> None: | ||||||
|  |         """ | ||||||
|  |         Process a list of incoming events. | ||||||
|  |         """ | ||||||
|  |         # Wrapped in a mutex lock, to prevent other incoming events | ||||||
|  |         # from processing at the same time | ||||||
|  |         async with self.process_event_mutex: | ||||||
|  |             for event in events: | ||||||
|  |                 if not isinstance(event, Frame): | ||||||
|  |                     # Event is not a frame. Ignore it. | ||||||
|  |                     continue | ||||||
|  |                 if event.opcode == Opcode.PONG: | ||||||
|  |                     await self.process_pong(event) | ||||||
|  |                 elif event.opcode == Opcode.CLOSE: | ||||||
|  |                     if self.recv_cancel: | ||||||
|  |                         self.recv_cancel.cancel() | ||||||
|  |                 else: | ||||||
|  |                     await self.assembler.put(event) | ||||||
|  |  | ||||||
|  |     async def process_pong(self, frame: Frame) -> None: | ||||||
|  |         if frame.data in self.pings: | ||||||
|  |             # Acknowledge all pings up to the one matching this pong. | ||||||
|  |             ping_ids = [] | ||||||
|  |             for ping_id, ping in self.pings.items(): | ||||||
|  |                 ping_ids.append(ping_id) | ||||||
|  |                 if not ping.done(): | ||||||
|  |                     ping.set_result(None) | ||||||
|  |                 if ping_id == frame.data: | ||||||
|  |                     break | ||||||
|  |             else:  # noqa | ||||||
|  |                 raise ServerError("ping_id is not in self.pings") | ||||||
|  |             # Remove acknowledged pings from self.pings. | ||||||
|  |             for ping_id in ping_ids: | ||||||
|  |                 del self.pings[ping_id] | ||||||
|  |  | ||||||
|  |     async def keepalive_ping(self) -> None: | ||||||
|  |         """ | ||||||
|  |         Send a Ping frame and wait for a Pong frame at regular intervals. | ||||||
|  |         This coroutine exits when the connection terminates and one of the | ||||||
|  |         following happens: | ||||||
|  |         - :meth:`ping` raises :exc:`ConnectionClosed`, or | ||||||
|  |         - :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`. | ||||||
|  |         """ | ||||||
|  |         if self.ping_interval is None: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             while True: | ||||||
|  |                 await asyncio.sleep(self.ping_interval) | ||||||
|  |  | ||||||
|  |                 # ping() raises CancelledError if the connection is closed, | ||||||
|  |                 # when auto_close_connection() cancels keepalive_ping_task. | ||||||
|  |  | ||||||
|  |                 # ping() raises ConnectionClosed if the connection is lost, | ||||||
|  |                 # when connection_lost() calls abort_pings(). | ||||||
|  |  | ||||||
|  |                 ping_waiter = await self.ping() | ||||||
|  |  | ||||||
|  |                 if self.ping_timeout is not None: | ||||||
|  |                     try: | ||||||
|  |                         await asyncio.wait_for(ping_waiter, self.ping_timeout) | ||||||
|  |                     except asyncio.TimeoutError: | ||||||
|  |                         error_logger.warning( | ||||||
|  |                             "Websocket timed out waiting for pong" | ||||||
|  |                         ) | ||||||
|  |                         self.fail_connection(1011) | ||||||
|  |                         break | ||||||
|  |         except asyncio.CancelledError: | ||||||
|  |             # It is expected for this task to be cancelled during during | ||||||
|  |             # normal operation, when the connection is closed. | ||||||
|  |             logger.debug("Websocket keepalive ping task was cancelled.") | ||||||
|  |         except (ConnectionClosed, WebsocketClosed): | ||||||
|  |             logger.debug("Websocket closed. Keepalive ping task exiting.") | ||||||
|  |         except Exception as e: | ||||||
|  |             error_logger.warning( | ||||||
|  |                 "Unexpected exception in websocket keepalive ping task." | ||||||
|  |             ) | ||||||
|  |             logger.debug(str(e)) | ||||||
|  |  | ||||||
|  |     def _force_disconnect(self) -> bool: | ||||||
|  |         """ | ||||||
|  |         Internal methdod used by end_connection and fail_connection | ||||||
|  |         only when the graceful auto-closer cannot be used | ||||||
|  |         """ | ||||||
|  |         if self.auto_closer_task and not self.auto_closer_task.done(): | ||||||
|  |             self.auto_closer_task.cancel() | ||||||
|  |         if self.data_finished_fut and not self.data_finished_fut.done(): | ||||||
|  |             self.data_finished_fut.cancel() | ||||||
|  |             self.data_finished_fut = None | ||||||
|  |         if self.keepalive_ping_task and not self.keepalive_ping_task.done(): | ||||||
|  |             self.keepalive_ping_task.cancel() | ||||||
|  |             self.keepalive_ping_task = None | ||||||
|  |         if self.loop and self.io_proto and self.io_proto.transport: | ||||||
|  |             self.io_proto.transport.close() | ||||||
|  |             self.loop.call_later( | ||||||
|  |                 self.close_timeout, self.io_proto.transport.abort | ||||||
|  |             ) | ||||||
|  |         # We were never open, or already closed | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def fail_connection(self, code: int = 1006, reason: str = "") -> bool: | ||||||
|  |         """ | ||||||
|  |         Fail the WebSocket Connection | ||||||
|  |         This requires: | ||||||
|  |         1. Stopping all processing of incoming data, which means cancelling | ||||||
|  |            pausing the underlying io protocol. The close code will be 1006 | ||||||
|  |            unless a close frame was received earlier. | ||||||
|  |         2. Sending a close frame with an appropriate code if the opening | ||||||
|  |            handshake succeeded and the other side is likely to process it. | ||||||
|  |         3. Closing the connection. :meth:`auto_close_connection` takes care | ||||||
|  |            of this. | ||||||
|  |         (The specification describes these steps in the opposite order.) | ||||||
|  |         """ | ||||||
|  |         if self.io_proto and self.io_proto.transport: | ||||||
|  |             # Stop new data coming in | ||||||
|  |             # In Python Version 3.7: pause_reading is idempotent | ||||||
|  |             # ut can be called when the transport is already paused or closed | ||||||
|  |             self.io_proto.transport.pause_reading() | ||||||
|  |  | ||||||
|  |             # Keeping fail_connection() synchronous guarantees it can't | ||||||
|  |             # get stuck and simplifies the implementation of the callers. | ||||||
|  |             # Not draining the write buffer is acceptable in this context. | ||||||
|  |  | ||||||
|  |             # clear the send buffer | ||||||
|  |             _ = self.connection.data_to_send() | ||||||
|  |             # If we're not already CLOSED or CLOSING, then send the close. | ||||||
|  |             if self.connection.state is OPEN: | ||||||
|  |                 if code in (1000, 1001): | ||||||
|  |                     self.connection.send_close(code, reason) | ||||||
|  |                 else: | ||||||
|  |                     self.connection.fail(code, reason) | ||||||
|  |                 try: | ||||||
|  |                     data_to_send = self.connection.data_to_send() | ||||||
|  |                     while ( | ||||||
|  |                         len(data_to_send) | ||||||
|  |                         and self.io_proto | ||||||
|  |                         and self.io_proto.transport | ||||||
|  |                     ): | ||||||
|  |                         frame_data = data_to_send.pop(0) | ||||||
|  |                         self.io_proto.transport.write(frame_data) | ||||||
|  |                 except Exception: | ||||||
|  |                     # sending close frames may fail if the | ||||||
|  |                     # transport closes during this period | ||||||
|  |                     ... | ||||||
|  |         if code == 1006: | ||||||
|  |             # Special case: 1006 consider the transport already closed | ||||||
|  |             self.connection.state = CLOSED | ||||||
|  |         if self.data_finished_fut and not self.data_finished_fut.done(): | ||||||
|  |             # We have a graceful auto-closer. Use it to close the connection. | ||||||
|  |             self.data_finished_fut.cancel() | ||||||
|  |             self.data_finished_fut = None | ||||||
|  |         if (not self.auto_closer_task) or self.auto_closer_task.done(): | ||||||
|  |             return self._force_disconnect() | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def end_connection(self, code=1000, reason=""): | ||||||
|  |         # This is like slightly more graceful form of fail_connection | ||||||
|  |         # Use this instead of close() when you need an immediate | ||||||
|  |         # close and cannot await websocket.close() handshake. | ||||||
|  |  | ||||||
|  |         if code == 1006 or not self.io_proto or not self.io_proto.transport: | ||||||
|  |             return self.fail_connection(code, reason) | ||||||
|  |  | ||||||
|  |         # Stop new data coming in | ||||||
|  |         # In Python Version 3.7: pause_reading is idempotent | ||||||
|  |         # i.e. it can be called when the transport is already paused or closed. | ||||||
|  |         self.io_proto.transport.pause_reading() | ||||||
|  |         if self.connection.state == OPEN: | ||||||
|  |             data_to_send = self.connection.data_to_send() | ||||||
|  |             self.connection.send_close(code, reason) | ||||||
|  |             data_to_send.extend(self.connection.data_to_send()) | ||||||
|  |             try: | ||||||
|  |                 while ( | ||||||
|  |                     len(data_to_send) | ||||||
|  |                     and self.io_proto | ||||||
|  |                     and self.io_proto.transport | ||||||
|  |                 ): | ||||||
|  |                     frame_data = data_to_send.pop(0) | ||||||
|  |                     self.io_proto.transport.write(frame_data) | ||||||
|  |             except Exception: | ||||||
|  |                 # sending close frames may fail if the | ||||||
|  |                 # transport closes during this period | ||||||
|  |                 # But that doesn't matter at this point | ||||||
|  |                 ... | ||||||
|  |         if self.data_finished_fut and not self.data_finished_fut.done(): | ||||||
|  |             # We have the ability to signal the auto-closer | ||||||
|  |             # try to trigger it to auto-close the connection | ||||||
|  |             self.data_finished_fut.cancel() | ||||||
|  |             self.data_finished_fut = None | ||||||
|  |         if (not self.auto_closer_task) or self.auto_closer_task.done(): | ||||||
|  |             # Auto-closer is not running, do force disconnect | ||||||
|  |             return self._force_disconnect() | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     async def auto_close_connection(self) -> None: | ||||||
|  |         """ | ||||||
|  |         Close the WebSocket Connection | ||||||
|  |         When the opening handshake succeeds, :meth:`connection_open` starts | ||||||
|  |         this coroutine in a task. It waits for the data transfer phase to | ||||||
|  |         complete then it closes the TCP connection cleanly. | ||||||
|  |         When the opening handshake fails, :meth:`fail_connection` does the | ||||||
|  |         same. There's no data transfer phase in that case. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # Wait for the data transfer phase to complete. | ||||||
|  |             if self.data_finished_fut: | ||||||
|  |                 try: | ||||||
|  |                     await self.data_finished_fut | ||||||
|  |                     logger.debug( | ||||||
|  |                         "Websocket task finished. Closing the connection." | ||||||
|  |                     ) | ||||||
|  |                 except asyncio.CancelledError: | ||||||
|  |                     # Cancelled error is called when data phase is cancelled | ||||||
|  |                     # if an error occurred or the client closed the connection | ||||||
|  |                     logger.debug( | ||||||
|  |                         "Websocket handler cancelled. Closing the connection." | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |             # Cancel the keepalive ping task. | ||||||
|  |             if self.keepalive_ping_task: | ||||||
|  |                 self.keepalive_ping_task.cancel() | ||||||
|  |                 self.keepalive_ping_task = None | ||||||
|  |  | ||||||
|  |             # Half-close the TCP connection if possible (when there's no TLS). | ||||||
|  |             if ( | ||||||
|  |                 self.io_proto | ||||||
|  |                 and self.io_proto.transport | ||||||
|  |                 and self.io_proto.transport.can_write_eof() | ||||||
|  |             ): | ||||||
|  |                 logger.debug("Websocket half-closing TCP connection") | ||||||
|  |                 self.io_proto.transport.write_eof() | ||||||
|  |                 if self.connection_lost_waiter: | ||||||
|  |                     if await self.wait_for_connection_lost(timeout=0): | ||||||
|  |                         return | ||||||
|  |         except asyncio.CancelledError: | ||||||
|  |             ... | ||||||
|  |         finally: | ||||||
|  |             # The try/finally ensures that the transport never remains open, | ||||||
|  |             # even if this coroutine is cancelled (for example). | ||||||
|  |             if (not self.io_proto) or (not self.io_proto.transport): | ||||||
|  |                 # we were never open, or done. Can't do any finalization. | ||||||
|  |                 return | ||||||
|  |             elif ( | ||||||
|  |                 self.connection_lost_waiter | ||||||
|  |                 and self.connection_lost_waiter.done() | ||||||
|  |             ): | ||||||
|  |                 # connection confirmed closed already, proceed to abort waiter | ||||||
|  |                 ... | ||||||
|  |             elif self.io_proto.transport.is_closing(): | ||||||
|  |                 # Connection is already closing (due to half-close above) | ||||||
|  |                 # proceed to abort waiter | ||||||
|  |                 ... | ||||||
|  |             else: | ||||||
|  |                 self.io_proto.transport.close() | ||||||
|  |             if not self.connection_lost_waiter: | ||||||
|  |                 # Our connection monitor task isn't running. | ||||||
|  |                 try: | ||||||
|  |                     await asyncio.sleep(self.close_timeout) | ||||||
|  |                 except asyncio.CancelledError: | ||||||
|  |                     ... | ||||||
|  |                 if self.io_proto and self.io_proto.transport: | ||||||
|  |                     self.io_proto.transport.abort() | ||||||
|  |             else: | ||||||
|  |                 if await self.wait_for_connection_lost( | ||||||
|  |                     timeout=self.close_timeout | ||||||
|  |                 ): | ||||||
|  |                     # Connection aborted before the timeout expired. | ||||||
|  |                     return | ||||||
|  |                 error_logger.warning( | ||||||
|  |                     "Timeout waiting for TCP connection to close. Aborting" | ||||||
|  |                 ) | ||||||
|  |                 if self.io_proto and self.io_proto.transport: | ||||||
|  |                     self.io_proto.transport.abort() | ||||||
|  |  | ||||||
|  |     def abort_pings(self) -> None: | ||||||
|  |         """ | ||||||
|  |         Raise ConnectionClosed in pending keepalive pings. | ||||||
|  |         They'll never receive a pong once the connection is closed. | ||||||
|  |         """ | ||||||
|  |         if self.connection.state is not CLOSED: | ||||||
|  |             raise ServerError( | ||||||
|  |                 "Webscoket about_pings should only be called " | ||||||
|  |                 "after connection state is changed to CLOSED" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         for ping in self.pings.values(): | ||||||
|  |             ping.set_exception(ConnectionClosedError(None, None)) | ||||||
|  |             # If the exception is never retrieved, it will be logged when ping | ||||||
|  |             # is garbage-collected. This is confusing for users. | ||||||
|  |             # Given that ping is done (with an exception), canceling it does | ||||||
|  |             # nothing, but it prevents logging the exception. | ||||||
|  |             ping.cancel() | ||||||
|  |  | ||||||
|  |     async def close(self, code: int = 1000, reason: str = "") -> None: | ||||||
|  |         """ | ||||||
|  |         Perform the closing handshake. | ||||||
|  |         This is a websocket-protocol level close. | ||||||
|  |         :meth:`close` waits for the other end to complete the handshake and | ||||||
|  |         for the TCP connection to terminate. | ||||||
|  |         :meth:`close` is idempotent: it doesn't do anything once the | ||||||
|  |         connection is closed. | ||||||
|  |         :param code: WebSocket close code | ||||||
|  |         :param reason: WebSocket close reason | ||||||
|  |         """ | ||||||
|  |         if code == 1006: | ||||||
|  |             self.fail_connection(code, reason) | ||||||
|  |             return | ||||||
|  |         async with self.conn_mutex: | ||||||
|  |             if self.connection.state is OPEN: | ||||||
|  |                 self.connection.send_close(code, reason) | ||||||
|  |                 data_to_send = self.connection.data_to_send() | ||||||
|  |                 await self.send_data(data_to_send) | ||||||
|  |  | ||||||
|  |     async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: | ||||||
|  |         """ | ||||||
|  |         Receive the next message. | ||||||
|  |         Return a :class:`str` for a text frame and :class:`bytes` for a binary | ||||||
|  |         frame. | ||||||
|  |         When the end of the message stream is reached, :meth:`recv` raises | ||||||
|  |         :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it | ||||||
|  |         raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal | ||||||
|  |         connection closure and | ||||||
|  |         :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol | ||||||
|  |         error or a network failure. | ||||||
|  |         If ``timeout`` is ``None``, block until a message is received. Else, | ||||||
|  |         if no message is received within ``timeout`` seconds, return ``None``. | ||||||
|  |         Set ``timeout`` to ``0`` to check if a message was already received. | ||||||
|  |         :raises ~websockets.exceptions.ConnectionClosed: when the | ||||||
|  |             connection is closed | ||||||
|  |         :raises asyncio.CancelledError: if the websocket closes while waiting | ||||||
|  |         :raises ServerError: if two tasks call :meth:`recv` or | ||||||
|  |             :meth:`recv_streaming` concurrently | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if self.recv_lock.locked(): | ||||||
|  |             raise ServerError( | ||||||
|  |                 "cannot call recv while another task is " | ||||||
|  |                 "already waiting for the next message" | ||||||
|  |             ) | ||||||
|  |         await self.recv_lock.acquire() | ||||||
|  |         if self.connection.state is CLOSED: | ||||||
|  |             self.recv_lock.release() | ||||||
|  |             raise WebsocketClosed( | ||||||
|  |                 "Cannot receive from websocket interface after it is closed." | ||||||
|  |             ) | ||||||
|  |         try: | ||||||
|  |             self.recv_cancel = asyncio.Future() | ||||||
|  |             done, pending = await asyncio.wait( | ||||||
|  |                 (self.recv_cancel, self.assembler.get(timeout)), | ||||||
|  |                 return_when=asyncio.FIRST_COMPLETED, | ||||||
|  |             ) | ||||||
|  |             done_task = next(iter(done)) | ||||||
|  |             if done_task is self.recv_cancel: | ||||||
|  |                 # recv was cancelled | ||||||
|  |                 for p in pending: | ||||||
|  |                     p.cancel() | ||||||
|  |                 raise asyncio.CancelledError() | ||||||
|  |             else: | ||||||
|  |                 self.recv_cancel.cancel() | ||||||
|  |                 return done_task.result() | ||||||
|  |         finally: | ||||||
|  |             self.recv_cancel = None | ||||||
|  |             self.recv_lock.release() | ||||||
|  |  | ||||||
|  |     async def recv_burst(self, max_recv=256) -> Sequence[Data]: | ||||||
|  |         """ | ||||||
|  |         Receive the messages which have arrived since last checking. | ||||||
|  |         Return a :class:`list` containing :class:`str` for a text frame | ||||||
|  |         and :class:`bytes` for a binary frame. | ||||||
|  |         When the end of the message stream is reached, :meth:`recv_burst` | ||||||
|  |         raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, | ||||||
|  |         it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a | ||||||
|  |         normal connection closure and | ||||||
|  |         :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol | ||||||
|  |         error or a network failure. | ||||||
|  |         :raises ~websockets.exceptions.ConnectionClosed: when the | ||||||
|  |             connection is closed | ||||||
|  |         :raises ServerError: if two tasks call :meth:`recv_burst` or | ||||||
|  |             :meth:`recv_streaming` concurrently | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if self.recv_lock.locked(): | ||||||
|  |             raise ServerError( | ||||||
|  |                 "cannot call recv_burst while another task is already waiting " | ||||||
|  |                 "for the next message" | ||||||
|  |             ) | ||||||
|  |         await self.recv_lock.acquire() | ||||||
|  |         if self.connection.state is CLOSED: | ||||||
|  |             self.recv_lock.release() | ||||||
|  |             raise WebsocketClosed( | ||||||
|  |                 "Cannot receive from websocket interface after it is closed." | ||||||
|  |             ) | ||||||
|  |         messages = [] | ||||||
|  |         try: | ||||||
|  |             # Prevent pausing the transport when we're | ||||||
|  |             # receiving a burst of messages | ||||||
|  |             self.can_pause = False | ||||||
|  |             self.recv_cancel = asyncio.Future() | ||||||
|  |             while True: | ||||||
|  |                 done, pending = await asyncio.wait( | ||||||
|  |                     (self.recv_cancel, self.assembler.get(timeout=0)), | ||||||
|  |                     return_when=asyncio.FIRST_COMPLETED, | ||||||
|  |                 ) | ||||||
|  |                 done_task = next(iter(done)) | ||||||
|  |                 if done_task is self.recv_cancel: | ||||||
|  |                     # recv_burst was cancelled | ||||||
|  |                     for p in pending: | ||||||
|  |                         p.cancel() | ||||||
|  |                     raise asyncio.CancelledError() | ||||||
|  |                 m = done_task.result() | ||||||
|  |                 if m is None: | ||||||
|  |                     # None left in the burst. This is good! | ||||||
|  |                     break | ||||||
|  |                 messages.append(m) | ||||||
|  |                 if len(messages) >= max_recv: | ||||||
|  |                     # Too much data in the pipe. Hit our burst limit. | ||||||
|  |                     break | ||||||
|  |                 # Allow an eventloop iteration for the | ||||||
|  |                 # next message to pass into the Assembler | ||||||
|  |                 await asyncio.sleep(0) | ||||||
|  |             self.recv_cancel.cancel() | ||||||
|  |         finally: | ||||||
|  |             self.recv_cancel = None | ||||||
|  |             self.can_pause = True | ||||||
|  |             self.recv_lock.release() | ||||||
|  |         return messages | ||||||
|  |  | ||||||
|  |     async def recv_streaming(self) -> AsyncIterator[Data]: | ||||||
|  |         """ | ||||||
|  |         Receive the next message frame by frame. | ||||||
|  |         Return an iterator of :class:`str` for a text frame and :class:`bytes` | ||||||
|  |         for a binary frame. The iterator should be exhausted, or else the | ||||||
|  |         connection will become unusable. | ||||||
|  |         With the exception of the return value, :meth:`recv_streaming` behaves | ||||||
|  |         like :meth:`recv`. | ||||||
|  |         """ | ||||||
|  |         if self.recv_lock.locked(): | ||||||
|  |             raise ServerError( | ||||||
|  |                 "Cannot call recv_streaming while another task " | ||||||
|  |                 "is already waiting for the next message" | ||||||
|  |             ) | ||||||
|  |         await self.recv_lock.acquire() | ||||||
|  |         if self.connection.state is CLOSED: | ||||||
|  |             self.recv_lock.release() | ||||||
|  |             raise WebsocketClosed( | ||||||
|  |                 "Cannot receive from websocket interface after it is closed." | ||||||
|  |             ) | ||||||
|  |         try: | ||||||
|  |             cancelled = False | ||||||
|  |             self.recv_cancel = asyncio.Future() | ||||||
|  |             self.can_pause = False | ||||||
|  |             async for m in self.assembler.get_iter(): | ||||||
|  |                 if self.recv_cancel.done(): | ||||||
|  |                     cancelled = True | ||||||
|  |                     break | ||||||
|  |                 yield m | ||||||
|  |             if cancelled: | ||||||
|  |                 raise asyncio.CancelledError() | ||||||
|  |         finally: | ||||||
|  |             self.can_pause = True | ||||||
|  |             self.recv_cancel = None | ||||||
|  |             self.recv_lock.release() | ||||||
|  |  | ||||||
|  |     async def send(self, message: Union[Data, Iterable[Data]]) -> None: | ||||||
|  |         """ | ||||||
|  |         Send a message. | ||||||
|  |         A string (:class:`str`) is sent as a `Text frame`_. A bytestring or | ||||||
|  |         bytes-like object (:class:`bytes`, :class:`bytearray`, or | ||||||
|  |         :class:`memoryview`) is sent as a `Binary frame`_. | ||||||
|  |         .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 | ||||||
|  |         .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 | ||||||
|  |         :meth:`send` also accepts an iterable of strings, bytestrings, or | ||||||
|  |         bytes-like objects. In that case the message is fragmented. Each item | ||||||
|  |         is treated as a message fragment and sent in its own frame. All items | ||||||
|  |         must be of the same type, or else :meth:`send` will raise a | ||||||
|  |         :exc:`TypeError` and the connection will be closed. | ||||||
|  |         :meth:`send` rejects dict-like objects because this is often an error. | ||||||
|  |         If you wish to send the keys of a dict-like object as fragments, call | ||||||
|  |         its :meth:`~dict.keys` method and pass the result to :meth:`send`. | ||||||
|  |         :raises TypeError: for unsupported inputs | ||||||
|  |         """ | ||||||
|  |         async with self.conn_mutex: | ||||||
|  |  | ||||||
|  |             if self.connection.state in (CLOSED, CLOSING): | ||||||
|  |                 raise WebsocketClosed( | ||||||
|  |                     "Cannot write to websocket interface after it is closed." | ||||||
|  |                 ) | ||||||
|  |             if (not self.data_finished_fut) or self.data_finished_fut.done(): | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Cannot write to websocket interface after it is finished." | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             # Unfragmented message -- this case must be handled first because | ||||||
|  |             # strings and bytes-like objects are iterable. | ||||||
|  |  | ||||||
|  |             if isinstance(message, str): | ||||||
|  |                 self.connection.send_text(message.encode("utf-8")) | ||||||
|  |                 await self.send_data(self.connection.data_to_send()) | ||||||
|  |  | ||||||
|  |             elif isinstance(message, (bytes, bytearray, memoryview)): | ||||||
|  |                 self.connection.send_binary(message) | ||||||
|  |                 await self.send_data(self.connection.data_to_send()) | ||||||
|  |  | ||||||
|  |             elif isinstance(message, Mapping): | ||||||
|  |                 # Catch a common mistake -- passing a dict to send(). | ||||||
|  |                 raise TypeError("data is a dict-like object") | ||||||
|  |  | ||||||
|  |             elif isinstance(message, Iterable): | ||||||
|  |                 # Fragmented message -- regular iterator. | ||||||
|  |                 raise NotImplementedError( | ||||||
|  |                     "Fragmented websocket messages are not supported." | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 raise TypeError("Websocket data must be bytes, str.") | ||||||
|  |  | ||||||
|  |     async def ping(self, data: Optional[Data] = None) -> asyncio.Future: | ||||||
|  |         """ | ||||||
|  |         Send a ping. | ||||||
|  |         Return an :class:`~asyncio.Future` that will be resolved when the | ||||||
|  |         corresponding pong is received. You can ignore it if you don't intend | ||||||
|  |         to wait. | ||||||
|  |         A ping may serve as a keepalive or as a check that the remote endpoint | ||||||
|  |         received all messages up to this point:: | ||||||
|  |             await pong_event = ws.ping() | ||||||
|  |             await pong_event # only if you want to wait for the pong | ||||||
|  |         By default, the ping contains four random bytes. This payload may be | ||||||
|  |         overridden with the optional ``data`` argument which must be a string | ||||||
|  |         (which will be encoded to UTF-8) or a bytes-like object. | ||||||
|  |         """ | ||||||
|  |         async with self.conn_mutex: | ||||||
|  |             if self.connection.state in (CLOSED, CLOSING): | ||||||
|  |                 raise WebsocketClosed( | ||||||
|  |                     "Cannot send a ping when the websocket interface " | ||||||
|  |                     "is closed." | ||||||
|  |                 ) | ||||||
|  |             if (not self.io_proto) or (not self.io_proto.loop): | ||||||
|  |                 raise ServerError( | ||||||
|  |                     "Cannot send a ping when the websocket has no I/O " | ||||||
|  |                     "protocol attached." | ||||||
|  |                 ) | ||||||
|  |             if data is not None: | ||||||
|  |                 if isinstance(data, str): | ||||||
|  |                     data = data.encode("utf-8") | ||||||
|  |                 elif isinstance(data, (bytearray, memoryview)): | ||||||
|  |                     data = bytes(data) | ||||||
|  |  | ||||||
|  |             # Protect against duplicates if a payload is explicitly set. | ||||||
|  |             if data in self.pings: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "already waiting for a pong with the same data" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             # Generate a unique random payload otherwise. | ||||||
|  |             while data is None or data in self.pings: | ||||||
|  |                 data = struct.pack("!I", random.getrandbits(32)) | ||||||
|  |  | ||||||
|  |             self.pings[data] = self.io_proto.loop.create_future() | ||||||
|  |  | ||||||
|  |             self.connection.send_ping(data) | ||||||
|  |             await self.send_data(self.connection.data_to_send()) | ||||||
|  |  | ||||||
|  |             return asyncio.shield(self.pings[data]) | ||||||
|  |  | ||||||
|  |     async def pong(self, data: Data = b"") -> None: | ||||||
|  |         """ | ||||||
|  |         Send a pong. | ||||||
|  |         An unsolicited pong may serve as a unidirectional heartbeat. | ||||||
|  |         The payload may be set with the optional ``data`` argument which must | ||||||
|  |         be a string (which will be encoded to UTF-8) or a bytes-like object. | ||||||
|  |         """ | ||||||
|  |         async with self.conn_mutex: | ||||||
|  |             if self.connection.state in (CLOSED, CLOSING): | ||||||
|  |                 # Cannot send pong after transport is shutting down | ||||||
|  |                 return | ||||||
|  |             if isinstance(data, str): | ||||||
|  |                 data = data.encode("utf-8") | ||||||
|  |             elif isinstance(data, (bytearray, memoryview)): | ||||||
|  |                 data = bytes(data) | ||||||
|  |             self.connection.send_pong(data) | ||||||
|  |             await self.send_data(self.connection.data_to_send()) | ||||||
|  |  | ||||||
|  |     async def send_data(self, data_to_send): | ||||||
|  |         for data in data_to_send: | ||||||
|  |             if data: | ||||||
|  |                 await self.io_proto.send(data) | ||||||
|  |             else: | ||||||
|  |                 # Send an EOF - We don't actually send it, | ||||||
|  |                 # just trigger to autoclose the connection | ||||||
|  |                 if ( | ||||||
|  |                     self.auto_closer_task | ||||||
|  |                     and not self.auto_closer_task.done() | ||||||
|  |                     and self.data_finished_fut | ||||||
|  |                     and not self.data_finished_fut.done() | ||||||
|  |                 ): | ||||||
|  |                     # Auto-close the connection | ||||||
|  |                     self.data_finished_fut.set_result(None) | ||||||
|  |                 else: | ||||||
|  |                     # This will fail the connection appropriately | ||||||
|  |                     SanicProtocol.close(self.io_proto, timeout=1.0) | ||||||
|  |  | ||||||
|  |     async def async_data_received(self, data_to_send, events_to_process): | ||||||
|  |         if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: | ||||||
|  |             # receiving data can generate data to send (eg, pong for a ping) | ||||||
|  |             # send connection.data_to_send() | ||||||
|  |             await self.send_data(data_to_send) | ||||||
|  |         if len(events_to_process) > 0: | ||||||
|  |             await self.process_events(events_to_process) | ||||||
|  |  | ||||||
|  |     def data_received(self, data): | ||||||
|  |         self.connection.receive_data(data) | ||||||
|  |         data_to_send = self.connection.data_to_send() | ||||||
|  |         events_to_process = self.connection.events_received() | ||||||
|  |         if len(data_to_send) > 0 or len(events_to_process) > 0: | ||||||
|  |             asyncio.create_task( | ||||||
|  |                 self.async_data_received(data_to_send, events_to_process) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     async def async_eof_received(self, data_to_send, events_to_process): | ||||||
|  |         # receiving EOF can generate data to send | ||||||
|  |         # send connection.data_to_send() | ||||||
|  |         if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: | ||||||
|  |             await self.send_data(data_to_send) | ||||||
|  |         if len(events_to_process) > 0: | ||||||
|  |             await self.process_events(events_to_process) | ||||||
|  |         if self.recv_cancel: | ||||||
|  |             self.recv_cancel.cancel() | ||||||
|  |         if ( | ||||||
|  |             self.auto_closer_task | ||||||
|  |             and not self.auto_closer_task.done() | ||||||
|  |             and self.data_finished_fut | ||||||
|  |             and not self.data_finished_fut.done() | ||||||
|  |         ): | ||||||
|  |             # Auto-close the connection | ||||||
|  |             self.data_finished_fut.set_result(None) | ||||||
|  |             # Cancel the running handler if its waiting | ||||||
|  |         else: | ||||||
|  |             # This will fail the connection appropriately | ||||||
|  |             SanicProtocol.close(self.io_proto, timeout=1.0) | ||||||
|  |  | ||||||
|  |     def eof_received(self) -> Optional[bool]: | ||||||
|  |         self.connection.receive_eof() | ||||||
|  |         data_to_send = self.connection.data_to_send() | ||||||
|  |         events_to_process = self.connection.events_received() | ||||||
|  |         asyncio.create_task( | ||||||
|  |             self.async_eof_received(data_to_send, events_to_process) | ||||||
|  |         ) | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def connection_lost(self, exc): | ||||||
|  |         """ | ||||||
|  |         The WebSocket Connection is Closed. | ||||||
|  |         """ | ||||||
|  |         if not self.connection.state == CLOSED: | ||||||
|  |             # signal to the websocket connection handler | ||||||
|  |             # we've lost the connection | ||||||
|  |             self.connection.fail(code=1006) | ||||||
|  |             self.connection.state = CLOSED | ||||||
|  |  | ||||||
|  |         self.abort_pings() | ||||||
|  |         if self.connection_lost_waiter: | ||||||
|  |             self.connection_lost_waiter.set_result(None) | ||||||
| @@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound  # type: ignore | |||||||
| from sanic_routing.utils import path_to_parts  # type: ignore | from sanic_routing.utils import path_to_parts  # type: ignore | ||||||
|  |  | ||||||
| from sanic.exceptions import InvalidSignal | from sanic.exceptions import InvalidSignal | ||||||
|  | from sanic.log import error_logger, logger | ||||||
| from sanic.models.handler_types import SignalHandler | from sanic.models.handler_types import SignalHandler | ||||||
|  |  | ||||||
|  |  | ||||||
| RESERVED_NAMESPACES = ( | RESERVED_NAMESPACES = { | ||||||
|     "server", |     "server": ( | ||||||
|     "http", |         # "server.main.start", | ||||||
| ) |         # "server.main.stop", | ||||||
|  |         "server.init.before", | ||||||
|  |         "server.init.after", | ||||||
|  |         "server.shutdown.before", | ||||||
|  |         "server.shutdown.after", | ||||||
|  |     ), | ||||||
|  |     "http": ( | ||||||
|  |         "http.lifecycle.begin", | ||||||
|  |         "http.lifecycle.complete", | ||||||
|  |         "http.lifecycle.exception", | ||||||
|  |         "http.lifecycle.handle", | ||||||
|  |         "http.lifecycle.read_body", | ||||||
|  |         "http.lifecycle.read_head", | ||||||
|  |         "http.lifecycle.request", | ||||||
|  |         "http.lifecycle.response", | ||||||
|  |         "http.routing.after", | ||||||
|  |         "http.routing.before", | ||||||
|  |         "http.lifecycle.send", | ||||||
|  |         "http.middleware.after", | ||||||
|  |         "http.middleware.before", | ||||||
|  |     ), | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _blank(): | ||||||
|  |     ... | ||||||
|  |  | ||||||
|  |  | ||||||
| class Signal(Route): | class Signal(Route): | ||||||
| @@ -59,8 +85,13 @@ class SignalRouter(BaseRouter): | |||||||
|                 terms.append(extra) |                 terms.append(extra) | ||||||
|             raise NotFound(message % tuple(terms)) |             raise NotFound(message % tuple(terms)) | ||||||
|  |  | ||||||
|  |         # Regex routes evaluate and can extract params directly. They are set | ||||||
|  |         # on param_basket["__params__"] | ||||||
|         params = param_basket["__params__"] |         params = param_basket["__params__"] | ||||||
|         if not params: |         if not params: | ||||||
|  |             # If param_basket["__params__"] does not exist, we might have | ||||||
|  |             # param_basket["__matches__"], which are indexed based matches | ||||||
|  |             # on path segments. They should already be cast types. | ||||||
|             params = { |             params = { | ||||||
|                 param.name: param_basket["__matches__"][idx] |                 param.name: param_basket["__matches__"][idx] | ||||||
|                 for idx, param in group.params.items() |                 for idx, param in group.params.items() | ||||||
| @@ -73,8 +104,18 @@ class SignalRouter(BaseRouter): | |||||||
|         event: str, |         event: str, | ||||||
|         context: Optional[Dict[str, Any]] = None, |         context: Optional[Dict[str, Any]] = None, | ||||||
|         condition: Optional[Dict[str, str]] = None, |         condition: Optional[Dict[str, str]] = None, | ||||||
|     ) -> None: |         fail_not_found: bool = True, | ||||||
|  |         reverse: bool = False, | ||||||
|  |     ) -> Any: | ||||||
|  |         try: | ||||||
|             group, handlers, params = self.get(event, condition=condition) |             group, handlers, params = self.get(event, condition=condition) | ||||||
|  |         except NotFound as e: | ||||||
|  |             if fail_not_found: | ||||||
|  |                 raise e | ||||||
|  |             else: | ||||||
|  |                 if self.ctx.app.debug: | ||||||
|  |                     error_logger.warning(str(e)) | ||||||
|  |                 return None | ||||||
|  |  | ||||||
|         events = [signal.ctx.event for signal in group] |         events = [signal.ctx.event for signal in group] | ||||||
|         for signal_event in events: |         for signal_event in events: | ||||||
| @@ -82,12 +123,19 @@ class SignalRouter(BaseRouter): | |||||||
|         if context: |         if context: | ||||||
|             params.update(context) |             params.update(context) | ||||||
|  |  | ||||||
|  |         if not reverse: | ||||||
|  |             handlers = handlers[::-1] | ||||||
|         try: |         try: | ||||||
|             for handler in handlers: |             for handler in handlers: | ||||||
|                 if condition is None or condition == handler.__requirements__: |                 if condition is None or condition == handler.__requirements__: | ||||||
|                     maybe_coroutine = handler(**params) |                     maybe_coroutine = handler(**params) | ||||||
|                     if isawaitable(maybe_coroutine): |                     if isawaitable(maybe_coroutine): | ||||||
|                         await maybe_coroutine |                         retval = await maybe_coroutine | ||||||
|  |                         if retval: | ||||||
|  |                             return retval | ||||||
|  |                     elif maybe_coroutine: | ||||||
|  |                         return maybe_coroutine | ||||||
|  |             return None | ||||||
|         finally: |         finally: | ||||||
|             for signal_event in events: |             for signal_event in events: | ||||||
|                 signal_event.clear() |                 signal_event.clear() | ||||||
| @@ -98,14 +146,23 @@ class SignalRouter(BaseRouter): | |||||||
|         *, |         *, | ||||||
|         context: Optional[Dict[str, Any]] = None, |         context: Optional[Dict[str, Any]] = None, | ||||||
|         condition: Optional[Dict[str, str]] = None, |         condition: Optional[Dict[str, str]] = None, | ||||||
|     ) -> asyncio.Task: |         fail_not_found: bool = True, | ||||||
|         task = self.ctx.loop.create_task( |         inline: bool = False, | ||||||
|             self._dispatch( |         reverse: bool = False, | ||||||
|  |     ) -> Union[asyncio.Task, Any]: | ||||||
|  |         dispatch = self._dispatch( | ||||||
|             event, |             event, | ||||||
|             context=context, |             context=context, | ||||||
|             condition=condition, |             condition=condition, | ||||||
|  |             fail_not_found=fail_not_found and inline, | ||||||
|  |             reverse=reverse, | ||||||
|         ) |         ) | ||||||
|         ) |         logger.debug(f"Dispatching signal: {event}") | ||||||
|  |  | ||||||
|  |         if inline: | ||||||
|  |             return await dispatch | ||||||
|  |  | ||||||
|  |         task = asyncio.get_running_loop().create_task(dispatch) | ||||||
|         await asyncio.sleep(0) |         await asyncio.sleep(0) | ||||||
|         return task |         return task | ||||||
|  |  | ||||||
| @@ -131,7 +188,9 @@ class SignalRouter(BaseRouter): | |||||||
|             append=True, |             append=True, | ||||||
|         )  # type: ignore |         )  # type: ignore | ||||||
|  |  | ||||||
|     def finalize(self, do_compile: bool = True): |     def finalize(self, do_compile: bool = True, do_optimize: bool = False): | ||||||
|  |         self.add(_blank, "sanic.__signal__.__init__") | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             self.ctx.loop = asyncio.get_running_loop() |             self.ctx.loop = asyncio.get_running_loop() | ||||||
|         except RuntimeError: |         except RuntimeError: | ||||||
| @@ -140,7 +199,7 @@ class SignalRouter(BaseRouter): | |||||||
|         for signal in self.routes: |         for signal in self.routes: | ||||||
|             signal.ctx.event = asyncio.Event() |             signal.ctx.event = asyncio.Event() | ||||||
|  |  | ||||||
|         return super().finalize(do_compile=do_compile) |         return super().finalize(do_compile=do_compile, do_optimize=do_optimize) | ||||||
|  |  | ||||||
|     def _build_event_parts(self, event: str) -> Tuple[str, str, str]: |     def _build_event_parts(self, event: str) -> Tuple[str, str, str]: | ||||||
|         parts = path_to_parts(event, self.delimiter) |         parts = path_to_parts(event, self.delimiter) | ||||||
| @@ -151,7 +210,11 @@ class SignalRouter(BaseRouter): | |||||||
|         ): |         ): | ||||||
|             raise InvalidSignal("Invalid signal event: %s" % event) |             raise InvalidSignal("Invalid signal event: %s" % event) | ||||||
|  |  | ||||||
|         if parts[0] in RESERVED_NAMESPACES: |         if ( | ||||||
|  |             parts[0] in RESERVED_NAMESPACES | ||||||
|  |             and event not in RESERVED_NAMESPACES[parts[0]] | ||||||
|  |             and not (parts[2].startswith("<") and parts[2].endswith(">")) | ||||||
|  |         ): | ||||||
|             raise InvalidSignal( |             raise InvalidSignal( | ||||||
|                 "Cannot declare reserved signal event: %s" % event |                 "Cannot declare reserved signal event: %s" % event | ||||||
|             ) |             ) | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								sanic/touchup/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								sanic/touchup/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | |||||||
|  | from .meta import TouchUpMeta | ||||||
|  | from .service import TouchUp | ||||||
|  |  | ||||||
|  |  | ||||||
|  | __all__ = ( | ||||||
|  |     "TouchUp", | ||||||
|  |     "TouchUpMeta", | ||||||
|  | ) | ||||||
							
								
								
									
										22
									
								
								sanic/touchup/meta.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								sanic/touchup/meta.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | from sanic.exceptions import SanicException | ||||||
|  |  | ||||||
|  | from .service import TouchUp | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TouchUpMeta(type): | ||||||
|  |     def __new__(cls, name, bases, attrs, **kwargs): | ||||||
|  |         gen_class = super().__new__(cls, name, bases, attrs, **kwargs) | ||||||
|  |  | ||||||
|  |         methods = attrs.get("__touchup__") | ||||||
|  |         attrs["__touched__"] = False | ||||||
|  |         if methods: | ||||||
|  |  | ||||||
|  |             for method in methods: | ||||||
|  |                 if method not in attrs: | ||||||
|  |                     raise SanicException( | ||||||
|  |                         "Cannot perform touchup on non-existent method: " | ||||||
|  |                         f"{name}.{method}" | ||||||
|  |                     ) | ||||||
|  |                 TouchUp.register(gen_class, method) | ||||||
|  |  | ||||||
|  |         return gen_class | ||||||
							
								
								
									
										5
									
								
								sanic/touchup/schemes/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								sanic/touchup/schemes/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | from .base import BaseScheme | ||||||
|  | from .ode import OptionalDispatchEvent  # noqa | ||||||
|  |  | ||||||
|  |  | ||||||
|  | __all__ = ("BaseScheme",) | ||||||
							
								
								
									
										20
									
								
								sanic/touchup/schemes/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								sanic/touchup/schemes/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | from abc import ABC, abstractmethod | ||||||
|  | from typing import Set, Type | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BaseScheme(ABC): | ||||||
|  |     ident: str | ||||||
|  |     _registry: Set[Type] = set() | ||||||
|  |  | ||||||
|  |     def __init__(self, app) -> None: | ||||||
|  |         self.app = app | ||||||
|  |  | ||||||
|  |     @abstractmethod | ||||||
|  |     def run(self, method, module_globals) -> None: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     def __init_subclass__(cls): | ||||||
|  |         BaseScheme._registry.add(cls) | ||||||
|  |  | ||||||
|  |     def __call__(self, method, module_globals): | ||||||
|  |         return self.run(method, module_globals) | ||||||
							
								
								
									
										67
									
								
								sanic/touchup/schemes/ode.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								sanic/touchup/schemes/ode.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | |||||||
|  | from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse | ||||||
|  | from inspect import getsource | ||||||
|  | from textwrap import dedent | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
|  | from sanic.log import logger | ||||||
|  |  | ||||||
|  | from .base import BaseScheme | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OptionalDispatchEvent(BaseScheme): | ||||||
|  |     ident = "ODE" | ||||||
|  |  | ||||||
|  |     def __init__(self, app) -> None: | ||||||
|  |         super().__init__(app) | ||||||
|  |  | ||||||
|  |         self._registered_events = [ | ||||||
|  |             signal.path for signal in app.signal_router.routes | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |     def run(self, method, module_globals): | ||||||
|  |         raw_source = getsource(method) | ||||||
|  |         src = dedent(raw_source) | ||||||
|  |         tree = parse(src) | ||||||
|  |         node = RemoveDispatch(self._registered_events).visit(tree) | ||||||
|  |         compiled_src = compile(node, method.__name__, "exec") | ||||||
|  |         exec_locals: Dict[str, Any] = {} | ||||||
|  |         exec(compiled_src, module_globals, exec_locals)  # nosec | ||||||
|  |  | ||||||
|  |         return exec_locals[method.__name__] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RemoveDispatch(NodeTransformer): | ||||||
|  |     def __init__(self, registered_events) -> None: | ||||||
|  |         self._registered_events = registered_events | ||||||
|  |  | ||||||
|  |     def visit_Expr(self, node: Expr) -> Any: | ||||||
|  |         call = node.value | ||||||
|  |         if isinstance(call, Await): | ||||||
|  |             call = call.value | ||||||
|  |  | ||||||
|  |         func = getattr(call, "func", None) | ||||||
|  |         args = getattr(call, "args", None) | ||||||
|  |         if not func or not args: | ||||||
|  |             return node | ||||||
|  |  | ||||||
|  |         if isinstance(func, Attribute) and func.attr == "dispatch": | ||||||
|  |             event = args[0] | ||||||
|  |             if hasattr(event, "s"): | ||||||
|  |                 event_name = getattr(event, "value", event.s) | ||||||
|  |                 if self._not_registered(event_name): | ||||||
|  |                     logger.debug(f"Disabling event: {event_name}") | ||||||
|  |                     return None | ||||||
|  |         return node | ||||||
|  |  | ||||||
|  |     def _not_registered(self, event_name): | ||||||
|  |         dynamic = [] | ||||||
|  |         for event in self._registered_events: | ||||||
|  |             if event.endswith(">"): | ||||||
|  |                 namespace_concern, _ = event.rsplit(".", 1) | ||||||
|  |                 dynamic.append(namespace_concern) | ||||||
|  |  | ||||||
|  |         namespace_concern, _ = event_name.rsplit(".", 1) | ||||||
|  |         return ( | ||||||
|  |             event_name not in self._registered_events | ||||||
|  |             and namespace_concern not in dynamic | ||||||
|  |         ) | ||||||
							
								
								
									
										33
									
								
								sanic/touchup/service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								sanic/touchup/service.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | |||||||
|  | from inspect import getmembers, getmodule | ||||||
|  | from typing import Set, Tuple, Type | ||||||
|  |  | ||||||
|  | from .schemes import BaseScheme | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TouchUp: | ||||||
|  |     _registry: Set[Tuple[Type, str]] = set() | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def run(cls, app): | ||||||
|  |         for target, method_name in cls._registry: | ||||||
|  |             method = getattr(target, method_name) | ||||||
|  |  | ||||||
|  |             if app.test_mode: | ||||||
|  |                 placeholder = f"_{method_name}" | ||||||
|  |                 if hasattr(target, placeholder): | ||||||
|  |                     method = getattr(target, placeholder) | ||||||
|  |                 else: | ||||||
|  |                     setattr(target, placeholder, method) | ||||||
|  |  | ||||||
|  |             module = getmodule(target) | ||||||
|  |             module_globals = dict(getmembers(module)) | ||||||
|  |  | ||||||
|  |             for scheme in BaseScheme._registry: | ||||||
|  |                 modified = scheme(app)(method, module_globals) | ||||||
|  |                 setattr(target, method_name, modified) | ||||||
|  |  | ||||||
|  |             target.__touched__ = True | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def register(cls, target, method_name): | ||||||
|  |         cls._registry.add((target, method_name)) | ||||||
| @@ -13,6 +13,7 @@ from warnings import warn | |||||||
|  |  | ||||||
| from sanic.constants import HTTP_METHODS | from sanic.constants import HTTP_METHODS | ||||||
| from sanic.exceptions import InvalidUsage | from sanic.exceptions import InvalidUsage | ||||||
|  | from sanic.models.handler_types import RouteHandler | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
| @@ -86,7 +87,7 @@ class HTTPMethodView: | |||||||
|         return handler(request, *args, **kwargs) |         return handler(request, *args, **kwargs) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def as_view(cls, *class_args, **class_kwargs): |     def as_view(cls, *class_args: Any, **class_kwargs: Any) -> RouteHandler: | ||||||
|         """Return view function for use with the routing system, that |         """Return view function for use with the routing system, that | ||||||
|         dispatches request to appropriate handler method. |         dispatches request to appropriate handler method. | ||||||
|         """ |         """ | ||||||
| @@ -100,7 +101,7 @@ class HTTPMethodView: | |||||||
|             for decorator in cls.decorators: |             for decorator in cls.decorators: | ||||||
|                 view = decorator(view) |                 view = decorator(view) | ||||||
|  |  | ||||||
|         view.view_class = cls |         view.view_class = cls  # type: ignore | ||||||
|         view.__doc__ = cls.__doc__ |         view.__doc__ = cls.__doc__ | ||||||
|         view.__module__ = cls.__module__ |         view.__module__ = cls.__module__ | ||||||
|         view.__name__ = cls.__name__ |         view.__name__ = cls.__name__ | ||||||
|   | |||||||
| @@ -1,205 +0,0 @@ | |||||||
| from typing import ( |  | ||||||
|     Any, |  | ||||||
|     Awaitable, |  | ||||||
|     Callable, |  | ||||||
|     Dict, |  | ||||||
|     List, |  | ||||||
|     MutableMapping, |  | ||||||
|     Optional, |  | ||||||
|     Union, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| from httptools import HttpParserUpgrade  # type: ignore |  | ||||||
| from websockets import (  # type: ignore |  | ||||||
|     ConnectionClosed, |  | ||||||
|     InvalidHandshake, |  | ||||||
|     WebSocketCommonProtocol, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| # Despite the "legacy" namespace, the primary maintainer of websockets |  | ||||||
| # committed to maintaining backwards-compatibility until 2026 and will |  | ||||||
| # consider extending it if sanic continues depending on this module. |  | ||||||
| from websockets.legacy import handshake |  | ||||||
|  |  | ||||||
| from sanic.exceptions import InvalidUsage |  | ||||||
| from sanic.server import HttpProtocol |  | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"] |  | ||||||
|  |  | ||||||
| ASIMessage = MutableMapping[str, Any] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class WebSocketProtocol(HttpProtocol): |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         *args, |  | ||||||
|         websocket_timeout=10, |  | ||||||
|         websocket_max_size=None, |  | ||||||
|         websocket_max_queue=None, |  | ||||||
|         websocket_read_limit=2 ** 16, |  | ||||||
|         websocket_write_limit=2 ** 16, |  | ||||||
|         websocket_ping_interval=20, |  | ||||||
|         websocket_ping_timeout=20, |  | ||||||
|         **kwargs, |  | ||||||
|     ): |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
|         self.websocket = None |  | ||||||
|         # self.app = None |  | ||||||
|         self.websocket_timeout = websocket_timeout |  | ||||||
|         self.websocket_max_size = websocket_max_size |  | ||||||
|         self.websocket_max_queue = websocket_max_queue |  | ||||||
|         self.websocket_read_limit = websocket_read_limit |  | ||||||
|         self.websocket_write_limit = websocket_write_limit |  | ||||||
|         self.websocket_ping_interval = websocket_ping_interval |  | ||||||
|         self.websocket_ping_timeout = websocket_ping_timeout |  | ||||||
|  |  | ||||||
|     # timeouts make no sense for websocket routes |  | ||||||
|     def request_timeout_callback(self): |  | ||||||
|         if self.websocket is None: |  | ||||||
|             super().request_timeout_callback() |  | ||||||
|  |  | ||||||
|     def response_timeout_callback(self): |  | ||||||
|         if self.websocket is None: |  | ||||||
|             super().response_timeout_callback() |  | ||||||
|  |  | ||||||
|     def keep_alive_timeout_callback(self): |  | ||||||
|         if self.websocket is None: |  | ||||||
|             super().keep_alive_timeout_callback() |  | ||||||
|  |  | ||||||
|     def connection_lost(self, exc): |  | ||||||
|         if self.websocket is not None: |  | ||||||
|             self.websocket.connection_lost(exc) |  | ||||||
|         super().connection_lost(exc) |  | ||||||
|  |  | ||||||
|     def data_received(self, data): |  | ||||||
|         if self.websocket is not None: |  | ||||||
|             # pass the data to the websocket protocol |  | ||||||
|             self.websocket.data_received(data) |  | ||||||
|         else: |  | ||||||
|             try: |  | ||||||
|                 super().data_received(data) |  | ||||||
|             except HttpParserUpgrade: |  | ||||||
|                 # this is okay, it just indicates we've got an upgrade request |  | ||||||
|                 pass |  | ||||||
|  |  | ||||||
|     def write_response(self, response): |  | ||||||
|         if self.websocket is not None: |  | ||||||
|             # websocket requests do not write a response |  | ||||||
|             self.transport.close() |  | ||||||
|         else: |  | ||||||
|             super().write_response(response) |  | ||||||
|  |  | ||||||
|     async def websocket_handshake(self, request, subprotocols=None): |  | ||||||
|         # let the websockets package do the handshake with the client |  | ||||||
|         headers = {} |  | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             key = handshake.check_request(request.headers) |  | ||||||
|             handshake.build_response(headers, key) |  | ||||||
|         except InvalidHandshake: |  | ||||||
|             raise InvalidUsage("Invalid websocket request") |  | ||||||
|  |  | ||||||
|         subprotocol = None |  | ||||||
|         if subprotocols and "Sec-Websocket-Protocol" in request.headers: |  | ||||||
|             # select a subprotocol |  | ||||||
|             client_subprotocols = [ |  | ||||||
|                 p.strip() |  | ||||||
|                 for p in request.headers["Sec-Websocket-Protocol"].split(",") |  | ||||||
|             ] |  | ||||||
|             for p in client_subprotocols: |  | ||||||
|                 if p in subprotocols: |  | ||||||
|                     subprotocol = p |  | ||||||
|                     headers["Sec-Websocket-Protocol"] = subprotocol |  | ||||||
|                     break |  | ||||||
|  |  | ||||||
|         # write the 101 response back to the client |  | ||||||
|         rv = b"HTTP/1.1 101 Switching Protocols\r\n" |  | ||||||
|         for k, v in headers.items(): |  | ||||||
|             rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" |  | ||||||
|         rv += b"\r\n" |  | ||||||
|         request.transport.write(rv) |  | ||||||
|  |  | ||||||
|         # hook up the websocket protocol |  | ||||||
|         self.websocket = WebSocketCommonProtocol( |  | ||||||
|             close_timeout=self.websocket_timeout, |  | ||||||
|             max_size=self.websocket_max_size, |  | ||||||
|             max_queue=self.websocket_max_queue, |  | ||||||
|             read_limit=self.websocket_read_limit, |  | ||||||
|             write_limit=self.websocket_write_limit, |  | ||||||
|             ping_interval=self.websocket_ping_interval, |  | ||||||
|             ping_timeout=self.websocket_ping_timeout, |  | ||||||
|         ) |  | ||||||
|         # we use WebSocketCommonProtocol because we don't want the handshake |  | ||||||
|         # logic from WebSocketServerProtocol; however, we must tell it that |  | ||||||
|         # we're running on the server side |  | ||||||
|         self.websocket.is_client = False |  | ||||||
|         self.websocket.side = "server" |  | ||||||
|         self.websocket.subprotocol = subprotocol |  | ||||||
|         self.websocket.connection_made(request.transport) |  | ||||||
|         self.websocket.connection_open() |  | ||||||
|         return self.websocket |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class WebSocketConnection: |  | ||||||
|  |  | ||||||
|     # TODO |  | ||||||
|     # - Implement ping/pong |  | ||||||
|  |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         send: Callable[[ASIMessage], Awaitable[None]], |  | ||||||
|         receive: Callable[[], Awaitable[ASIMessage]], |  | ||||||
|         subprotocols: Optional[List[str]] = None, |  | ||||||
|     ) -> None: |  | ||||||
|         self._send = send |  | ||||||
|         self._receive = receive |  | ||||||
|         self._subprotocols = subprotocols or [] |  | ||||||
|  |  | ||||||
|     async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: |  | ||||||
|         message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} |  | ||||||
|  |  | ||||||
|         if isinstance(data, bytes): |  | ||||||
|             message.update({"bytes": data}) |  | ||||||
|         else: |  | ||||||
|             message.update({"text": str(data)}) |  | ||||||
|  |  | ||||||
|         await self._send(message) |  | ||||||
|  |  | ||||||
|     async def recv(self, *args, **kwargs) -> Optional[str]: |  | ||||||
|         message = await self._receive() |  | ||||||
|  |  | ||||||
|         if message["type"] == "websocket.receive": |  | ||||||
|             return message["text"] |  | ||||||
|         elif message["type"] == "websocket.disconnect": |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|     receive = recv |  | ||||||
|  |  | ||||||
|     async def accept(self, subprotocols: Optional[List[str]] = None) -> None: |  | ||||||
|         subprotocol = None |  | ||||||
|         if subprotocols: |  | ||||||
|             for subp in subprotocols: |  | ||||||
|                 if subp in self.subprotocols: |  | ||||||
|                     subprotocol = subp |  | ||||||
|                     break |  | ||||||
|  |  | ||||||
|         await self._send( |  | ||||||
|             { |  | ||||||
|                 "type": "websocket.accept", |  | ||||||
|                 "subprotocol": subprotocol, |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     async def close(self) -> None: |  | ||||||
|         pass |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def subprotocols(self): |  | ||||||
|         return self._subprotocols |  | ||||||
|  |  | ||||||
|     @subprotocols.setter |  | ||||||
|     def subprotocols(self, subprotocols: Optional[List[str]] = None): |  | ||||||
|         self._subprotocols = subprotocols or [] |  | ||||||
| @@ -8,8 +8,8 @@ import traceback | |||||||
| from gunicorn.workers import base  # type: ignore | from gunicorn.workers import base  # type: ignore | ||||||
|  |  | ||||||
| from sanic.log import logger | from sanic.log import logger | ||||||
| from sanic.server import HttpProtocol, Signal, serve, trigger_events | from sanic.server import HttpProtocol, Signal, serve | ||||||
| from sanic.websocket import WebSocketProtocol | from sanic.server.protocols.websocket_protocol import WebSocketProtocol | ||||||
|  |  | ||||||
|  |  | ||||||
| try: | try: | ||||||
| @@ -68,10 +68,10 @@ class GunicornWorker(base.Worker): | |||||||
|         ) |         ) | ||||||
|         self._server_settings["signal"] = self.signal |         self._server_settings["signal"] = self.signal | ||||||
|         self._server_settings.pop("sock") |         self._server_settings.pop("sock") | ||||||
|         trigger_events( |         self._await(self.app.callable._startup()) | ||||||
|             self._server_settings.get("before_start", []), self.loop |         self._await( | ||||||
|  |             self.app.callable._server_event("init", "before", loop=self.loop) | ||||||
|         ) |         ) | ||||||
|         self._server_settings["before_start"] = () |  | ||||||
|  |  | ||||||
|         main_start = self._server_settings.pop("main_start", None) |         main_start = self._server_settings.pop("main_start", None) | ||||||
|         main_stop = self._server_settings.pop("main_stop", None) |         main_stop = self._server_settings.pop("main_stop", None) | ||||||
| @@ -82,24 +82,29 @@ class GunicornWorker(base.Worker): | |||||||
|                 "with GunicornWorker" |                 "with GunicornWorker" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         self._runner = asyncio.ensure_future(self._run(), loop=self.loop) |  | ||||||
|         try: |         try: | ||||||
|             self.loop.run_until_complete(self._runner) |             self._await(self._run()) | ||||||
|             self.app.callable.is_running = True |             self.app.callable.is_running = True | ||||||
|             trigger_events( |             self._await( | ||||||
|                 self._server_settings.get("after_start", []), self.loop |                 self.app.callable._server_event( | ||||||
|  |                     "init", "after", loop=self.loop | ||||||
|  |                 ) | ||||||
|             ) |             ) | ||||||
|             self.loop.run_until_complete(self._check_alive()) |             self.loop.run_until_complete(self._check_alive()) | ||||||
|             trigger_events( |             self._await( | ||||||
|                 self._server_settings.get("before_stop", []), self.loop |                 self.app.callable._server_event( | ||||||
|  |                     "shutdown", "before", loop=self.loop | ||||||
|  |                 ) | ||||||
|             ) |             ) | ||||||
|             self.loop.run_until_complete(self.close()) |             self.loop.run_until_complete(self.close()) | ||||||
|         except BaseException: |         except BaseException: | ||||||
|             traceback.print_exc() |             traceback.print_exc() | ||||||
|         finally: |         finally: | ||||||
|             try: |             try: | ||||||
|                 trigger_events( |                 self._await( | ||||||
|                     self._server_settings.get("after_stop", []), self.loop |                     self.app.callable._server_event( | ||||||
|  |                         "shutdown", "after", loop=self.loop | ||||||
|  |                     ) | ||||||
|                 ) |                 ) | ||||||
|             except BaseException: |             except BaseException: | ||||||
|                 traceback.print_exc() |                 traceback.print_exc() | ||||||
| @@ -137,14 +142,11 @@ class GunicornWorker(base.Worker): | |||||||
|  |  | ||||||
|             # Force close non-idle connection after waiting for |             # Force close non-idle connection after waiting for | ||||||
|             # graceful_shutdown_timeout |             # graceful_shutdown_timeout | ||||||
|             coros = [] |  | ||||||
|             for conn in self.connections: |             for conn in self.connections: | ||||||
|                 if hasattr(conn, "websocket") and conn.websocket: |                 if hasattr(conn, "websocket") and conn.websocket: | ||||||
|                     coros.append(conn.websocket.close_connection()) |                     conn.websocket.fail_connection(code=1001) | ||||||
|                 else: |                 else: | ||||||
|                     conn.close() |                     conn.abort() | ||||||
|             _shutdown = asyncio.gather(*coros, loop=self.loop) |  | ||||||
|             await _shutdown |  | ||||||
|  |  | ||||||
|     async def _run(self): |     async def _run(self): | ||||||
|         for sock in self.sockets: |         for sock in self.sockets: | ||||||
| @@ -238,3 +240,7 @@ class GunicornWorker(base.Worker): | |||||||
|         self.exit_code = 1 |         self.exit_code = 1 | ||||||
|         self.cfg.worker_abort(self) |         self.cfg.worker_abort(self) | ||||||
|         sys.exit(1) |         sys.exit(1) | ||||||
|  |  | ||||||
|  |     def _await(self, coro): | ||||||
|  |         fut = asyncio.ensure_future(coro, loop=self.loop) | ||||||
|  |         self.loop.run_until_complete(fut) | ||||||
|   | |||||||
							
								
								
									
										35
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								setup.py
									
									
									
									
									
								
							| @@ -81,60 +81,63 @@ env_dependency = ( | |||||||
| ) | ) | ||||||
| ujson = "ujson>=1.35" + env_dependency | ujson = "ujson>=1.35" + env_dependency | ||||||
| uvloop = "uvloop>=0.5.3" + env_dependency | uvloop = "uvloop>=0.5.3" + env_dependency | ||||||
|  | types_ujson = "types-ujson" + env_dependency | ||||||
| requirements = [ | requirements = [ | ||||||
|     "sanic-routing==0.7.0", |     "sanic-routing~=0.7", | ||||||
|     "httptools>=0.0.10", |     "httptools>=0.0.10", | ||||||
|     uvloop, |     uvloop, | ||||||
|     ujson, |     ujson, | ||||||
|     "aiofiles>=0.6.0", |     "aiofiles>=0.6.0", | ||||||
|     "websockets>=9.0", |     "websockets>=10.0", | ||||||
|     "multidict>=5.0,<6.0", |     "multidict>=5.0,<6.0", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| tests_require = [ | tests_require = [ | ||||||
|     "sanic-testing>=0.6.0", |     "sanic-testing>=0.7.0", | ||||||
|     "pytest==5.2.1", |     "pytest==5.2.1", | ||||||
|     "multidict>=5.0,<6.0", |     "coverage==5.3", | ||||||
|     "gunicorn==20.0.4", |     "gunicorn==20.0.4", | ||||||
|     "pytest-cov", |     "pytest-cov", | ||||||
|     "beautifulsoup4", |     "beautifulsoup4", | ||||||
|     uvloop, |  | ||||||
|     ujson, |  | ||||||
|     "pytest-sanic", |     "pytest-sanic", | ||||||
|     "pytest-sugar", |     "pytest-sugar", | ||||||
|     "pytest-benchmark", |     "pytest-benchmark", | ||||||
|  |     "chardet==3.*", | ||||||
|  |     "flake8", | ||||||
|  |     "black", | ||||||
|  |     "isort>=5.0.0", | ||||||
|  |     "bandit", | ||||||
|  |     "mypy>=0.901", | ||||||
|  |     "docutils", | ||||||
|  |     "pygments", | ||||||
|  |     "uvicorn<0.15.0", | ||||||
|  |     types_ujson, | ||||||
| ] | ] | ||||||
|  |  | ||||||
| docs_require = [ | docs_require = [ | ||||||
|     "sphinx>=2.1.2", |     "sphinx>=2.1.2", | ||||||
|     "sphinx_rtd_theme", |     "sphinx_rtd_theme>=0.4.3", | ||||||
|     "recommonmark>=0.5.0", |  | ||||||
|     "docutils", |     "docutils", | ||||||
|     "pygments", |     "pygments", | ||||||
|  |     "m2r2", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| dev_require = tests_require + [ | dev_require = tests_require + [ | ||||||
|     "aiofiles", |  | ||||||
|     "tox", |     "tox", | ||||||
|     "black", |  | ||||||
|     "flake8", |  | ||||||
|     "bandit", |  | ||||||
|     "towncrier", |     "towncrier", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| all_require = dev_require + docs_require | all_require = list(set(dev_require + docs_require)) | ||||||
|  |  | ||||||
| if strtobool(os.environ.get("SANIC_NO_UJSON", "no")): | if strtobool(os.environ.get("SANIC_NO_UJSON", "no")): | ||||||
|     print("Installing without uJSON") |     print("Installing without uJSON") | ||||||
|     requirements.remove(ujson) |     requirements.remove(ujson) | ||||||
|     tests_require.remove(ujson) |     tests_require.remove(types_ujson) | ||||||
|  |  | ||||||
| # 'nt' means windows OS | # 'nt' means windows OS | ||||||
| if strtobool(os.environ.get("SANIC_NO_UVLOOP", "no")): | if strtobool(os.environ.get("SANIC_NO_UVLOOP", "no")): | ||||||
|     print("Installing without uvLoop") |     print("Installing without uvLoop") | ||||||
|     requirements.remove(uvloop) |     requirements.remove(uvloop) | ||||||
|     tests_require.remove(uvloop) |  | ||||||
|  |  | ||||||
| extras_require = { | extras_require = { | ||||||
|     "test": tests_require, |     "test": tests_require, | ||||||
|   | |||||||
| @@ -1,3 +1,5 @@ | |||||||
|  | import asyncio | ||||||
|  | import logging | ||||||
| import random | import random | ||||||
| import re | import re | ||||||
| import string | import string | ||||||
| @@ -9,10 +11,12 @@ from typing import Tuple | |||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from sanic_routing.exceptions import RouteExists | from sanic_routing.exceptions import RouteExists | ||||||
|  | from sanic_testing.testing import PORT | ||||||
|  |  | ||||||
| from sanic import Sanic | from sanic import Sanic | ||||||
| from sanic.constants import HTTP_METHODS | from sanic.constants import HTTP_METHODS | ||||||
| from sanic.router import Router | from sanic.router import Router | ||||||
|  | from sanic.touchup.service import TouchUp | ||||||
|  |  | ||||||
|  |  | ||||||
| slugify = re.compile(r"[^a-zA-Z0-9_\-]") | slugify = re.compile(r"[^a-zA-Z0-9_\-]") | ||||||
| @@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]: | |||||||
|     collect_ignore = ["test_worker.py"] |     collect_ignore = ["test_worker.py"] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture |  | ||||||
| def caplog(caplog): |  | ||||||
|     yield caplog |  | ||||||
|  |  | ||||||
|  |  | ||||||
| async def _handler(request): | async def _handler(request): | ||||||
|     """ |     """ | ||||||
|     Dummy placeholder method used for route resolver when creating a new |     Dummy placeholder method used for route resolver when creating a new | ||||||
| @@ -41,33 +40,32 @@ async def _handler(request): | |||||||
|  |  | ||||||
|  |  | ||||||
| TYPE_TO_GENERATOR_MAP = { | TYPE_TO_GENERATOR_MAP = { | ||||||
|     "string": lambda: "".join( |     "str": lambda: "".join( | ||||||
|         [random.choice(string.ascii_lowercase) for _ in range(4)] |         [random.choice(string.ascii_lowercase) for _ in range(4)] | ||||||
|     ), |     ), | ||||||
|     "int": lambda: random.choice(range(1000000)), |     "int": lambda: random.choice(range(1000000)), | ||||||
|     "number": lambda: random.random(), |     "float": lambda: random.random(), | ||||||
|     "alpha": lambda: "".join( |     "alpha": lambda: "".join( | ||||||
|         [random.choice(string.ascii_lowercase) for _ in range(4)] |         [random.choice(string.ascii_lowercase) for _ in range(4)] | ||||||
|     ), |     ), | ||||||
|     "uuid": lambda: str(uuid.uuid1()), |     "uuid": lambda: str(uuid.uuid1()), | ||||||
| } | } | ||||||
|  |  | ||||||
|  | CACHE = {} | ||||||
|  |  | ||||||
|  |  | ||||||
| class RouteStringGenerator: | class RouteStringGenerator: | ||||||
|  |  | ||||||
|     ROUTE_COUNT_PER_DEPTH = 100 |     ROUTE_COUNT_PER_DEPTH = 100 | ||||||
|     HTTP_METHODS = HTTP_METHODS |     HTTP_METHODS = HTTP_METHODS | ||||||
|     ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"] |     ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"] | ||||||
|  |  | ||||||
|     def generate_random_direct_route(self, max_route_depth=4): |     def generate_random_direct_route(self, max_route_depth=4): | ||||||
|         routes = [] |         routes = [] | ||||||
|         for depth in range(1, max_route_depth + 1): |         for depth in range(1, max_route_depth + 1): | ||||||
|             for _ in range(self.ROUTE_COUNT_PER_DEPTH): |             for _ in range(self.ROUTE_COUNT_PER_DEPTH): | ||||||
|                 route = "/".join( |                 route = "/".join( | ||||||
|                     [ |                     [TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)] | ||||||
|                         TYPE_TO_GENERATOR_MAP.get("string")() |  | ||||||
|                         for _ in range(depth) |  | ||||||
|                     ] |  | ||||||
|                 ) |                 ) | ||||||
|                 route = route.replace(".", "", -1) |                 route = route.replace(".", "", -1) | ||||||
|                 route_detail = (random.choice(self.HTTP_METHODS), route) |                 route_detail = (random.choice(self.HTTP_METHODS), route) | ||||||
| @@ -83,7 +81,7 @@ class RouteStringGenerator: | |||||||
|             new_route_part = "/".join( |             new_route_part = "/".join( | ||||||
|                 [ |                 [ | ||||||
|                     "<{}:{}>".format( |                     "<{}:{}>".format( | ||||||
|                         TYPE_TO_GENERATOR_MAP.get("string")(), |                         TYPE_TO_GENERATOR_MAP.get("str")(), | ||||||
|                         random.choice(self.ROUTE_PARAM_TYPES), |                         random.choice(self.ROUTE_PARAM_TYPES), | ||||||
|                     ) |                     ) | ||||||
|                     for _ in range(max_route_depth - current_length) |                     for _ in range(max_route_depth - current_length) | ||||||
| @@ -98,7 +96,7 @@ class RouteStringGenerator: | |||||||
|     def generate_url_for_template(template): |     def generate_url_for_template(template): | ||||||
|         url = template |         url = template | ||||||
|         for pattern, param_type in re.findall( |         for pattern, param_type in re.findall( | ||||||
|             re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"), |             re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"), | ||||||
|             template, |             template, | ||||||
|         ): |         ): | ||||||
|             value = TYPE_TO_GENERATOR_MAP.get(param_type)() |             value = TYPE_TO_GENERATOR_MAP.get(param_type)() | ||||||
| @@ -111,6 +109,7 @@ def sanic_router(app): | |||||||
|     # noinspection PyProtectedMember |     # noinspection PyProtectedMember | ||||||
|     def _setup(route_details: tuple) -> Tuple[Router, tuple]: |     def _setup(route_details: tuple) -> Tuple[Router, tuple]: | ||||||
|         router = Router() |         router = Router() | ||||||
|  |         router.ctx.app = app | ||||||
|         added_router = [] |         added_router = [] | ||||||
|         for method, route in route_details: |         for method, route in route_details: | ||||||
|             try: |             try: | ||||||
| @@ -141,5 +140,33 @@ def url_param_generator(): | |||||||
|  |  | ||||||
| @pytest.fixture(scope="function") | @pytest.fixture(scope="function") | ||||||
| def app(request): | def app(request): | ||||||
|  |     if not CACHE: | ||||||
|  |         for target, method_name in TouchUp._registry: | ||||||
|  |             CACHE[method_name] = getattr(target, method_name) | ||||||
|     app = Sanic(slugify.sub("-", request.node.name)) |     app = Sanic(slugify.sub("-", request.node.name)) | ||||||
|     return app |     yield app | ||||||
|  |     for target, method_name in TouchUp._registry: | ||||||
|  |         setattr(target, method_name, CACHE[method_name]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture(scope="function") | ||||||
|  | def run_startup(caplog): | ||||||
|  |     def run(app): | ||||||
|  |         nonlocal caplog | ||||||
|  |         loop = asyncio.new_event_loop() | ||||||
|  |         asyncio.set_event_loop(loop) | ||||||
|  |         with caplog.at_level(logging.DEBUG): | ||||||
|  |             server = app.create_server( | ||||||
|  |                 debug=True, return_asyncio_server=True, port=PORT | ||||||
|  |             ) | ||||||
|  |             loop._stopping = False | ||||||
|  |  | ||||||
|  |             _server = loop.run_until_complete(server) | ||||||
|  |  | ||||||
|  |             _server.close() | ||||||
|  |             loop.run_until_complete(_server.wait_closed()) | ||||||
|  |             app.stop() | ||||||
|  |  | ||||||
|  |         return caplog.record_tuples | ||||||
|  |  | ||||||
|  |     return run | ||||||
|   | |||||||
| @@ -178,9 +178,6 @@ def test_app_enable_websocket(app, websocket_enabled, enable): | |||||||
| @patch("sanic.app.WebSocketProtocol") | @patch("sanic.app.WebSocketProtocol") | ||||||
| def test_app_websocket_parameters(websocket_protocol_mock, app): | def test_app_websocket_parameters(websocket_protocol_mock, app): | ||||||
|     app.config.WEBSOCKET_MAX_SIZE = 44 |     app.config.WEBSOCKET_MAX_SIZE = 44 | ||||||
|     app.config.WEBSOCKET_MAX_QUEUE = 45 |  | ||||||
|     app.config.WEBSOCKET_READ_LIMIT = 46 |  | ||||||
|     app.config.WEBSOCKET_WRITE_LIMIT = 47 |  | ||||||
|     app.config.WEBSOCKET_PING_TIMEOUT = 48 |     app.config.WEBSOCKET_PING_TIMEOUT = 48 | ||||||
|     app.config.WEBSOCKET_PING_INTERVAL = 50 |     app.config.WEBSOCKET_PING_INTERVAL = 50 | ||||||
|  |  | ||||||
| @@ -197,11 +194,6 @@ def test_app_websocket_parameters(websocket_protocol_mock, app): | |||||||
|     websocket_protocol_call_args = websocket_protocol_mock.call_args |     websocket_protocol_call_args = websocket_protocol_mock.call_args | ||||||
|     ws_kwargs = websocket_protocol_call_args[1] |     ws_kwargs = websocket_protocol_call_args[1] | ||||||
|     assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE |     assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE | ||||||
|     assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE |  | ||||||
|     assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT |  | ||||||
|     assert ( |  | ||||||
|         ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT |  | ||||||
|     ) |  | ||||||
|     assert ( |     assert ( | ||||||
|         ws_kwargs["websocket_ping_timeout"] |         ws_kwargs["websocket_ping_timeout"] | ||||||
|         == app.config.WEBSOCKET_PING_TIMEOUT |         == app.config.WEBSOCKET_PING_TIMEOUT | ||||||
| @@ -396,7 +388,7 @@ def test_app_set_attribute_warning(app): | |||||||
|     assert len(record) == 1 |     assert len(record) == 1 | ||||||
|     assert record[0].message.args[0] == ( |     assert record[0].message.args[0] == ( | ||||||
|         "Setting variables on Sanic instances is deprecated " |         "Setting variables on Sanic instances is deprecated " | ||||||
|         "and will be removed in version 21.9. You should change your " |         "and will be removed in version 21.12. You should change your " | ||||||
|         "Sanic instance to use instance.ctx.foo instead." |         "Sanic instance to use instance.ctx.foo instead." | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,10 +7,10 @@ import uvicorn | |||||||
|  |  | ||||||
| from sanic import Sanic | from sanic import Sanic | ||||||
| from sanic.asgi import MockTransport | from sanic.asgi import MockTransport | ||||||
| from sanic.exceptions import InvalidUsage | from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable | ||||||
| from sanic.request import Request | from sanic.request import Request | ||||||
| from sanic.response import json, text | from sanic.response import json, text | ||||||
| from sanic.websocket import WebSocketConnection | from sanic.server.websockets.connection import WebSocketConnection | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @@ -346,3 +346,33 @@ async def test_content_type(app): | |||||||
|  |  | ||||||
|     _, response = await app.asgi_client.get("/custom") |     _, response = await app.asgi_client.get("/custom") | ||||||
|     assert response.headers.get("content-type") == "somethingelse" |     assert response.headers.get("content-type") == "somethingelse" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_request_handle_exception(app): | ||||||
|  |     @app.get("/error-prone") | ||||||
|  |     def _request(request): | ||||||
|  |         raise ServiceUnavailable(message="Service unavailable") | ||||||
|  |  | ||||||
|  |     _, response = await app.asgi_client.get("/wrong-path") | ||||||
|  |     assert response.status_code == 404 | ||||||
|  |  | ||||||
|  |     _, response = await app.asgi_client.get("/error-prone") | ||||||
|  |     assert response.status_code == 503 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_request_exception_suppressed_by_middleware(app): | ||||||
|  |     @app.get("/error-prone") | ||||||
|  |     def _request(request): | ||||||
|  |         raise ServiceUnavailable(message="Service unavailable") | ||||||
|  |  | ||||||
|  |     @app.on_request | ||||||
|  |     def forbidden(request): | ||||||
|  |         raise Forbidden(message="forbidden") | ||||||
|  |  | ||||||
|  |     _, response = await app.asgi_client.get("/wrong-path") | ||||||
|  |     assert response.status_code == 403 | ||||||
|  |  | ||||||
|  |     _, response = await app.asgi_client.get("/error-prone") | ||||||
|  |     assert response.status_code == 403 | ||||||
|   | |||||||
| @@ -20,4 +20,4 @@ def test_bad_request_response(app): | |||||||
|  |  | ||||||
|     app.run(host="127.0.0.1", port=42101, debug=False) |     app.run(host="127.0.0.1", port=42101, debug=False) | ||||||
|     assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n" |     assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n" | ||||||
|     assert b"Bad Request" in lines[-1] |     assert b"Bad Request" in lines[-2] | ||||||
|   | |||||||
							
								
								
									
										70
									
								
								tests/test_blueprint_copy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								tests/test_blueprint_copy.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | |||||||
|  | from copy import deepcopy | ||||||
|  |  | ||||||
|  | from sanic import Blueprint, Sanic, blueprints, response | ||||||
|  | from sanic.response import text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_bp_copy(app: Sanic): | ||||||
|  |     bp1 = Blueprint("test_bp1", version=1) | ||||||
|  |     bp1.ctx.test = 1 | ||||||
|  |     assert hasattr(bp1.ctx, "test") | ||||||
|  |  | ||||||
|  |     @bp1.route("/page") | ||||||
|  |     def handle_request(request): | ||||||
|  |         return text("Hello world!") | ||||||
|  |  | ||||||
|  |     bp2 = bp1.copy(name="test_bp2", version=2) | ||||||
|  |     assert id(bp1) != id(bp2) | ||||||
|  |     assert bp1._apps == bp2._apps == set() | ||||||
|  |     assert not hasattr(bp2.ctx, "test") | ||||||
|  |     assert len(bp2._future_exceptions) == len(bp1._future_exceptions) | ||||||
|  |     assert len(bp2._future_listeners) == len(bp1._future_listeners) | ||||||
|  |     assert len(bp2._future_middleware) == len(bp1._future_middleware) | ||||||
|  |     assert len(bp2._future_routes) == len(bp1._future_routes) | ||||||
|  |     assert len(bp2._future_signals) == len(bp1._future_signals) | ||||||
|  |  | ||||||
|  |     app.blueprint(bp1) | ||||||
|  |     app.blueprint(bp2) | ||||||
|  |  | ||||||
|  |     bp3 = bp1.copy(name="test_bp3", version=3, with_registration=True) | ||||||
|  |     assert id(bp1) != id(bp3) | ||||||
|  |     assert bp1._apps == bp3._apps and bp3._apps | ||||||
|  |     assert not hasattr(bp3.ctx, "test") | ||||||
|  |  | ||||||
|  |     bp4 = bp1.copy(name="test_bp4", version=4, with_ctx=True) | ||||||
|  |     assert id(bp1) != id(bp4) | ||||||
|  |     assert bp4.ctx.test == 1 | ||||||
|  |  | ||||||
|  |     bp5 = bp1.copy(name="test_bp5", version=5, with_registration=False) | ||||||
|  |     assert id(bp1) != id(bp5) | ||||||
|  |     assert not bp5._apps | ||||||
|  |     assert bp1._apps != set() | ||||||
|  |  | ||||||
|  |     app.blueprint(bp5) | ||||||
|  |  | ||||||
|  |     bp6 = bp1.copy( | ||||||
|  |         name="test_bp6", | ||||||
|  |         version=6, | ||||||
|  |         with_registration=True, | ||||||
|  |         version_prefix="/version", | ||||||
|  |     ) | ||||||
|  |     assert bp6._apps | ||||||
|  |     assert bp6.version_prefix == "/version" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/v1/page") | ||||||
|  |     assert "Hello world!" in response.text | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/v2/page") | ||||||
|  |     assert "Hello world!" in response.text | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/v3/page") | ||||||
|  |     assert "Hello world!" in response.text | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/v4/page") | ||||||
|  |     assert "Hello world!" in response.text | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/v5/page") | ||||||
|  |     assert "Hello world!" in response.text | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/version6/page") | ||||||
|  |     assert "Hello world!" in response.text | ||||||
| @@ -3,6 +3,12 @@ from pytest import raises | |||||||
| from sanic.app import Sanic | from sanic.app import Sanic | ||||||
| from sanic.blueprint_group import BlueprintGroup | from sanic.blueprint_group import BlueprintGroup | ||||||
| from sanic.blueprints import Blueprint | from sanic.blueprints import Blueprint | ||||||
|  | from sanic.exceptions import ( | ||||||
|  |     Forbidden, | ||||||
|  |     InvalidUsage, | ||||||
|  |     SanicException, | ||||||
|  |     ServerError, | ||||||
|  | ) | ||||||
| from sanic.request import Request | from sanic.request import Request | ||||||
| from sanic.response import HTTPResponse, text | from sanic.response import HTTPResponse, text | ||||||
|  |  | ||||||
| @@ -96,16 +102,28 @@ def test_bp_group(app: Sanic): | |||||||
|     def blueprint_1_default_route(request): |     def blueprint_1_default_route(request): | ||||||
|         return text("BP1_OK") |         return text("BP1_OK") | ||||||
|  |  | ||||||
|  |     @blueprint_1.route("/invalid") | ||||||
|  |     def blueprint_1_error(request: Request): | ||||||
|  |         raise InvalidUsage("Invalid") | ||||||
|  |  | ||||||
|     @blueprint_2.route("/") |     @blueprint_2.route("/") | ||||||
|     def blueprint_2_default_route(request): |     def blueprint_2_default_route(request): | ||||||
|         return text("BP2_OK") |         return text("BP2_OK") | ||||||
|  |  | ||||||
|  |     @blueprint_2.route("/error") | ||||||
|  |     def blueprint_2_error(request: Request): | ||||||
|  |         raise ServerError("Error") | ||||||
|  |  | ||||||
|     blueprint_group_1 = Blueprint.group( |     blueprint_group_1 = Blueprint.group( | ||||||
|         blueprint_1, blueprint_2, url_prefix="/bp" |         blueprint_1, blueprint_2, url_prefix="/bp" | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3") |     blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3") | ||||||
|  |  | ||||||
|  |     @blueprint_group_1.exception(InvalidUsage) | ||||||
|  |     def handle_group_exception(request, exception): | ||||||
|  |         return text("BP1_ERR_OK") | ||||||
|  |  | ||||||
|     @blueprint_group_1.middleware("request") |     @blueprint_group_1.middleware("request") | ||||||
|     def blueprint_group_1_middleware(request): |     def blueprint_group_1_middleware(request): | ||||||
|         global MIDDLEWARE_INVOKE_COUNTER |         global MIDDLEWARE_INVOKE_COUNTER | ||||||
| @@ -116,19 +134,47 @@ def test_bp_group(app: Sanic): | |||||||
|         global MIDDLEWARE_INVOKE_COUNTER |         global MIDDLEWARE_INVOKE_COUNTER | ||||||
|         MIDDLEWARE_INVOKE_COUNTER["request"] += 1 |         MIDDLEWARE_INVOKE_COUNTER["request"] += 1 | ||||||
|  |  | ||||||
|  |     @blueprint_group_1.on_request | ||||||
|  |     def blueprint_group_1_convenience_1(request): | ||||||
|  |         global MIDDLEWARE_INVOKE_COUNTER | ||||||
|  |         MIDDLEWARE_INVOKE_COUNTER["request"] += 1 | ||||||
|  |  | ||||||
|  |     @blueprint_group_1.on_request() | ||||||
|  |     def blueprint_group_1_convenience_2(request): | ||||||
|  |         global MIDDLEWARE_INVOKE_COUNTER | ||||||
|  |         MIDDLEWARE_INVOKE_COUNTER["request"] += 1 | ||||||
|  |  | ||||||
|     @blueprint_3.route("/") |     @blueprint_3.route("/") | ||||||
|     def blueprint_3_default_route(request): |     def blueprint_3_default_route(request): | ||||||
|         return text("BP3_OK") |         return text("BP3_OK") | ||||||
|  |  | ||||||
|  |     @blueprint_3.route("/forbidden") | ||||||
|  |     def blueprint_3_forbidden(request: Request): | ||||||
|  |         raise Forbidden("Forbidden") | ||||||
|  |  | ||||||
|     blueprint_group_2 = Blueprint.group( |     blueprint_group_2 = Blueprint.group( | ||||||
|         blueprint_group_1, blueprint_3, url_prefix="/api" |         blueprint_group_1, blueprint_3, url_prefix="/api" | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     @blueprint_group_2.exception(SanicException) | ||||||
|  |     def handle_non_handled_exception(request, exception): | ||||||
|  |         return text("BP2_ERR_OK") | ||||||
|  |  | ||||||
|     @blueprint_group_2.middleware("response") |     @blueprint_group_2.middleware("response") | ||||||
|     def blueprint_group_2_middleware(request, response): |     def blueprint_group_2_middleware(request, response): | ||||||
|         global MIDDLEWARE_INVOKE_COUNTER |         global MIDDLEWARE_INVOKE_COUNTER | ||||||
|         MIDDLEWARE_INVOKE_COUNTER["response"] += 1 |         MIDDLEWARE_INVOKE_COUNTER["response"] += 1 | ||||||
|  |  | ||||||
|  |     @blueprint_group_2.on_response | ||||||
|  |     def blueprint_group_2_middleware_convenience_1(request, response): | ||||||
|  |         global MIDDLEWARE_INVOKE_COUNTER | ||||||
|  |         MIDDLEWARE_INVOKE_COUNTER["response"] += 1 | ||||||
|  |  | ||||||
|  |     @blueprint_group_2.on_response() | ||||||
|  |     def blueprint_group_2_middleware_convenience_2(request, response): | ||||||
|  |         global MIDDLEWARE_INVOKE_COUNTER | ||||||
|  |         MIDDLEWARE_INVOKE_COUNTER["response"] += 1 | ||||||
|  |  | ||||||
|     app.blueprint(blueprint_group_2) |     app.blueprint(blueprint_group_2) | ||||||
|  |  | ||||||
|     @app.route("/") |     @app.route("/") | ||||||
| @@ -141,14 +187,23 @@ def test_bp_group(app: Sanic): | |||||||
|     _, response = app.test_client.get("/api/bp/bp1") |     _, response = app.test_client.get("/api/bp/bp1") | ||||||
|     assert response.text == "BP1_OK" |     assert response.text == "BP1_OK" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/api/bp/bp1/invalid") | ||||||
|  |     assert response.text == "BP1_ERR_OK" | ||||||
|  |  | ||||||
|     _, response = app.test_client.get("/api/bp/bp2") |     _, response = app.test_client.get("/api/bp/bp2") | ||||||
|     assert response.text == "BP2_OK" |     assert response.text == "BP2_OK" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/api/bp/bp2/error") | ||||||
|  |     assert response.text == "BP2_ERR_OK" | ||||||
|  |  | ||||||
|     _, response = app.test_client.get("/api/bp3") |     _, response = app.test_client.get("/api/bp3") | ||||||
|     assert response.text == "BP3_OK" |     assert response.text == "BP3_OK" | ||||||
|  |  | ||||||
|     assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3 |     _, response = app.test_client.get("/api/bp3/forbidden") | ||||||
|     assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4 |     assert response.text == "BP2_ERR_OK" | ||||||
|  |  | ||||||
|  |     assert MIDDLEWARE_INVOKE_COUNTER["response"] == 18 | ||||||
|  |     assert MIDDLEWARE_INVOKE_COUNTER["request"] == 16 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_bp_group_list_operations(app: Sanic): | def test_bp_group_list_operations(app: Sanic): | ||||||
|   | |||||||
| @@ -83,7 +83,6 @@ def test_versioned_routes_get(app, method): | |||||||
|             return text("OK") |             return text("OK") | ||||||
|  |  | ||||||
|     else: |     else: | ||||||
|         print(func) |  | ||||||
|         raise Exception(f"{func} is not callable") |         raise Exception(f"{func} is not callable") | ||||||
|  |  | ||||||
|     app.blueprint(bp) |     app.blueprint(bp) | ||||||
| @@ -477,6 +476,58 @@ def test_bp_exception_handler(app): | |||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_bp_exception_handler_applied(app): | ||||||
|  |     class Error(Exception): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     handled = Blueprint("handled") | ||||||
|  |     nothandled = Blueprint("nothandled") | ||||||
|  |  | ||||||
|  |     @handled.exception(Error) | ||||||
|  |     def handle_error(req, e): | ||||||
|  |         return text("handled {}".format(e)) | ||||||
|  |  | ||||||
|  |     @handled.route("/ok") | ||||||
|  |     def ok(request): | ||||||
|  |         raise Error("uh oh") | ||||||
|  |  | ||||||
|  |     @nothandled.route("/notok") | ||||||
|  |     def notok(request): | ||||||
|  |         raise Error("uh oh") | ||||||
|  |  | ||||||
|  |     app.blueprint(handled) | ||||||
|  |     app.blueprint(nothandled) | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/ok") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.text == "handled uh oh" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/notok") | ||||||
|  |     assert response.status == 500 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_bp_exception_handler_not_applied(app): | ||||||
|  |     class Error(Exception): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     handled = Blueprint("handled") | ||||||
|  |     nothandled = Blueprint("nothandled") | ||||||
|  |  | ||||||
|  |     @handled.exception(Error) | ||||||
|  |     def handle_error(req, e): | ||||||
|  |         return text("handled {}".format(e)) | ||||||
|  |  | ||||||
|  |     @nothandled.route("/notok") | ||||||
|  |     def notok(request): | ||||||
|  |         raise Error("uh oh") | ||||||
|  |  | ||||||
|  |     app.blueprint(handled) | ||||||
|  |     app.blueprint(nothandled) | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/notok") | ||||||
|  |     assert response.status == 500 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_bp_listeners(app): | def test_bp_listeners(app): | ||||||
|     app.route("/")(lambda x: x) |     app.route("/")(lambda x: x) | ||||||
|     blueprint = Blueprint("test_middleware") |     blueprint = Blueprint("test_middleware") | ||||||
| @@ -1034,6 +1085,6 @@ def test_bp_set_attribute_warning(): | |||||||
|     assert len(record) == 1 |     assert len(record) == 1 | ||||||
|     assert record[0].message.args[0] == ( |     assert record[0].message.args[0] == ( | ||||||
|         "Setting variables on Blueprint instances is deprecated " |         "Setting variables on Blueprint instances is deprecated " | ||||||
|         "and will be removed in version 21.9. You should change your " |         "and will be removed in version 21.12. You should change your " | ||||||
|         "Blueprint instance to use instance.ctx.foo instead." |         "Blueprint instance to use instance.ctx.foo instead." | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -89,7 +89,7 @@ def test_debug(cmd): | |||||||
|     out, err, exitcode = capture(command) |     out, err, exitcode = capture(command) | ||||||
|     lines = out.split(b"\n") |     lines = out.split(b"\n") | ||||||
|  |  | ||||||
|     app_info = lines[9] |     app_info = lines[26] | ||||||
|     info = json.loads(app_info) |     info = json.loads(app_info) | ||||||
|  |  | ||||||
|     assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO |     assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO | ||||||
| @@ -103,7 +103,7 @@ def test_auto_reload(cmd): | |||||||
|     out, err, exitcode = capture(command) |     out, err, exitcode = capture(command) | ||||||
|     lines = out.split(b"\n") |     lines = out.split(b"\n") | ||||||
|  |  | ||||||
|     app_info = lines[9] |     app_info = lines[26] | ||||||
|     info = json.loads(app_info) |     info = json.loads(app_info) | ||||||
|  |  | ||||||
|     assert info["debug"] is False |     assert info["debug"] is False | ||||||
| @@ -118,7 +118,7 @@ def test_access_logs(cmd, expected): | |||||||
|     out, err, exitcode = capture(command) |     out, err, exitcode = capture(command) | ||||||
|     lines = out.split(b"\n") |     lines = out.split(b"\n") | ||||||
|  |  | ||||||
|     app_info = lines[9] |     app_info = lines[26] | ||||||
|     info = json.loads(app_info) |     info = json.loads(app_info) | ||||||
|  |  | ||||||
|     assert info["access_log"] is expected |     assert info["access_log"] is expected | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ from sanic.exceptions import PyFileError | |||||||
|  |  | ||||||
| @contextmanager | @contextmanager | ||||||
| def temp_path(): | def temp_path(): | ||||||
|     """ a simple cross platform replacement for NamedTemporaryFile """ |     """a simple cross platform replacement for NamedTemporaryFile""" | ||||||
|     with TemporaryDirectory() as td: |     with TemporaryDirectory() as td: | ||||||
|         yield Path(td, "file") |         yield Path(td, "file") | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,4 @@ | |||||||
| from crypt import methods | from sanic import Sanic, text | ||||||
|  |  | ||||||
| from sanic import text |  | ||||||
| from sanic.constants import HTTP_METHODS, HTTPMethod | from sanic.constants import HTTP_METHODS, HTTPMethod | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -14,7 +12,7 @@ def test_string_compat(): | |||||||
|     assert HTTPMethod.GET.upper() == "GET" |     assert HTTPMethod.GET.upper() == "GET" | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_use_in_routes(app): | def test_use_in_routes(app: Sanic): | ||||||
|     @app.route("/", methods=[HTTPMethod.GET, HTTPMethod.POST]) |     @app.route("/", methods=[HTTPMethod.GET, HTTPMethod.POST]) | ||||||
|     def handler(_): |     def handler(_): | ||||||
|         return text("It works") |         return text("It works") | ||||||
|   | |||||||
| @@ -1,6 +1,5 @@ | |||||||
| import asyncio | import asyncio | ||||||
|  |  | ||||||
| from queue import Queue |  | ||||||
| from threading import Event | from threading import Event | ||||||
|  |  | ||||||
| from sanic.response import text | from sanic.response import text | ||||||
| @@ -13,8 +12,6 @@ def test_create_task(app): | |||||||
|         await asyncio.sleep(0.05) |         await asyncio.sleep(0.05) | ||||||
|         e.set() |         e.set() | ||||||
|  |  | ||||||
|     app.add_task(coro) |  | ||||||
|  |  | ||||||
|     @app.route("/early") |     @app.route("/early") | ||||||
|     def not_set(request): |     def not_set(request): | ||||||
|         return text(str(e.is_set())) |         return text(str(e.is_set())) | ||||||
| @@ -24,24 +21,30 @@ def test_create_task(app): | |||||||
|         await asyncio.sleep(0.1) |         await asyncio.sleep(0.1) | ||||||
|         return text(str(e.is_set())) |         return text(str(e.is_set())) | ||||||
|  |  | ||||||
|  |     app.add_task(coro) | ||||||
|  |  | ||||||
|     request, response = app.test_client.get("/early") |     request, response = app.test_client.get("/early") | ||||||
|     assert response.body == b"False" |     assert response.body == b"False" | ||||||
|  |  | ||||||
|  |     app.signal_router.reset() | ||||||
|  |     app.add_task(coro) | ||||||
|     request, response = app.test_client.get("/late") |     request, response = app.test_client.get("/late") | ||||||
|     assert response.body == b"True" |     assert response.body == b"True" | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_create_task_with_app_arg(app): | def test_create_task_with_app_arg(app): | ||||||
|     q = Queue() |     @app.after_server_start | ||||||
|  |     async def setup_q(app, _): | ||||||
|  |         app.ctx.q = asyncio.Queue() | ||||||
|  |  | ||||||
|     @app.route("/") |     @app.route("/") | ||||||
|     def not_set(request): |     async def not_set(request): | ||||||
|         return "hello" |         return text(await request.app.ctx.q.get()) | ||||||
|  |  | ||||||
|     async def coro(app): |     async def coro(app): | ||||||
|         q.put(app.name) |         await app.ctx.q.put(app.name) | ||||||
|  |  | ||||||
|     app.add_task(coro) |     app.add_task(coro) | ||||||
|  |  | ||||||
|     request, response = app.test_client.get("/") |     _, response = app.test_client.get("/") | ||||||
|     assert q.get() == "test_create_task_with_app_arg" |     assert response.text == "test_create_task_with_app_arg" | ||||||
|   | |||||||
| @@ -1,10 +1,10 @@ | |||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from sanic import Sanic | from sanic import Sanic | ||||||
| from sanic.errorpages import exception_response | from sanic.errorpages import HTMLRenderer, exception_response | ||||||
| from sanic.exceptions import NotFound | from sanic.exceptions import NotFound, SanicException | ||||||
| from sanic.request import Request | from sanic.request import Request | ||||||
| from sanic.response import HTTPResponse | from sanic.response import HTTPResponse, html, json, text | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @@ -20,7 +20,7 @@ def app(): | |||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def fake_request(app): | def fake_request(app): | ||||||
|     return Request(b"/foobar", {}, "1.1", "GET", None, app) |     return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
| @@ -47,7 +47,13 @@ def test_should_return_html_valid_setting( | |||||||
|     try: |     try: | ||||||
|         raise exception("bad stuff") |         raise exception("bad stuff") | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         response = exception_response(fake_request, e, True) |         response = exception_response( | ||||||
|  |             fake_request, | ||||||
|  |             e, | ||||||
|  |             True, | ||||||
|  |             base=HTMLRenderer, | ||||||
|  |             fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     assert isinstance(response, HTTPResponse) |     assert isinstance(response, HTTPResponse) | ||||||
|     assert response.status == status |     assert response.status == status | ||||||
| @@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app): | |||||||
|     app.config.FALLBACK_ERROR_FORMAT = "auto" |     app.config.FALLBACK_ERROR_FORMAT = "auto" | ||||||
|  |  | ||||||
|     _, response = app.test_client.get( |     _, response = app.test_client.get( | ||||||
|         "/error", headers={"content-type": "application/json"} |         "/error", headers={"content-type": "application/json", "accept": "*/*"} | ||||||
|     ) |     ) | ||||||
|     assert response.status == 500 |     assert response.status == 500 | ||||||
|     assert response.content_type == "application/json" |     assert response.content_type == "application/json" | ||||||
|  |  | ||||||
|     _, response = app.test_client.get( |     _, response = app.test_client.get( | ||||||
|         "/error", headers={"content-type": "text/plain"} |         "/error", headers={"content-type": "foo/bar", "accept": "*/*"} | ||||||
|  |     ) | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/html; charset=utf-8" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_route_error_format_set_on_auto(app): | ||||||
|  |     @app.get("/text") | ||||||
|  |     def text_response(request): | ||||||
|  |         return text(request.route.ctx.error_format) | ||||||
|  |  | ||||||
|  |     @app.get("/json") | ||||||
|  |     def json_response(request): | ||||||
|  |         return json({"format": request.route.ctx.error_format}) | ||||||
|  |  | ||||||
|  |     @app.get("/html") | ||||||
|  |     def html_response(request): | ||||||
|  |         return html(request.route.ctx.error_format) | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/text") | ||||||
|  |     assert response.text == "text" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/json") | ||||||
|  |     assert response.json["format"] == "json" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/html") | ||||||
|  |     assert response.text == "html" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_route_error_response_from_auto_route(app): | ||||||
|  |     @app.get("/text") | ||||||
|  |     def text_response(request): | ||||||
|  |         raise Exception("oops") | ||||||
|  |         return text("Never gonna see this") | ||||||
|  |  | ||||||
|  |     @app.get("/json") | ||||||
|  |     def json_response(request): | ||||||
|  |         raise Exception("oops") | ||||||
|  |         return json({"message": "Never gonna see this"}) | ||||||
|  |  | ||||||
|  |     @app.get("/html") | ||||||
|  |     def html_response(request): | ||||||
|  |         raise Exception("oops") | ||||||
|  |         return html("<h1>Never gonna see this</h1>") | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/text") | ||||||
|  |     assert response.content_type == "text/plain; charset=utf-8" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/json") | ||||||
|  |     assert response.content_type == "application/json" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/html") | ||||||
|  |     assert response.content_type == "text/html; charset=utf-8" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_route_error_response_from_explicit_format(app): | ||||||
|  |     @app.get("/text", error_format="json") | ||||||
|  |     def text_response(request): | ||||||
|  |         raise Exception("oops") | ||||||
|  |         return text("Never gonna see this") | ||||||
|  |  | ||||||
|  |     @app.get("/json", error_format="text") | ||||||
|  |     def json_response(request): | ||||||
|  |         raise Exception("oops") | ||||||
|  |         return json({"message": "Never gonna see this"}) | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/text") | ||||||
|  |     assert response.content_type == "application/json" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/json") | ||||||
|  |     assert response.content_type == "text/plain; charset=utf-8" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_unknown_fallback_format(app): | ||||||
|  |     with pytest.raises(SanicException, match="Unknown format: bad"): | ||||||
|  |         app.config.FALLBACK_ERROR_FORMAT = "bad" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_route_error_format_unknown(app): | ||||||
|  |     with pytest.raises(SanicException, match="Unknown format: bad"): | ||||||
|  |  | ||||||
|  |         @app.get("/text", error_format="bad") | ||||||
|  |         def handler(request): | ||||||
|  |             ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_fallback_with_content_type_mismatch_accept(app): | ||||||
|  |     app.config.FALLBACK_ERROR_FORMAT = "auto" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get( | ||||||
|  |         "/error", | ||||||
|  |         headers={"content-type": "application/json", "accept": "text/plain"}, | ||||||
|     ) |     ) | ||||||
|     assert response.status == 500 |     assert response.status == 500 | ||||||
|     assert response.content_type == "text/plain; charset=utf-8" |     assert response.content_type == "text/plain; charset=utf-8" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get( | ||||||
|  |         "/error", | ||||||
|  |         headers={"content-type": "text/plain", "accept": "foo/bar"}, | ||||||
|  |     ) | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/html; charset=utf-8" | ||||||
|  |  | ||||||
|  |     app.router.reset() | ||||||
|  |  | ||||||
|  |     @app.route("/alt1") | ||||||
|  |     @app.route("/alt2", error_format="text") | ||||||
|  |     @app.route("/alt3", error_format="html") | ||||||
|  |     def handler(_): | ||||||
|  |         raise Exception("problem here") | ||||||
|  |         # Yes, we know this return value is unreachable. This is on purpose. | ||||||
|  |         return json({}) | ||||||
|  |  | ||||||
|  |     app.router.finalize() | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get( | ||||||
|  |         "/alt1", | ||||||
|  |         headers={"accept": "foo/bar"}, | ||||||
|  |     ) | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/html; charset=utf-8" | ||||||
|  |     _, response = app.test_client.get( | ||||||
|  |         "/alt1", | ||||||
|  |         headers={"accept": "foo/bar,*/*"}, | ||||||
|  |     ) | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "application/json" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get( | ||||||
|  |         "/alt2", | ||||||
|  |         headers={"accept": "foo/bar"}, | ||||||
|  |     ) | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/html; charset=utf-8" | ||||||
|  |     _, response = app.test_client.get( | ||||||
|  |         "/alt2", | ||||||
|  |         headers={"accept": "foo/bar,*/*"}, | ||||||
|  |     ) | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/plain; charset=utf-8" | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get( | ||||||
|  |         "/alt3", | ||||||
|  |         headers={"accept": "foo/bar"}, | ||||||
|  |     ) | ||||||
|  |     assert response.status == 500 | ||||||
|  |     assert response.content_type == "text/html; charset=utf-8" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "accept,content_type,expected", | ||||||
|  |     ( | ||||||
|  |         (None, None, "text/plain; charset=utf-8"), | ||||||
|  |         ("foo/bar", None, "text/html; charset=utf-8"), | ||||||
|  |         ("application/json", None, "application/json"), | ||||||
|  |         ("application/json,text/plain", None, "application/json"), | ||||||
|  |         ("text/plain,application/json", None, "application/json"), | ||||||
|  |         ("text/plain,foo/bar", None, "text/plain; charset=utf-8"), | ||||||
|  |         # Following test is valid after v22.3 | ||||||
|  |         # ("text/plain,text/html", None, "text/plain; charset=utf-8"), | ||||||
|  |         ("*/*", "foo/bar", "text/html; charset=utf-8"), | ||||||
|  |         ("*/*", "application/json", "application/json"), | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_combinations_for_auto(fake_request, accept, content_type, expected): | ||||||
|  |     if accept: | ||||||
|  |         fake_request.headers["accept"] = accept | ||||||
|  |     else: | ||||||
|  |         del fake_request.headers["accept"] | ||||||
|  |  | ||||||
|  |     if content_type: | ||||||
|  |         fake_request.headers["content-type"] = content_type | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         raise Exception("bad stuff") | ||||||
|  |     except Exception as e: | ||||||
|  |         response = exception_response( | ||||||
|  |             fake_request, | ||||||
|  |             e, | ||||||
|  |             True, | ||||||
|  |             base=HTMLRenderer, | ||||||
|  |             fallback="auto", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     assert response.content_type == expected | ||||||
|   | |||||||
| @@ -1,8 +1,10 @@ | |||||||
|  | import logging | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from bs4 import BeautifulSoup | from bs4 import BeautifulSoup | ||||||
|  | from websockets.version import version as websockets_version | ||||||
|  |  | ||||||
| from sanic import Sanic | from sanic import Sanic | ||||||
| from sanic.exceptions import ( | from sanic.exceptions import ( | ||||||
| @@ -232,3 +234,41 @@ def test_sanic_exception(exception_app): | |||||||
|         request, response = exception_app.test_client.get("/old_abort") |         request, response = exception_app.test_client.get("/old_abort") | ||||||
|     assert response.status == 500 |     assert response.status == 500 | ||||||
|     assert len(w) == 1 and "deprecated" in w[0].message.args[0] |     assert len(w) == 1 and "deprecated" in w[0].message.args[0] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_custom_exception_default_message(exception_app): | ||||||
|  |     class TeaError(SanicException): | ||||||
|  |         message = "Tempest in a teapot" | ||||||
|  |         status_code = 418 | ||||||
|  |  | ||||||
|  |     exception_app.router.reset() | ||||||
|  |  | ||||||
|  |     @exception_app.get("/tempest") | ||||||
|  |     def tempest(_): | ||||||
|  |         raise TeaError | ||||||
|  |  | ||||||
|  |     _, response = exception_app.test_client.get("/tempest", debug=True) | ||||||
|  |     assert response.status == 418 | ||||||
|  |     assert b"Tempest in a teapot" in response.body | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_exception_in_ws_logged(caplog): | ||||||
|  |     app = Sanic(__file__) | ||||||
|  |  | ||||||
|  |     @app.websocket("/feed") | ||||||
|  |     async def feed(request, ws): | ||||||
|  |         raise Exception("...") | ||||||
|  |  | ||||||
|  |     with caplog.at_level(logging.INFO): | ||||||
|  |         app.test_client.websocket("/feed") | ||||||
|  |     # Websockets v10.0 and above output an additional | ||||||
|  |     # INFO message when a ws connection is accepted | ||||||
|  |     ws_version_parts = websockets_version.split(".") | ||||||
|  |     ws_major = int(ws_version_parts[0]) | ||||||
|  |     record_index = 2 if ws_major >= 10 else 1 | ||||||
|  |     assert caplog.record_tuples[record_index][0] == "sanic.error" | ||||||
|  |     assert caplog.record_tuples[record_index][1] == logging.ERROR | ||||||
|  |     assert ( | ||||||
|  |         "Exception occurred while handling uri:" | ||||||
|  |         in caplog.record_tuples[record_index][2] | ||||||
|  |     ) | ||||||
|   | |||||||
| @@ -1,4 +1,7 @@ | |||||||
| import asyncio | import asyncio | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
| from bs4 import BeautifulSoup | from bs4 import BeautifulSoup | ||||||
|  |  | ||||||
| @@ -8,9 +11,6 @@ from sanic.handlers import ErrorHandler | |||||||
| from sanic.response import stream, text | from sanic.response import stream, text | ||||||
|  |  | ||||||
|  |  | ||||||
| exception_handler_app = Sanic("test_exception_handler") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| async def sample_streaming_fn(response): | async def sample_streaming_fn(response): | ||||||
|     await response.write("foo,") |     await response.write("foo,") | ||||||
|     await asyncio.sleep(0.001) |     await asyncio.sleep(0.001) | ||||||
| @@ -21,113 +21,107 @@ class ErrorWithRequestCtx(ServerError): | |||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
| @exception_handler_app.route("/1") | @pytest.fixture | ||||||
| def handler_1(request): | def exception_handler_app(): | ||||||
|  |     exception_handler_app = Sanic("test_exception_handler") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/1", error_format="html") | ||||||
|  |     def handler_1(request): | ||||||
|         raise InvalidUsage("OK") |         raise InvalidUsage("OK") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/2", error_format="html") | ||||||
| @exception_handler_app.route("/2") |     def handler_2(request): | ||||||
| def handler_2(request): |  | ||||||
|         raise ServerError("OK") |         raise ServerError("OK") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/3", error_format="html") | ||||||
| @exception_handler_app.route("/3") |     def handler_3(request): | ||||||
| def handler_3(request): |  | ||||||
|         raise NotFound("OK") |         raise NotFound("OK") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/4", error_format="html") | ||||||
| @exception_handler_app.route("/4") |     def handler_4(request): | ||||||
| def handler_4(request): |         foo = bar  # noqa -- F821 | ||||||
|     foo = bar  # noqa -- F821 undefined name 'bar' is done to throw exception |  | ||||||
|         return text(foo) |         return text(foo) | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/5", error_format="html") | ||||||
| @exception_handler_app.route("/5") |     def handler_5(request): | ||||||
| def handler_5(request): |  | ||||||
|         class CustomServerError(ServerError): |         class CustomServerError(ServerError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|         raise CustomServerError("Custom server error") |         raise CustomServerError("Custom server error") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/6/<arg:int>", error_format="html") | ||||||
| @exception_handler_app.route("/6/<arg:int>") |     def handler_6(request, arg): | ||||||
| def handler_6(request, arg): |  | ||||||
|         try: |         try: | ||||||
|             foo = 1 / arg |             foo = 1 / arg | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise e from ValueError(f"{arg}") |             raise e from ValueError(f"{arg}") | ||||||
|         return text(foo) |         return text(foo) | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/7", error_format="html") | ||||||
| @exception_handler_app.route("/7") |     def handler_7(request): | ||||||
| def handler_7(request): |  | ||||||
|         raise Forbidden("go away!") |         raise Forbidden("go away!") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.route("/8", error_format="html") | ||||||
| @exception_handler_app.route("/8") |     def handler_8(request): | ||||||
| def handler_8(request): |  | ||||||
|  |  | ||||||
|         raise ErrorWithRequestCtx("OK") |         raise ErrorWithRequestCtx("OK") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.exception(ErrorWithRequestCtx, NotFound) | ||||||
| @exception_handler_app.exception(ErrorWithRequestCtx, NotFound) |     def handler_exception_with_ctx(request, exception): | ||||||
| def handler_exception_with_ctx(request, exception): |  | ||||||
|         return text(request.ctx.middleware_ran) |         return text(request.ctx.middleware_ran) | ||||||
|  |  | ||||||
|  |     @exception_handler_app.exception(ServerError) | ||||||
| @exception_handler_app.exception(ServerError) |     def handler_exception(request, exception): | ||||||
| def handler_exception(request, exception): |  | ||||||
|         return text("OK") |         return text("OK") | ||||||
|  |  | ||||||
|  |     @exception_handler_app.exception(Forbidden) | ||||||
| @exception_handler_app.exception(Forbidden) |     async def async_handler_exception(request, exception): | ||||||
| async def async_handler_exception(request, exception): |  | ||||||
|         return stream( |         return stream( | ||||||
|             sample_streaming_fn, |             sample_streaming_fn, | ||||||
|             content_type="text/csv", |             content_type="text/csv", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @exception_handler_app.middleware | ||||||
| @exception_handler_app.middleware |     async def some_request_middleware(request): | ||||||
| async def some_request_middleware(request): |  | ||||||
|         request.ctx.middleware_ran = "Done." |         request.ctx.middleware_ran = "Done." | ||||||
|  |  | ||||||
|  |     return exception_handler_app | ||||||
|  |  | ||||||
| def test_invalid_usage_exception_handler(): |  | ||||||
|  | def test_invalid_usage_exception_handler(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/1") |     request, response = exception_handler_app.test_client.get("/1") | ||||||
|     assert response.status == 400 |     assert response.status == 400 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_server_error_exception_handler(): | def test_server_error_exception_handler(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/2") |     request, response = exception_handler_app.test_client.get("/2") | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|     assert response.text == "OK" |     assert response.text == "OK" | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_not_found_exception_handler(): | def test_not_found_exception_handler(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/3") |     request, response = exception_handler_app.test_client.get("/3") | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_text_exception__handler(): | def test_text_exception__handler(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/random") |     request, response = exception_handler_app.test_client.get("/random") | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|     assert response.text == "Done." |     assert response.text == "Done." | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_async_exception_handler(): | def test_async_exception_handler(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/7") |     request, response = exception_handler_app.test_client.get("/7") | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|     assert response.text == "foo,bar" |     assert response.text == "foo,bar" | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_html_traceback_output_in_debug_mode(): | def test_html_traceback_output_in_debug_mode(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/4", debug=True) |     request, response = exception_handler_app.test_client.get("/4", debug=True) | ||||||
|     assert response.status == 500 |     assert response.status == 500 | ||||||
|     soup = BeautifulSoup(response.body, "html.parser") |     soup = BeautifulSoup(response.body, "html.parser") | ||||||
|     html = str(soup) |     html = str(soup) | ||||||
|  |  | ||||||
|     assert "response = handler(request, **kwargs)" in html |  | ||||||
|     assert "handler_4" in html |     assert "handler_4" in html | ||||||
|     assert "foo = bar" in html |     assert "foo = bar" in html | ||||||
|  |  | ||||||
| @@ -137,12 +131,12 @@ def test_html_traceback_output_in_debug_mode(): | |||||||
|     ) == summary_text |     ) == summary_text | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_inherited_exception_handler(): | def test_inherited_exception_handler(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/5") |     request, response = exception_handler_app.test_client.get("/5") | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_chained_exception_handler(): | def test_chained_exception_handler(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get( |     request, response = exception_handler_app.test_client.get( | ||||||
|         "/6/0", debug=True |         "/6/0", debug=True | ||||||
|     ) |     ) | ||||||
| @@ -151,11 +145,9 @@ def test_chained_exception_handler(): | |||||||
|     soup = BeautifulSoup(response.body, "html.parser") |     soup = BeautifulSoup(response.body, "html.parser") | ||||||
|     html = str(soup) |     html = str(soup) | ||||||
|  |  | ||||||
|     assert "response = handler(request, **kwargs)" in html |  | ||||||
|     assert "handler_6" in html |     assert "handler_6" in html | ||||||
|     assert "foo = 1 / arg" in html |     assert "foo = 1 / arg" in html | ||||||
|     assert "ValueError" in html |     assert "ValueError" in html | ||||||
|     assert "The above exception was the direct cause" in html |  | ||||||
|  |  | ||||||
|     summary_text = " ".join(soup.select(".summary")[0].text.split()) |     summary_text = " ".join(soup.select(".summary")[0].text.split()) | ||||||
|     assert ( |     assert ( | ||||||
| @@ -163,7 +155,7 @@ def test_chained_exception_handler(): | |||||||
|     ) == summary_text |     ) == summary_text | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_exception_handler_lookup(): | def test_exception_handler_lookup(exception_handler_app): | ||||||
|     class CustomError(Exception): |     class CustomError(Exception): | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
| @@ -186,26 +178,52 @@ def test_exception_handler_lookup(): | |||||||
|         class ModuleNotFoundError(ImportError): |         class ModuleNotFoundError(ImportError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|     handler = ErrorHandler() |     handler = ErrorHandler("auto") | ||||||
|     handler.add(ImportError, import_error_handler) |     handler.add(ImportError, import_error_handler) | ||||||
|     handler.add(CustomError, custom_error_handler) |     handler.add(CustomError, custom_error_handler) | ||||||
|     handler.add(ServerError, server_error_handler) |     handler.add(ServerError, server_error_handler) | ||||||
|  |  | ||||||
|     assert handler.lookup(ImportError()) == import_error_handler |     assert handler.lookup(ImportError(), None) == import_error_handler | ||||||
|     assert handler.lookup(ModuleNotFoundError()) == import_error_handler |     assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler | ||||||
|     assert handler.lookup(CustomError()) == custom_error_handler |     assert handler.lookup(CustomError(), None) == custom_error_handler | ||||||
|     assert handler.lookup(ServerError("Error")) == server_error_handler |     assert handler.lookup(ServerError("Error"), None) == server_error_handler | ||||||
|     assert handler.lookup(CustomServerError("Error")) == server_error_handler |     assert ( | ||||||
|  |         handler.lookup(CustomServerError("Error"), None) | ||||||
|  |         == server_error_handler | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     # once again to ensure there is no caching bug |     # once again to ensure there is no caching bug | ||||||
|     assert handler.lookup(ImportError()) == import_error_handler |     assert handler.lookup(ImportError(), None) == import_error_handler | ||||||
|     assert handler.lookup(ModuleNotFoundError()) == import_error_handler |     assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler | ||||||
|     assert handler.lookup(CustomError()) == custom_error_handler |     assert handler.lookup(CustomError(), None) == custom_error_handler | ||||||
|     assert handler.lookup(ServerError("Error")) == server_error_handler |     assert handler.lookup(ServerError("Error"), None) == server_error_handler | ||||||
|     assert handler.lookup(CustomServerError("Error")) == server_error_handler |     assert ( | ||||||
|  |         handler.lookup(CustomServerError("Error"), None) | ||||||
|  |         == server_error_handler | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_exception_handler_processed_request_middleware(): | def test_exception_handler_processed_request_middleware(exception_handler_app): | ||||||
|     request, response = exception_handler_app.test_client.get("/8") |     request, response = exception_handler_app.test_client.get("/8") | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|     assert response.text == "Done." |     assert response.text == "Done." | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_single_arg_exception_handler_notice(exception_handler_app, caplog): | ||||||
|  |     class CustomErrorHandler(ErrorHandler): | ||||||
|  |         def lookup(self, exception): | ||||||
|  |             return super().lookup(exception, None) | ||||||
|  |  | ||||||
|  |     exception_handler_app.error_handler = CustomErrorHandler() | ||||||
|  |  | ||||||
|  |     with caplog.at_level(logging.WARNING): | ||||||
|  |         _, response = exception_handler_app.test_client.get("/1") | ||||||
|  |  | ||||||
|  |     assert caplog.records[0].message == ( | ||||||
|  |         "You are using a deprecated error handler. The lookup method should " | ||||||
|  |         "accept two positional parameters: (exception, route_name: " | ||||||
|  |         "Optional[str]). Until you upgrade your ErrorHandler.lookup, " | ||||||
|  |         "Blueprint specific exceptions will not work properly. Beginning in " | ||||||
|  |         "v22.3, the legacy style lookup method will not work at all." | ||||||
|  |     ) | ||||||
|  |     assert response.status == 400 | ||||||
|   | |||||||
							
								
								
									
										46
									
								
								tests/test_graceful_shutdown.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/test_graceful_shutdown.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | import asyncio | ||||||
|  | import logging | ||||||
|  | import time | ||||||
|  |  | ||||||
|  | from collections import Counter | ||||||
|  | from multiprocessing import Process | ||||||
|  |  | ||||||
|  | import httpx | ||||||
|  |  | ||||||
|  |  | ||||||
|  | PORT = 42101 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_no_exceptions_when_cancel_pending_request(app, caplog): | ||||||
|  |     app.config.GRACEFUL_SHUTDOWN_TIMEOUT = 1 | ||||||
|  |  | ||||||
|  |     @app.get("/") | ||||||
|  |     async def handler(request): | ||||||
|  |         await asyncio.sleep(5) | ||||||
|  |  | ||||||
|  |     @app.after_server_start | ||||||
|  |     def shutdown(app, _): | ||||||
|  |         time.sleep(0.2) | ||||||
|  |         app.stop() | ||||||
|  |  | ||||||
|  |     def ping(): | ||||||
|  |         time.sleep(0.1) | ||||||
|  |         response = httpx.get("http://127.0.0.1:8000") | ||||||
|  |         print(response.status_code) | ||||||
|  |  | ||||||
|  |     p = Process(target=ping) | ||||||
|  |     p.start() | ||||||
|  |  | ||||||
|  |     with caplog.at_level(logging.INFO): | ||||||
|  |         app.run() | ||||||
|  |  | ||||||
|  |     p.kill() | ||||||
|  |  | ||||||
|  |     counter = Counter([r[1] for r in caplog.record_tuples]) | ||||||
|  |  | ||||||
|  |     assert counter[logging.INFO] == 5 | ||||||
|  |     assert logging.ERROR not in counter | ||||||
|  |     assert ( | ||||||
|  |         caplog.record_tuples[3][2] | ||||||
|  |         == "Request: GET http://127.0.0.1:8000/ stopped. Transport is closed." | ||||||
|  |     ) | ||||||
							
								
								
									
										39
									
								
								tests/test_handler_annotations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								tests/test_handler_annotations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | |||||||
|  | from uuid import UUID | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | from sanic import json | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "idx,path,expectation", | ||||||
|  |     ( | ||||||
|  |         (0, "/abc", "str"), | ||||||
|  |         (1, "/123", "int"), | ||||||
|  |         (2, "/123.5", "float"), | ||||||
|  |         (3, "/8af729fe-2b94-4a95-a168-c07068568429", "UUID"), | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_annotated_handlers(app, idx, path, expectation): | ||||||
|  |     def build_response(num, foo): | ||||||
|  |         return json({"num": num, "type": type(foo).__name__}) | ||||||
|  |  | ||||||
|  |     @app.get("/<foo>") | ||||||
|  |     def handler0(_, foo: str): | ||||||
|  |         return build_response(0, foo) | ||||||
|  |  | ||||||
|  |     @app.get("/<foo>") | ||||||
|  |     def handler1(_, foo: int): | ||||||
|  |         return build_response(1, foo) | ||||||
|  |  | ||||||
|  |     @app.get("/<foo>") | ||||||
|  |     def handler2(_, foo: float): | ||||||
|  |         return build_response(2, foo) | ||||||
|  |  | ||||||
|  |     @app.get("/<foo>") | ||||||
|  |     def handler3(_, foo: UUID): | ||||||
|  |         return build_response(3, foo) | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get(path) | ||||||
|  |     assert response.json["num"] == idx | ||||||
|  |     assert response.json["type"] == expectation | ||||||
| @@ -3,8 +3,9 @@ from unittest.mock import Mock | |||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from sanic import headers, text | from sanic import headers, text | ||||||
| from sanic.exceptions import PayloadTooLarge | from sanic.exceptions import InvalidHeader, PayloadTooLarge | ||||||
| from sanic.http import Http | from sanic.http import Http | ||||||
|  | from sanic.request import Request | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @@ -182,3 +183,187 @@ def test_request_line(app): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     assert request.request_line == b"GET / HTTP/1.1" |     assert request.request_line == b"GET / HTTP/1.1" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "raw", | ||||||
|  |     ( | ||||||
|  |         "show/first, show/second", | ||||||
|  |         "show/*, show/first", | ||||||
|  |         "*/*, show/first", | ||||||
|  |         "*/*, show/*", | ||||||
|  |         "other/*; q=0.1, show/*; q=0.2", | ||||||
|  |         "show/first; q=0.5, show/second; q=0.5", | ||||||
|  |         "show/first; foo=bar, show/second; foo=bar", | ||||||
|  |         "show/second, show/first; foo=bar", | ||||||
|  |         "show/second; q=0.5, show/first; foo=bar; q=0.5", | ||||||
|  |         "show/second; q=0.5, show/first; q=1.0", | ||||||
|  |         "show/first, show/second; q=1.0", | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_parse_accept_ordered_okay(raw): | ||||||
|  |     ordered = headers.parse_accept(raw) | ||||||
|  |     expected_subtype = ( | ||||||
|  |         "*" if all(q.subtype.is_wildcard for q in ordered) else "first" | ||||||
|  |     ) | ||||||
|  |     assert ordered[0].type_ == "show" | ||||||
|  |     assert ordered[0].subtype == expected_subtype | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "raw", | ||||||
|  |     ( | ||||||
|  |         "missing", | ||||||
|  |         "missing/", | ||||||
|  |         "/missing", | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_bad_accept(raw): | ||||||
|  |     with pytest.raises(InvalidHeader): | ||||||
|  |         headers.parse_accept(raw) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_empty_accept(): | ||||||
|  |     assert headers.parse_accept("") == [] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_wildcard_accept_set_ok(): | ||||||
|  |     accept = headers.parse_accept("*/*")[0] | ||||||
|  |     assert accept.type_.is_wildcard | ||||||
|  |     assert accept.subtype.is_wildcard | ||||||
|  |  | ||||||
|  |     accept = headers.parse_accept("foo/bar")[0] | ||||||
|  |     assert not accept.type_.is_wildcard | ||||||
|  |     assert not accept.subtype.is_wildcard | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_accept_parsed_against_str(): | ||||||
|  |     accept = headers.Accept.parse("foo/bar") | ||||||
|  |     assert accept > "foo/bar; q=0.1" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_media_type_equality(): | ||||||
|  |     assert headers.MediaType("foo") == headers.MediaType("foo") == "foo" | ||||||
|  |     assert headers.MediaType("foo") == headers.MediaType("*") == "*" | ||||||
|  |     assert headers.MediaType("foo") != headers.MediaType("bar") | ||||||
|  |     assert headers.MediaType("foo") != "bar" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_media_type_matching(): | ||||||
|  |     assert headers.MediaType("foo").match(headers.MediaType("foo")) | ||||||
|  |     assert headers.MediaType("foo").match("foo") | ||||||
|  |  | ||||||
|  |     assert not headers.MediaType("foo").match(headers.MediaType("*")) | ||||||
|  |     assert not headers.MediaType("foo").match("*") | ||||||
|  |  | ||||||
|  |     assert not headers.MediaType("foo").match(headers.MediaType("bar")) | ||||||
|  |     assert not headers.MediaType("foo").match("bar") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "value,other,outcome,allow_type,allow_subtype", | ||||||
|  |     ( | ||||||
|  |         # ALLOW BOTH | ||||||
|  |         ("foo/bar", "foo/bar", True, True, True), | ||||||
|  |         ("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), | ||||||
|  |         ("foo/bar", "foo/*", True, True, True), | ||||||
|  |         ("foo/bar", headers.Accept.parse("foo/*"), True, True, True), | ||||||
|  |         ("foo/bar", "*/*", True, True, True), | ||||||
|  |         ("foo/bar", headers.Accept.parse("*/*"), True, True, True), | ||||||
|  |         ("foo/*", "foo/bar", True, True, True), | ||||||
|  |         ("foo/*", headers.Accept.parse("foo/bar"), True, True, True), | ||||||
|  |         ("foo/*", "foo/*", True, True, True), | ||||||
|  |         ("foo/*", headers.Accept.parse("foo/*"), True, True, True), | ||||||
|  |         ("foo/*", "*/*", True, True, True), | ||||||
|  |         ("foo/*", headers.Accept.parse("*/*"), True, True, True), | ||||||
|  |         ("*/*", "foo/bar", True, True, True), | ||||||
|  |         ("*/*", headers.Accept.parse("foo/bar"), True, True, True), | ||||||
|  |         ("*/*", "foo/*", True, True, True), | ||||||
|  |         ("*/*", headers.Accept.parse("foo/*"), True, True, True), | ||||||
|  |         ("*/*", "*/*", True, True, True), | ||||||
|  |         ("*/*", headers.Accept.parse("*/*"), True, True, True), | ||||||
|  |         # ALLOW TYPE | ||||||
|  |         ("foo/bar", "foo/bar", True, True, False), | ||||||
|  |         ("foo/bar", headers.Accept.parse("foo/bar"), True, True, False), | ||||||
|  |         ("foo/bar", "foo/*", False, True, False), | ||||||
|  |         ("foo/bar", headers.Accept.parse("foo/*"), False, True, False), | ||||||
|  |         ("foo/bar", "*/*", False, True, False), | ||||||
|  |         ("foo/bar", headers.Accept.parse("*/*"), False, True, False), | ||||||
|  |         ("foo/*", "foo/bar", False, True, False), | ||||||
|  |         ("foo/*", headers.Accept.parse("foo/bar"), False, True, False), | ||||||
|  |         ("foo/*", "foo/*", False, True, False), | ||||||
|  |         ("foo/*", headers.Accept.parse("foo/*"), False, True, False), | ||||||
|  |         ("foo/*", "*/*", False, True, False), | ||||||
|  |         ("foo/*", headers.Accept.parse("*/*"), False, True, False), | ||||||
|  |         ("*/*", "foo/bar", False, True, False), | ||||||
|  |         ("*/*", headers.Accept.parse("foo/bar"), False, True, False), | ||||||
|  |         ("*/*", "foo/*", False, True, False), | ||||||
|  |         ("*/*", headers.Accept.parse("foo/*"), False, True, False), | ||||||
|  |         ("*/*", "*/*", False, True, False), | ||||||
|  |         ("*/*", headers.Accept.parse("*/*"), False, True, False), | ||||||
|  |         # ALLOW SUBTYPE | ||||||
|  |         ("foo/bar", "foo/bar", True, False, True), | ||||||
|  |         ("foo/bar", headers.Accept.parse("foo/bar"), True, False, True), | ||||||
|  |         ("foo/bar", "foo/*", True, False, True), | ||||||
|  |         ("foo/bar", headers.Accept.parse("foo/*"), True, False, True), | ||||||
|  |         ("foo/bar", "*/*", False, False, True), | ||||||
|  |         ("foo/bar", headers.Accept.parse("*/*"), False, False, True), | ||||||
|  |         ("foo/*", "foo/bar", True, False, True), | ||||||
|  |         ("foo/*", headers.Accept.parse("foo/bar"), True, False, True), | ||||||
|  |         ("foo/*", "foo/*", True, False, True), | ||||||
|  |         ("foo/*", headers.Accept.parse("foo/*"), True, False, True), | ||||||
|  |         ("foo/*", "*/*", False, False, True), | ||||||
|  |         ("foo/*", headers.Accept.parse("*/*"), False, False, True), | ||||||
|  |         ("*/*", "foo/bar", False, False, True), | ||||||
|  |         ("*/*", headers.Accept.parse("foo/bar"), False, False, True), | ||||||
|  |         ("*/*", "foo/*", False, False, True), | ||||||
|  |         ("*/*", headers.Accept.parse("foo/*"), False, False, True), | ||||||
|  |         ("*/*", "*/*", False, False, True), | ||||||
|  |         ("*/*", headers.Accept.parse("*/*"), False, False, True), | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_accept_matching(value, other, outcome, allow_type, allow_subtype): | ||||||
|  |     assert ( | ||||||
|  |         headers.Accept.parse(value).match( | ||||||
|  |             other, | ||||||
|  |             allow_type_wildcard=allow_type, | ||||||
|  |             allow_subtype_wildcard=allow_subtype, | ||||||
|  |         ) | ||||||
|  |         is outcome | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*")) | ||||||
|  | def test_value_in_accept(value): | ||||||
|  |     acceptable = headers.parse_accept(value) | ||||||
|  |     assert "foo/bar" in acceptable | ||||||
|  |     assert "foo/*" in acceptable | ||||||
|  |     assert "*/*" in acceptable | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("value", ("foo/bar", "foo/*")) | ||||||
|  | def test_value_not_in_accept(value): | ||||||
|  |     acceptable = headers.parse_accept(value) | ||||||
|  |     assert "no/match" not in acceptable | ||||||
|  |     assert "no/*" not in acceptable | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "header,expected", | ||||||
|  |     ( | ||||||
|  |         ( | ||||||
|  |             "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",  # noqa: E501 | ||||||
|  |             [ | ||||||
|  |                 "text/html", | ||||||
|  |                 "application/xhtml+xml", | ||||||
|  |                 "image/avif", | ||||||
|  |                 "image/webp", | ||||||
|  |                 "application/xml;q=0.9", | ||||||
|  |                 "*/*;q=0.8", | ||||||
|  |             ], | ||||||
|  |         ), | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_browser_headers(header, expected): | ||||||
|  |     request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) | ||||||
|  |     assert request.accept == expected | ||||||
|   | |||||||
							
								
								
									
										137
									
								
								tests/test_http.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								tests/test_http.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,137 @@ | |||||||
|  | import asyncio | ||||||
|  | import json as stdjson | ||||||
|  |  | ||||||
|  | from collections import namedtuple | ||||||
|  | from textwrap import dedent | ||||||
|  | from typing import AnyStr | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | from sanic_testing.reusable import ReusableClient | ||||||
|  |  | ||||||
|  | from sanic import json, text | ||||||
|  | from sanic.app import Sanic | ||||||
|  |  | ||||||
|  |  | ||||||
|  | PORT = 1234 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RawClient: | ||||||
|  |     CRLF = b"\r\n" | ||||||
|  |  | ||||||
|  |     def __init__(self, host: str, port: int): | ||||||
|  |         self.reader = None | ||||||
|  |         self.writer = None | ||||||
|  |         self.host = host | ||||||
|  |         self.port = port | ||||||
|  |  | ||||||
|  |     async def connect(self): | ||||||
|  |         self.reader, self.writer = await asyncio.open_connection( | ||||||
|  |             self.host, self.port | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     async def close(self): | ||||||
|  |         self.writer.close() | ||||||
|  |         await self.writer.wait_closed() | ||||||
|  |  | ||||||
|  |     async def send(self, message: AnyStr): | ||||||
|  |         if isinstance(message, str): | ||||||
|  |             msg = self._clean(message).encode("utf-8") | ||||||
|  |         else: | ||||||
|  |             msg = message | ||||||
|  |         await self._send(msg) | ||||||
|  |  | ||||||
|  |     async def _send(self, message: bytes): | ||||||
|  |         if not self.writer: | ||||||
|  |             raise Exception("No open write stream") | ||||||
|  |         self.writer.write(message) | ||||||
|  |  | ||||||
|  |     async def recv(self, nbytes: int = -1) -> bytes: | ||||||
|  |         if not self.reader: | ||||||
|  |             raise Exception("No open read stream") | ||||||
|  |         return await self.reader.read(nbytes) | ||||||
|  |  | ||||||
|  |     def _clean(self, message: str) -> str: | ||||||
|  |         return ( | ||||||
|  |             dedent(message) | ||||||
|  |             .lstrip("\n") | ||||||
|  |             .replace("\n", self.CRLF.decode("utf-8")) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def test_app(app: Sanic): | ||||||
|  |     app.config.KEEP_ALIVE_TIMEOUT = 1 | ||||||
|  |  | ||||||
|  |     @app.get("/") | ||||||
|  |     async def base_handler(request): | ||||||
|  |         return text("111122223333444455556666777788889999") | ||||||
|  |  | ||||||
|  |     @app.post("/upload", stream=True) | ||||||
|  |     async def upload_handler(request): | ||||||
|  |         data = [part.decode("utf-8") async for part in request.stream] | ||||||
|  |         return json(data) | ||||||
|  |  | ||||||
|  |     return app | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def runner(test_app): | ||||||
|  |     client = ReusableClient(test_app, port=PORT) | ||||||
|  |     client.run() | ||||||
|  |     yield client | ||||||
|  |     client.stop() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def client(runner): | ||||||
|  |     client = namedtuple("Client", ("raw", "send", "recv")) | ||||||
|  |  | ||||||
|  |     raw = RawClient(runner.host, runner.port) | ||||||
|  |     runner._run(raw.connect()) | ||||||
|  |  | ||||||
|  |     def send(msg): | ||||||
|  |         nonlocal runner | ||||||
|  |         nonlocal raw | ||||||
|  |         runner._run(raw.send(msg)) | ||||||
|  |  | ||||||
|  |     def recv(**kwargs): | ||||||
|  |         nonlocal runner | ||||||
|  |         nonlocal raw | ||||||
|  |         method = raw.recv_until if "until" in kwargs else raw.recv | ||||||
|  |         return runner._run(method(**kwargs)) | ||||||
|  |  | ||||||
|  |     yield client(raw, send, recv) | ||||||
|  |  | ||||||
|  |     runner._run(raw.close()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_full_message(client): | ||||||
|  |     client.send( | ||||||
|  |         """ | ||||||
|  |         GET / HTTP/1.1 | ||||||
|  |         host: localhost:7777 | ||||||
|  |  | ||||||
|  |         """ | ||||||
|  |     ) | ||||||
|  |     response = client.recv() | ||||||
|  |     assert len(response) == 140 | ||||||
|  |     assert b"200 OK" in response | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_transfer_chunked(client): | ||||||
|  |     client.send( | ||||||
|  |         """ | ||||||
|  |         POST /upload HTTP/1.1 | ||||||
|  |         transfer-encoding: chunked | ||||||
|  |  | ||||||
|  |         """ | ||||||
|  |     ) | ||||||
|  |     client.send(b"3\r\nfoo\r\n") | ||||||
|  |     client.send(b"3\r\nbar\r\n") | ||||||
|  |     client.send(b"0\r\n\r\n") | ||||||
|  |     response = client.recv() | ||||||
|  |     _, body = response.rsplit(b"\r\n\r\n", 1) | ||||||
|  |     data = stdjson.loads(body) | ||||||
|  |  | ||||||
|  |     assert data == ["foo", "bar"] | ||||||
| @@ -2,16 +2,13 @@ import asyncio | |||||||
| import platform | import platform | ||||||
|  |  | ||||||
| from asyncio import sleep as aio_sleep | from asyncio import sleep as aio_sleep | ||||||
| from json import JSONDecodeError |  | ||||||
| from os import environ | from os import environ | ||||||
|  |  | ||||||
| import httpcore |  | ||||||
| import httpx |  | ||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from sanic_testing.testing import HOST, SanicTestClient | from sanic_testing.reusable import ReusableClient | ||||||
|  |  | ||||||
| from sanic import Sanic, server | from sanic import Sanic | ||||||
| from sanic.compat import OS_IS_WINDOWS | from sanic.compat import OS_IS_WINDOWS | ||||||
| from sanic.response import text | from sanic.response import text | ||||||
|  |  | ||||||
| @@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} | |||||||
| PORT = 42101  # test_keep_alive_timeout_reuse doesn't work with random port | PORT = 42101  # test_keep_alive_timeout_reuse doesn't work with random port | ||||||
|  |  | ||||||
|  |  | ||||||
| class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool): |  | ||||||
|     last_reused_connection = None |  | ||||||
|  |  | ||||||
|     async def _get_connection_from_pool(self, *args, **kwargs): |  | ||||||
|         conn = await super()._get_connection_from_pool(*args, **kwargs) |  | ||||||
|         self.__class__.last_reused_connection = conn |  | ||||||
|         return conn |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResusableSanicSession(httpx.AsyncClient): |  | ||||||
|     def __init__(self, *args, **kwargs) -> None: |  | ||||||
|         transport = ReusableSanicConnectionPool() |  | ||||||
|         super().__init__(transport=transport, *args, **kwargs) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ReuseableSanicTestClient(SanicTestClient): |  | ||||||
|     def __init__(self, app, loop=None): |  | ||||||
|         super().__init__(app) |  | ||||||
|         if loop is None: |  | ||||||
|             loop = asyncio.get_event_loop() |  | ||||||
|         self._loop = loop |  | ||||||
|         self._server = None |  | ||||||
|         self._tcp_connector = None |  | ||||||
|         self._session = None |  | ||||||
|  |  | ||||||
|     def get_new_session(self): |  | ||||||
|         return ResusableSanicSession() |  | ||||||
|  |  | ||||||
|     # Copied from SanicTestClient, but with some changes to reuse the |  | ||||||
|     # same loop for the same app. |  | ||||||
|     def _sanic_endpoint_test( |  | ||||||
|         self, |  | ||||||
|         method="get", |  | ||||||
|         uri="/", |  | ||||||
|         gather_request=True, |  | ||||||
|         debug=False, |  | ||||||
|         server_kwargs=None, |  | ||||||
|         *request_args, |  | ||||||
|         **request_kwargs, |  | ||||||
|     ): |  | ||||||
|         loop = self._loop |  | ||||||
|         results = [None, None] |  | ||||||
|         exceptions = [] |  | ||||||
|         server_kwargs = server_kwargs or {"return_asyncio_server": True} |  | ||||||
|         if gather_request: |  | ||||||
|  |  | ||||||
|             def _collect_request(request): |  | ||||||
|                 if results[0] is None: |  | ||||||
|                     results[0] = request |  | ||||||
|  |  | ||||||
|             self.app.request_middleware.appendleft(_collect_request) |  | ||||||
|  |  | ||||||
|         if uri.startswith( |  | ||||||
|             ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") |  | ||||||
|         ): |  | ||||||
|             url = uri |  | ||||||
|         else: |  | ||||||
|             uri = uri if uri.startswith("/") else f"/{uri}" |  | ||||||
|             scheme = "http" |  | ||||||
|             url = f"{scheme}://{HOST}:{PORT}{uri}" |  | ||||||
|  |  | ||||||
|         @self.app.listener("after_server_start") |  | ||||||
|         async def _collect_response(loop): |  | ||||||
|             try: |  | ||||||
|                 response = await self._local_request( |  | ||||||
|                     method, url, *request_args, **request_kwargs |  | ||||||
|                 ) |  | ||||||
|                 results[-1] = response |  | ||||||
|             except Exception as e2: |  | ||||||
|                 exceptions.append(e2) |  | ||||||
|  |  | ||||||
|         if self._server is not None: |  | ||||||
|             _server = self._server |  | ||||||
|         else: |  | ||||||
|             _server_co = self.app.create_server( |  | ||||||
|                 host=HOST, debug=debug, port=PORT, **server_kwargs |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             server.trigger_events( |  | ||||||
|                 self.app.listeners["before_server_start"], loop |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             try: |  | ||||||
|                 loop._stopping = False |  | ||||||
|                 _server = loop.run_until_complete(_server_co) |  | ||||||
|             except Exception as e1: |  | ||||||
|                 raise e1 |  | ||||||
|             self._server = _server |  | ||||||
|         server.trigger_events(self.app.listeners["after_server_start"], loop) |  | ||||||
|         self.app.listeners["after_server_start"].pop() |  | ||||||
|  |  | ||||||
|         if exceptions: |  | ||||||
|             raise ValueError(f"Exception during request: {exceptions}") |  | ||||||
|  |  | ||||||
|         if gather_request: |  | ||||||
|             self.app.request_middleware.pop() |  | ||||||
|             try: |  | ||||||
|                 request, response = results |  | ||||||
|                 return request, response |  | ||||||
|             except Exception: |  | ||||||
|                 raise ValueError( |  | ||||||
|                     f"Request and response object expected, got ({results})" |  | ||||||
|                 ) |  | ||||||
|         else: |  | ||||||
|             try: |  | ||||||
|                 return results[-1] |  | ||||||
|             except Exception: |  | ||||||
|                 raise ValueError(f"Request object expected, got ({results})") |  | ||||||
|  |  | ||||||
|     def kill_server(self): |  | ||||||
|         try: |  | ||||||
|             if self._server: |  | ||||||
|                 self._server.close() |  | ||||||
|                 self._loop.run_until_complete(self._server.wait_closed()) |  | ||||||
|                 self._server = None |  | ||||||
|  |  | ||||||
|             if self._session: |  | ||||||
|                 self._loop.run_until_complete(self._session.aclose()) |  | ||||||
|                 self._session = None |  | ||||||
|  |  | ||||||
|         except Exception as e3: |  | ||||||
|             raise e3 |  | ||||||
|  |  | ||||||
|     # Copied from SanicTestClient, but with some changes to reuse the |  | ||||||
|     # same TCPConnection and the sane ClientSession more than once. |  | ||||||
|     # Note, you cannot use the same session if you are in a _different_ |  | ||||||
|     # loop, so the changes above are required too. |  | ||||||
|     async def _local_request(self, method, url, *args, **kwargs): |  | ||||||
|         raw_cookies = kwargs.pop("raw_cookies", None) |  | ||||||
|         request_keepalive = kwargs.pop( |  | ||||||
|             "request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"] |  | ||||||
|         ) |  | ||||||
|         if not self._session: |  | ||||||
|             self._session = self.get_new_session() |  | ||||||
|         try: |  | ||||||
|             response = await getattr(self._session, method.lower())( |  | ||||||
|                 url, timeout=request_keepalive, *args, **kwargs |  | ||||||
|             ) |  | ||||||
|         except NameError: |  | ||||||
|             raise Exception(response.status_code) |  | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             response.json = response.json() |  | ||||||
|         except (JSONDecodeError, UnicodeDecodeError): |  | ||||||
|             response.json = None |  | ||||||
|  |  | ||||||
|         response.body = await response.aread() |  | ||||||
|         response.status = response.status_code |  | ||||||
|         response.content_type = response.headers.get("content-type") |  | ||||||
|  |  | ||||||
|         if raw_cookies: |  | ||||||
|             response.raw_cookies = {} |  | ||||||
|             for cookie in response.cookies: |  | ||||||
|                 response.raw_cookies[cookie.name] = cookie |  | ||||||
|  |  | ||||||
|         return response |  | ||||||
|  |  | ||||||
|  |  | ||||||
| keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") | keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") | ||||||
| keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") | keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") | ||||||
| keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") | keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") | ||||||
| @@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse(): | |||||||
|     """If the server keep-alive timeout and client keep-alive timeout are |     """If the server keep-alive timeout and client keep-alive timeout are | ||||||
|     both longer than the delay, the client _and_ server will successfully |     both longer than the delay, the client _and_ server will successfully | ||||||
|     reuse the existing connection.""" |     reuse the existing connection.""" | ||||||
|     try: |  | ||||||
|     loop = asyncio.new_event_loop() |     loop = asyncio.new_event_loop() | ||||||
|     asyncio.set_event_loop(loop) |     asyncio.set_event_loop(loop) | ||||||
|         client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop) |     client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT) | ||||||
|  |     with client: | ||||||
|         headers = {"Connection": "keep-alive"} |         headers = {"Connection": "keep-alive"} | ||||||
|         request, response = client.get("/1", headers=headers) |         request, response = client.get("/1", headers=headers) | ||||||
|         assert response.status == 200 |         assert response.status == 200 | ||||||
|         assert response.text == "OK" |         assert response.text == "OK" | ||||||
|  |         assert request.protocol.state["requests_count"] == 1 | ||||||
|  |  | ||||||
|         loop.run_until_complete(aio_sleep(1)) |         loop.run_until_complete(aio_sleep(1)) | ||||||
|  |  | ||||||
|         request, response = client.get("/1") |         request, response = client.get("/1") | ||||||
|         assert response.status == 200 |         assert response.status == 200 | ||||||
|         assert response.text == "OK" |         assert response.text == "OK" | ||||||
|         assert ReusableSanicConnectionPool.last_reused_connection |         assert request.protocol.state["requests_count"] == 2 | ||||||
|     finally: |  | ||||||
|         client.kill_server() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||||
| @@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse(): | |||||||
| def test_keep_alive_client_timeout(): | def test_keep_alive_client_timeout(): | ||||||
|     """If the server keep-alive timeout is longer than the client |     """If the server keep-alive timeout is longer than the client | ||||||
|     keep-alive timeout, client will try to create a new connection here.""" |     keep-alive timeout, client will try to create a new connection here.""" | ||||||
|     try: |  | ||||||
|     loop = asyncio.new_event_loop() |     loop = asyncio.new_event_loop() | ||||||
|     asyncio.set_event_loop(loop) |     asyncio.set_event_loop(loop) | ||||||
|         client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop) |     client = ReusableClient( | ||||||
|  |         keep_alive_app_client_timeout, loop=loop, port=PORT | ||||||
|  |     ) | ||||||
|  |     with client: | ||||||
|         headers = {"Connection": "keep-alive"} |         headers = {"Connection": "keep-alive"} | ||||||
|         _, response = client.get("/1", headers=headers, request_keepalive=1) |         request, response = client.get("/1", headers=headers, timeout=1) | ||||||
|  |  | ||||||
|         assert response.status == 200 |         assert response.status == 200 | ||||||
|         assert response.text == "OK" |         assert response.text == "OK" | ||||||
|  |         assert request.protocol.state["requests_count"] == 1 | ||||||
|  |  | ||||||
|         loop.run_until_complete(aio_sleep(2)) |         loop.run_until_complete(aio_sleep(2)) | ||||||
|         _, response = client.get("/1", request_keepalive=1) |         request, response = client.get("/1", timeout=1) | ||||||
|  |         assert request.protocol.state["requests_count"] == 1 | ||||||
|         assert ReusableSanicConnectionPool.last_reused_connection is None |  | ||||||
|     finally: |  | ||||||
|         client.kill_server() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||||
| @@ -277,22 +117,23 @@ def test_keep_alive_server_timeout(): | |||||||
|     keep-alive timeout, the client will either a 'Connection reset' error |     keep-alive timeout, the client will either a 'Connection reset' error | ||||||
|     _or_ a new connection. Depending on how the event-loop handles the |     _or_ a new connection. Depending on how the event-loop handles the | ||||||
|     broken server connection.""" |     broken server connection.""" | ||||||
|     try: |  | ||||||
|     loop = asyncio.new_event_loop() |     loop = asyncio.new_event_loop() | ||||||
|     asyncio.set_event_loop(loop) |     asyncio.set_event_loop(loop) | ||||||
|         client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop) |     client = ReusableClient( | ||||||
|  |         keep_alive_app_server_timeout, loop=loop, port=PORT | ||||||
|  |     ) | ||||||
|  |     with client: | ||||||
|         headers = {"Connection": "keep-alive"} |         headers = {"Connection": "keep-alive"} | ||||||
|         _, response = client.get("/1", headers=headers, request_keepalive=60) |         request, response = client.get("/1", headers=headers, timeout=60) | ||||||
|  |  | ||||||
|         assert response.status == 200 |         assert response.status == 200 | ||||||
|         assert response.text == "OK" |         assert response.text == "OK" | ||||||
|  |         assert request.protocol.state["requests_count"] == 1 | ||||||
|  |  | ||||||
|         loop.run_until_complete(aio_sleep(3)) |         loop.run_until_complete(aio_sleep(3)) | ||||||
|         _, response = client.get("/1", request_keepalive=60) |         request, response = client.get("/1", timeout=60) | ||||||
|  |  | ||||||
|         assert ReusableSanicConnectionPool.last_reused_connection is None |         assert request.protocol.state["requests_count"] == 1 | ||||||
|     finally: |  | ||||||
|         client.kill_server() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||||
| @@ -300,10 +141,10 @@ def test_keep_alive_server_timeout(): | |||||||
|     reason="Not testable with current client", |     reason="Not testable with current client", | ||||||
| ) | ) | ||||||
| def test_keep_alive_connection_context(): | def test_keep_alive_connection_context(): | ||||||
|     try: |  | ||||||
|     loop = asyncio.new_event_loop() |     loop = asyncio.new_event_loop() | ||||||
|     asyncio.set_event_loop(loop) |     asyncio.set_event_loop(loop) | ||||||
|         client = ReuseableSanicTestClient(keep_alive_app_context, loop) |     client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT) | ||||||
|  |     with client: | ||||||
|         headers = {"Connection": "keep-alive"} |         headers = {"Connection": "keep-alive"} | ||||||
|         request1, _ = client.post("/ctx", headers=headers) |         request1, _ = client.post("/ctx", headers=headers) | ||||||
|  |  | ||||||
| @@ -315,5 +156,4 @@ def test_keep_alive_connection_context(): | |||||||
|         assert ( |         assert ( | ||||||
|             request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello" |             request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello" | ||||||
|         ) |         ) | ||||||
|     finally: |         assert request2.protocol.state["requests_count"] == 2 | ||||||
|         client.kill_server() |  | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import uuid | |||||||
|  |  | ||||||
| from importlib import reload | from importlib import reload | ||||||
| from io import StringIO | from io import StringIO | ||||||
|  | from unittest.mock import Mock | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| @@ -51,7 +52,7 @@ def test_log(app): | |||||||
|  |  | ||||||
| def test_logging_defaults(): | def test_logging_defaults(): | ||||||
|     # reset_logging() |     # reset_logging() | ||||||
|     app = Sanic("test_logging") |     Sanic("test_logging") | ||||||
|  |  | ||||||
|     for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: |     for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: | ||||||
|         assert ( |         assert ( | ||||||
| @@ -87,7 +88,7 @@ def test_logging_pass_customer_logconfig(): | |||||||
|         "format" |         "format" | ||||||
|     ] = "%(asctime)s - (%(name)s)[%(levelname)s]: %(message)s" |     ] = "%(asctime)s - (%(name)s)[%(levelname)s]: %(message)s" | ||||||
|  |  | ||||||
|     app = Sanic("test_logging", log_config=modified_config) |     Sanic("test_logging", log_config=modified_config) | ||||||
|  |  | ||||||
|     for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: |     for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: | ||||||
|         assert fmt._fmt == modified_config["formatters"]["generic"]["format"] |         assert fmt._fmt == modified_config["formatters"]["generic"]["format"] | ||||||
| @@ -111,11 +112,13 @@ def test_logging_pass_customer_logconfig(): | |||||||
|     ), |     ), | ||||||
| ) | ) | ||||||
| def test_log_connection_lost(app, debug, monkeypatch): | def test_log_connection_lost(app, debug, monkeypatch): | ||||||
|     """ Should not log Connection lost exception on non debug """ |     """Should not log Connection lost exception on non debug""" | ||||||
|     stream = StringIO() |     stream = StringIO() | ||||||
|     error = logging.getLogger("sanic.error") |     error = logging.getLogger("sanic.error") | ||||||
|     error.addHandler(logging.StreamHandler(stream)) |     error.addHandler(logging.StreamHandler(stream)) | ||||||
|     monkeypatch.setattr(sanic.server, "error_logger", error) |     monkeypatch.setattr( | ||||||
|  |         sanic.server.protocols.http_protocol, "error_logger", error | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     @app.route("/conn_lost") |     @app.route("/conn_lost") | ||||||
|     async def conn_lost(request): |     async def conn_lost(request): | ||||||
| @@ -208,6 +211,56 @@ def test_logging_modified_root_logger_config(): | |||||||
|     modified_config = LOGGING_CONFIG_DEFAULTS |     modified_config = LOGGING_CONFIG_DEFAULTS | ||||||
|     modified_config["loggers"]["sanic.root"]["level"] = "DEBUG" |     modified_config["loggers"]["sanic.root"]["level"] = "DEBUG" | ||||||
|  |  | ||||||
|     app = Sanic("test_logging", log_config=modified_config) |     Sanic("test_logging", log_config=modified_config) | ||||||
|  |  | ||||||
|     assert logging.getLogger("sanic.root").getEffectiveLevel() == logging.DEBUG |     assert logging.getLogger("sanic.root").getEffectiveLevel() == logging.DEBUG | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_access_log_client_ip_remote_addr(monkeypatch): | ||||||
|  |     access = Mock() | ||||||
|  |     monkeypatch.setattr(sanic.http, "access_logger", access) | ||||||
|  |  | ||||||
|  |     app = Sanic("test_logging") | ||||||
|  |     app.config.PROXIES_COUNT = 2 | ||||||
|  |  | ||||||
|  |     @app.route("/") | ||||||
|  |     async def handler(request): | ||||||
|  |         return text(request.remote_addr) | ||||||
|  |  | ||||||
|  |     headers = {"X-Forwarded-For": "1.1.1.1, 2.2.2.2"} | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get("/", headers=headers) | ||||||
|  |  | ||||||
|  |     assert request.remote_addr == "1.1.1.1" | ||||||
|  |     access.info.assert_called_with( | ||||||
|  |         "", | ||||||
|  |         extra={ | ||||||
|  |             "status": 200, | ||||||
|  |             "byte": len(response.content), | ||||||
|  |             "host": f"{request.remote_addr}:{request.port}", | ||||||
|  |             "request": f"GET {request.scheme}://{request.host}/", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_access_log_client_ip_reqip(monkeypatch): | ||||||
|  |     access = Mock() | ||||||
|  |     monkeypatch.setattr(sanic.http, "access_logger", access) | ||||||
|  |  | ||||||
|  |     app = Sanic("test_logging") | ||||||
|  |  | ||||||
|  |     @app.route("/") | ||||||
|  |     async def handler(request): | ||||||
|  |         return text(request.ip) | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get("/") | ||||||
|  |  | ||||||
|  |     access.info.assert_called_with( | ||||||
|  |         "", | ||||||
|  |         extra={ | ||||||
|  |             "status": 200, | ||||||
|  |             "byte": len(response.content), | ||||||
|  |             "host": f"{request.ip}:{request.port}", | ||||||
|  |             "request": f"GET {request.scheme}://{request.host}/", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|   | |||||||
| @@ -6,85 +6,37 @@ from sanic_testing.testing import PORT | |||||||
| from sanic.config import BASE_LOGO | from sanic.config import BASE_LOGO | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_logo_base(app, caplog): | def test_logo_base(app, run_startup): | ||||||
|     server = app.create_server( |     logs = run_startup(app) | ||||||
|         debug=True, return_asyncio_server=True, port=PORT |  | ||||||
|     ) |  | ||||||
|     loop = asyncio.new_event_loop() |  | ||||||
|     asyncio.set_event_loop(loop) |  | ||||||
|     loop._stopping = False |  | ||||||
|  |  | ||||||
|     with caplog.at_level(logging.DEBUG): |     assert logs[0][1] == logging.DEBUG | ||||||
|         _server = loop.run_until_complete(server) |     assert logs[0][2] == BASE_LOGO | ||||||
|  |  | ||||||
|     _server.close() |  | ||||||
|     loop.run_until_complete(_server.wait_closed()) |  | ||||||
|     app.stop() |  | ||||||
|  |  | ||||||
|     assert caplog.record_tuples[0][1] == logging.DEBUG |  | ||||||
|     assert caplog.record_tuples[0][2] == BASE_LOGO |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_logo_false(app, caplog): | def test_logo_false(app, caplog, run_startup): | ||||||
|     app.config.LOGO = False |     app.config.LOGO = False | ||||||
|  |  | ||||||
|     server = app.create_server( |     logs = run_startup(app) | ||||||
|         debug=True, return_asyncio_server=True, port=PORT |  | ||||||
|     ) |  | ||||||
|     loop = asyncio.new_event_loop() |  | ||||||
|     asyncio.set_event_loop(loop) |  | ||||||
|     loop._stopping = False |  | ||||||
|  |  | ||||||
|     with caplog.at_level(logging.DEBUG): |     banner, port = logs[0][2].rsplit(":", 1) | ||||||
|         _server = loop.run_until_complete(server) |     assert logs[0][1] == logging.INFO | ||||||
|  |  | ||||||
|     _server.close() |  | ||||||
|     loop.run_until_complete(_server.wait_closed()) |  | ||||||
|     app.stop() |  | ||||||
|  |  | ||||||
|     banner, port = caplog.record_tuples[0][2].rsplit(":", 1) |  | ||||||
|     assert caplog.record_tuples[0][1] == logging.INFO |  | ||||||
|     assert banner == "Goin' Fast @ http://127.0.0.1" |     assert banner == "Goin' Fast @ http://127.0.0.1" | ||||||
|     assert int(port) > 0 |     assert int(port) > 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_logo_true(app, caplog): | def test_logo_true(app, run_startup): | ||||||
|     app.config.LOGO = True |     app.config.LOGO = True | ||||||
|  |  | ||||||
|     server = app.create_server( |     logs = run_startup(app) | ||||||
|         debug=True, return_asyncio_server=True, port=PORT |  | ||||||
|     ) |  | ||||||
|     loop = asyncio.new_event_loop() |  | ||||||
|     asyncio.set_event_loop(loop) |  | ||||||
|     loop._stopping = False |  | ||||||
|  |  | ||||||
|     with caplog.at_level(logging.DEBUG): |     assert logs[0][1] == logging.DEBUG | ||||||
|         _server = loop.run_until_complete(server) |     assert logs[0][2] == BASE_LOGO | ||||||
|  |  | ||||||
|     _server.close() |  | ||||||
|     loop.run_until_complete(_server.wait_closed()) |  | ||||||
|     app.stop() |  | ||||||
|  |  | ||||||
|     assert caplog.record_tuples[0][1] == logging.DEBUG |  | ||||||
|     assert caplog.record_tuples[0][2] == BASE_LOGO |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_logo_custom(app, caplog): | def test_logo_custom(app, run_startup): | ||||||
|     app.config.LOGO = "My Custom Logo" |     app.config.LOGO = "My Custom Logo" | ||||||
|  |  | ||||||
|     server = app.create_server( |     logs = run_startup(app) | ||||||
|         debug=True, return_asyncio_server=True, port=PORT |  | ||||||
|     ) |  | ||||||
|     loop = asyncio.new_event_loop() |  | ||||||
|     asyncio.set_event_loop(loop) |  | ||||||
|     loop._stopping = False |  | ||||||
|  |  | ||||||
|     with caplog.at_level(logging.DEBUG): |     assert logs[0][1] == logging.DEBUG | ||||||
|         _server = loop.run_until_complete(server) |     assert logs[0][2] == "My Custom Logo" | ||||||
|  |  | ||||||
|     _server.close() |  | ||||||
|     loop.run_until_complete(_server.wait_closed()) |  | ||||||
|     app.stop() |  | ||||||
|  |  | ||||||
|     assert caplog.record_tuples[0][1] == logging.DEBUG |  | ||||||
|     assert caplog.record_tuples[0][2] == "My Custom Logo" |  | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ from itertools import count | |||||||
|  |  | ||||||
| from sanic.exceptions import NotFound | from sanic.exceptions import NotFound | ||||||
| from sanic.request import Request | from sanic.request import Request | ||||||
| from sanic.response import HTTPResponse, text | from sanic.response import HTTPResponse, json, text | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------------------------------------ # | # ------------------------------------------------------------ # | ||||||
| @@ -37,14 +37,19 @@ def test_middleware_request_as_convenience(app): | |||||||
|     async def handler1(request): |     async def handler1(request): | ||||||
|         results.append(request) |         results.append(request) | ||||||
|  |  | ||||||
|     @app.route("/") |     @app.on_request() | ||||||
|     async def handler2(request): |     async def handler2(request): | ||||||
|  |         results.append(request) | ||||||
|  |  | ||||||
|  |     @app.route("/") | ||||||
|  |     async def handler3(request): | ||||||
|         return text("OK") |         return text("OK") | ||||||
|  |  | ||||||
|     request, response = app.test_client.get("/") |     request, response = app.test_client.get("/") | ||||||
|  |  | ||||||
|     assert response.text == "OK" |     assert response.text == "OK" | ||||||
|     assert type(results[0]) is Request |     assert type(results[0]) is Request | ||||||
|  |     assert type(results[1]) is Request | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_middleware_response(app): | def test_middleware_response(app): | ||||||
| @@ -79,7 +84,12 @@ def test_middleware_response_as_convenience(app): | |||||||
|         results.append(request) |         results.append(request) | ||||||
|  |  | ||||||
|     @app.on_response |     @app.on_response | ||||||
|     async def process_response(request, response): |     async def process_response_1(request, response): | ||||||
|  |         results.append(request) | ||||||
|  |         results.append(response) | ||||||
|  |  | ||||||
|  |     @app.on_response() | ||||||
|  |     async def process_response_2(request, response): | ||||||
|         results.append(request) |         results.append(request) | ||||||
|         results.append(response) |         results.append(response) | ||||||
|  |  | ||||||
| @@ -93,6 +103,8 @@ def test_middleware_response_as_convenience(app): | |||||||
|     assert type(results[0]) is Request |     assert type(results[0]) is Request | ||||||
|     assert type(results[1]) is Request |     assert type(results[1]) is Request | ||||||
|     assert isinstance(results[2], HTTPResponse) |     assert isinstance(results[2], HTTPResponse) | ||||||
|  |     assert type(results[3]) is Request | ||||||
|  |     assert isinstance(results[4], HTTPResponse) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_middleware_response_as_convenience_called(app): | def test_middleware_response_as_convenience_called(app): | ||||||
| @@ -271,3 +283,17 @@ def test_request_middleware_executes_once(app): | |||||||
|  |  | ||||||
|     request, response = app.test_client.get("/") |     request, response = app.test_client.get("/") | ||||||
|     assert next(i) == 3 |     assert next(i) == 3 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_middleware_added_response(app): | ||||||
|  |     @app.on_response | ||||||
|  |     def display(_, response): | ||||||
|  |         response["foo"] = "bar" | ||||||
|  |         return json(response) | ||||||
|  |  | ||||||
|  |     @app.get("/") | ||||||
|  |     async def handler(request): | ||||||
|  |         return {} | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/") | ||||||
|  |     assert response.json["foo"] == "bar" | ||||||
|   | |||||||
							
								
								
									
										105
									
								
								tests/test_pipelining.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								tests/test_pipelining.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | from httpx import AsyncByteStream | ||||||
|  | from sanic_testing.reusable import ReusableClient | ||||||
|  |  | ||||||
|  | from sanic.response import json, text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_no_body_requests(app): | ||||||
|  |     @app.get("/") | ||||||
|  |     async def handler(request): | ||||||
|  |         return json( | ||||||
|  |             { | ||||||
|  |                 "request_id": str(request.id), | ||||||
|  |                 "connection_id": id(request.conn_info), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     client = ReusableClient(app, port=1234) | ||||||
|  |  | ||||||
|  |     with client: | ||||||
|  |         _, response1 = client.get("/") | ||||||
|  |         _, response2 = client.get("/") | ||||||
|  |  | ||||||
|  |     assert response1.status == response2.status == 200 | ||||||
|  |     assert response1.json["request_id"] != response2.json["request_id"] | ||||||
|  |     assert response1.json["connection_id"] == response2.json["connection_id"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_json_body_requests(app): | ||||||
|  |     @app.post("/") | ||||||
|  |     async def handler(request): | ||||||
|  |         return json( | ||||||
|  |             { | ||||||
|  |                 "request_id": str(request.id), | ||||||
|  |                 "connection_id": id(request.conn_info), | ||||||
|  |                 "foo": request.json.get("foo"), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     client = ReusableClient(app, port=1234) | ||||||
|  |  | ||||||
|  |     with client: | ||||||
|  |         _, response1 = client.post("/", json={"foo": True}) | ||||||
|  |         _, response2 = client.post("/", json={"foo": True}) | ||||||
|  |  | ||||||
|  |     assert response1.status == response2.status == 200 | ||||||
|  |     assert response1.json["foo"] is response2.json["foo"] is True | ||||||
|  |     assert response1.json["request_id"] != response2.json["request_id"] | ||||||
|  |     assert response1.json["connection_id"] == response2.json["connection_id"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_streaming_body_requests(app): | ||||||
|  |     @app.post("/", stream=True) | ||||||
|  |     async def handler(request): | ||||||
|  |         data = [part.decode("utf-8") async for part in request.stream] | ||||||
|  |         return json( | ||||||
|  |             { | ||||||
|  |                 "request_id": str(request.id), | ||||||
|  |                 "connection_id": id(request.conn_info), | ||||||
|  |                 "data": data, | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     data = ["hello", "world"] | ||||||
|  |  | ||||||
|  |     class Data(AsyncByteStream): | ||||||
|  |         def __init__(self, data): | ||||||
|  |             self.data = data | ||||||
|  |  | ||||||
|  |         async def __aiter__(self): | ||||||
|  |             for value in self.data: | ||||||
|  |                 yield value.encode("utf-8") | ||||||
|  |  | ||||||
|  |     client = ReusableClient(app, port=1234) | ||||||
|  |  | ||||||
|  |     with client: | ||||||
|  |         _, response1 = client.post("/", data=Data(data)) | ||||||
|  |         _, response2 = client.post("/", data=Data(data)) | ||||||
|  |  | ||||||
|  |     assert response1.status == response2.status == 200 | ||||||
|  |     assert response1.json["data"] == response2.json["data"] == data | ||||||
|  |     assert response1.json["request_id"] != response2.json["request_id"] | ||||||
|  |     assert response1.json["connection_id"] == response2.json["connection_id"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_bad_headers(app): | ||||||
|  |     @app.get("/") | ||||||
|  |     async def handler(request): | ||||||
|  |         return text("") | ||||||
|  |  | ||||||
|  |     @app.on_response | ||||||
|  |     async def reqid(request, response): | ||||||
|  |         response.headers["x-request-id"] = request.id | ||||||
|  |  | ||||||
|  |     client = ReusableClient(app, port=1234) | ||||||
|  |     bad_headers = {"bad": "bad" * 5_000} | ||||||
|  |  | ||||||
|  |     with client: | ||||||
|  |         _, response1 = client.get("/") | ||||||
|  |         _, response2 = client.get("/", headers=bad_headers) | ||||||
|  |  | ||||||
|  |     assert response1.status == 200 | ||||||
|  |     assert response2.status == 413 | ||||||
|  |     assert ( | ||||||
|  |         response1.headers["x-request-id"] != response2.headers["x-request-id"] | ||||||
|  |     ) | ||||||
| @@ -140,3 +140,39 @@ def test_ipv6_address_is_not_wrapped(app): | |||||||
|     assert resp.json["client"] == "[::1]" |     assert resp.json["client"] == "[::1]" | ||||||
|     assert resp.json["client_ip"] == "::1" |     assert resp.json["client_ip"] == "::1" | ||||||
|     assert request.ip == "::1" |     assert request.ip == "::1" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_request_accept(): | ||||||
|  |     app = Sanic("req-generator") | ||||||
|  |  | ||||||
|  |     @app.get("/") | ||||||
|  |     async def get(request): | ||||||
|  |         return response.empty() | ||||||
|  |  | ||||||
|  |     request, _ = app.test_client.get( | ||||||
|  |         "/", | ||||||
|  |         headers={ | ||||||
|  |             "Accept": "text/*, text/plain, text/plain;format=flowed, */*" | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     assert request.accept == [ | ||||||
|  |         "text/plain;format=flowed", | ||||||
|  |         "text/plain", | ||||||
|  |         "text/*", | ||||||
|  |         "*/*", | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     request, _ = app.test_client.get( | ||||||
|  |         "/", | ||||||
|  |         headers={ | ||||||
|  |             "Accept": ( | ||||||
|  |                 "text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c" | ||||||
|  |             ) | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     assert request.accept == [ | ||||||
|  |         "text/html", | ||||||
|  |         "text/x-c", | ||||||
|  |         "text/x-dvi; q=0.8", | ||||||
|  |         "text/plain; q=0.5", | ||||||
|  |     ] | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ import asyncio | |||||||
|  |  | ||||||
| import httpcore | import httpcore | ||||||
| import httpx | import httpx | ||||||
|  | import pytest | ||||||
|  |  | ||||||
| from sanic_testing.testing import SanicTestClient | from sanic_testing.testing import SanicTestClient | ||||||
|  |  | ||||||
| @@ -48,42 +49,51 @@ class DelayableSanicTestClient(SanicTestClient): | |||||||
|         return DelayableSanicSession(request_delay=self._request_delay) |         return DelayableSanicSession(request_delay=self._request_delay) | ||||||
|  |  | ||||||
|  |  | ||||||
| request_timeout_default_app = Sanic("test_request_timeout_default") | @pytest.fixture | ||||||
| request_no_timeout_app = Sanic("test_request_no_timeout") | def request_no_timeout_app(): | ||||||
| request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6 |     app = Sanic("test_request_no_timeout") | ||||||
| request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6 |     app.config.REQUEST_TIMEOUT = 0.6 | ||||||
|  |  | ||||||
|  |     @app.route("/1") | ||||||
| @request_timeout_default_app.route("/1") |     async def handler2(request): | ||||||
| async def handler1(request): |  | ||||||
|         return text("OK") |         return text("OK") | ||||||
|  |  | ||||||
|  |     return app | ||||||
|  |  | ||||||
| @request_no_timeout_app.route("/1") |  | ||||||
| async def handler2(request): | @pytest.fixture | ||||||
|  | def request_timeout_default_app(): | ||||||
|  |     app = Sanic("test_request_timeout_default") | ||||||
|  |     app.config.REQUEST_TIMEOUT = 0.6 | ||||||
|  |  | ||||||
|  |     @app.route("/1") | ||||||
|  |     async def handler1(request): | ||||||
|         return text("OK") |         return text("OK") | ||||||
|  |  | ||||||
|  |     @app.websocket("/ws1") | ||||||
| @request_timeout_default_app.websocket("/ws1") |     async def ws_handler1(request, ws): | ||||||
| async def ws_handler1(request, ws): |  | ||||||
|         await ws.send("OK") |         await ws.send("OK") | ||||||
|  |  | ||||||
|  |     return app | ||||||
|  |  | ||||||
| def test_default_server_error_request_timeout(): |  | ||||||
|  | def test_default_server_error_request_timeout(request_timeout_default_app): | ||||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) |     client = DelayableSanicTestClient(request_timeout_default_app, 2) | ||||||
|     request, response = client.get("/1") |     _, response = client.get("/1") | ||||||
|     assert response.status == 408 |     assert response.status == 408 | ||||||
|     assert "Request Timeout" in response.text |     assert "Request Timeout" in response.text | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_default_server_error_request_dont_timeout(): | def test_default_server_error_request_dont_timeout(request_no_timeout_app): | ||||||
|     client = DelayableSanicTestClient(request_no_timeout_app, 0.2) |     client = DelayableSanicTestClient(request_no_timeout_app, 0.2) | ||||||
|     request, response = client.get("/1") |     _, response = client.get("/1") | ||||||
|     assert response.status == 200 |     assert response.status == 200 | ||||||
|     assert response.text == "OK" |     assert response.text == "OK" | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_default_server_error_websocket_request_timeout(): | def test_default_server_error_websocket_request_timeout( | ||||||
|  |     request_timeout_default_app, | ||||||
|  | ): | ||||||
|  |  | ||||||
|     headers = { |     headers = { | ||||||
|         "Upgrade": "websocket", |         "Upgrade": "websocket", | ||||||
| @@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout(): | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     client = DelayableSanicTestClient(request_timeout_default_app, 2) |     client = DelayableSanicTestClient(request_timeout_default_app, 2) | ||||||
|     request, response = client.get("/ws1", headers=headers) |     _, response = client.get("/ws1", headers=headers) | ||||||
|  |  | ||||||
|     assert response.status == 408 |     assert response.status == 408 | ||||||
|     assert "Request Timeout" in response.text |     assert "Request Timeout" in response.text | ||||||
|   | |||||||
| @@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app): | |||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.parametrize("url", ["/ws", "ws"]) | @pytest.mark.parametrize("url", ["/ws", "ws"]) | ||||||
| async def test_websocket_route_asgi(app, url): | async def test_websocket_route_asgi(app, url): | ||||||
|     ev = asyncio.Event() |     @app.after_server_start | ||||||
|  |     async def setup_ev(app, _): | ||||||
|  |         app.ctx.ev = asyncio.Event() | ||||||
|  |  | ||||||
|     @app.websocket(url) |     @app.websocket(url) | ||||||
|     async def handler(request, ws): |     async def handler(request, ws): | ||||||
|         ev.set() |         request.app.ctx.ev.set() | ||||||
|  |  | ||||||
|     request, response = await app.asgi_client.websocket(url) |     @app.get("/ev") | ||||||
|     assert ev.is_set() |     async def check(request): | ||||||
|  |         return json({"set": request.app.ctx.ev.is_set()}) | ||||||
|  |  | ||||||
|  |     _, response = await app.asgi_client.websocket(url) | ||||||
|  |     _, response = await app.asgi_client.get("/") | ||||||
|  |     assert response.json["set"] | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_websocket_route_with_subprotocols(app): | @pytest.mark.parametrize( | ||||||
|  |     "subprotocols,expected", | ||||||
|  |     ( | ||||||
|  |         (["one"], "one"), | ||||||
|  |         (["three", "one"], "one"), | ||||||
|  |         (["tree"], None), | ||||||
|  |         (None, None), | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_websocket_route_with_subprotocols(app, subprotocols, expected): | ||||||
|     results = [] |     results = [] | ||||||
|  |  | ||||||
|     @app.websocket("/ws", subprotocols=["foo", "bar"]) |     @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"]) | ||||||
|     async def handler(request, ws): |     async def handler(request, ws): | ||||||
|         results.append(ws.subprotocol) |         nonlocal results | ||||||
|  |         results = ws.subprotocol | ||||||
|         assert ws.subprotocol is not None |         assert ws.subprotocol is not None | ||||||
|  |  | ||||||
|     _, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"]) |  | ||||||
|     assert response.opened is True |  | ||||||
|     assert results == ["bar"] |  | ||||||
|  |  | ||||||
|     _, response = SanicTestClient(app).websocket( |     _, response = SanicTestClient(app).websocket( | ||||||
|         "/ws", subprotocols=["bar", "foo"] |         "/ws", subprotocols=subprotocols | ||||||
|     ) |     ) | ||||||
|     assert response.opened is True |     assert response.opened is True | ||||||
|     assert results == ["bar", "bar"] |     assert results == expected | ||||||
|  |  | ||||||
|     _, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"]) |  | ||||||
|     assert response.opened is True |  | ||||||
|     assert results == ["bar", "bar", None] |  | ||||||
|  |  | ||||||
|     _, response = SanicTestClient(app).websocket("/ws") |  | ||||||
|     assert response.opened is True |  | ||||||
|     assert results == ["bar", "bar", None, None] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("strict_slashes", [True, False, None]) | @pytest.mark.parametrize("strict_slashes", [True, False, None]) | ||||||
|   | |||||||
| @@ -8,7 +8,7 @@ import pytest | |||||||
|  |  | ||||||
| from sanic_testing.testing import HOST, PORT | from sanic_testing.testing import HOST, PORT | ||||||
|  |  | ||||||
| from sanic.exceptions import InvalidUsage | from sanic.exceptions import InvalidUsage, SanicException | ||||||
|  |  | ||||||
|  |  | ||||||
| AVAILABLE_LISTENERS = [ | AVAILABLE_LISTENERS = [ | ||||||
| @@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app): | |||||||
|     async def init_db(app, loop): |     async def init_db(app, loop): | ||||||
|         app.db = MySanicDb() |         app.db = MySanicDb() | ||||||
|  |  | ||||||
|     await app.create_server(debug=True, return_asyncio_server=True, port=PORT) |     srv = await app.create_server( | ||||||
|  |         debug=True, return_asyncio_server=True, port=PORT | ||||||
|  |     ) | ||||||
|  |     await srv.startup() | ||||||
|  |     await srv.before_start() | ||||||
|  |  | ||||||
|     assert hasattr(app, "db") |     assert hasattr(app, "db") | ||||||
|     assert isinstance(app.db, MySanicDb) |     assert isinstance(app.db, MySanicDb) | ||||||
| @@ -157,14 +161,15 @@ def test_create_server_trigger_events(app): | |||||||
|         serv_coro = app.create_server(return_asyncio_server=True, sock=sock) |         serv_coro = app.create_server(return_asyncio_server=True, sock=sock) | ||||||
|         serv_task = asyncio.ensure_future(serv_coro, loop=loop) |         serv_task = asyncio.ensure_future(serv_coro, loop=loop) | ||||||
|         server = loop.run_until_complete(serv_task) |         server = loop.run_until_complete(serv_task) | ||||||
|         server.after_start() |         loop.run_until_complete(server.startup()) | ||||||
|  |         loop.run_until_complete(server.after_start()) | ||||||
|         try: |         try: | ||||||
|             loop.run_forever() |             loop.run_forever() | ||||||
|         except KeyboardInterrupt as e: |         except KeyboardInterrupt: | ||||||
|             loop.stop() |             loop.stop() | ||||||
|         finally: |         finally: | ||||||
|             # Run the on_stop function if provided |             # Run the on_stop function if provided | ||||||
|             server.before_stop() |             loop.run_until_complete(server.before_stop()) | ||||||
|  |  | ||||||
|             # Wait for server to close |             # Wait for server to close | ||||||
|             close_task = server.close() |             close_task = server.close() | ||||||
| @@ -174,5 +179,19 @@ def test_create_server_trigger_events(app): | |||||||
|             signal.stopped = True |             signal.stopped = True | ||||||
|             for connection in server.connections: |             for connection in server.connections: | ||||||
|                 connection.close_if_idle() |                 connection.close_if_idle() | ||||||
|             server.after_stop() |             loop.run_until_complete(server.after_stop()) | ||||||
|         assert flag1 and flag2 and flag3 |         assert flag1 and flag2 and flag3 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_missing_startup_raises_exception(app): | ||||||
|  |     @app.listener("before_server_start") | ||||||
|  |     async def init_db(app, loop): | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     srv = await app.create_server( | ||||||
|  |         debug=True, return_asyncio_server=True, port=PORT | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     with pytest.raises(SanicException): | ||||||
|  |         await srv.before_start() | ||||||
|   | |||||||
| @@ -95,7 +95,7 @@ def test_windows_workaround(): | |||||||
|         os.kill(os.getpid(), signal.SIGINT) |         os.kill(os.getpid(), signal.SIGINT) | ||||||
|         await asyncio.sleep(0.2) |         await asyncio.sleep(0.2) | ||||||
|         assert app.is_stopping |         assert app.is_stopping | ||||||
|         assert app.stay_active_task.result() == None |         assert app.stay_active_task.result() is None | ||||||
|         # Second Ctrl+C should raise |         # Second Ctrl+C should raise | ||||||
|         with pytest.raises(KeyboardInterrupt): |         with pytest.raises(KeyboardInterrupt): | ||||||
|             os.kill(os.getpid(), signal.SIGINT) |             os.kill(os.getpid(), signal.SIGINT) | ||||||
|   | |||||||
| @@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app): | |||||||
|  |  | ||||||
|     app.signal_router.finalize() |     app.signal_router.finalize() | ||||||
|  |  | ||||||
|  |     assert len(app.signal_router.routes) == 3 | ||||||
|     await app.dispatch("foo.bar.baz") |     await app.dispatch("foo.bar.baz") | ||||||
|     assert counter == 2 |     assert counter == 2 | ||||||
|  |  | ||||||
| @@ -331,7 +332,8 @@ def test_event_on_bp_not_registered(): | |||||||
|     "event,expected", |     "event,expected", | ||||||
|     ( |     ( | ||||||
|         ("foo.bar.baz", True), |         ("foo.bar.baz", True), | ||||||
|         ("server.init.before", False), |         ("server.init.before", True), | ||||||
|  |         ("server.init.somethingelse", False), | ||||||
|         ("http.request.start", False), |         ("http.request.start", False), | ||||||
|         ("sanic.notice.anything", True), |         ("sanic.notice.anything", True), | ||||||
|     ), |     ), | ||||||
|   | |||||||
| @@ -461,6 +461,22 @@ def test_nested_dir(app, static_file_directory): | |||||||
|     assert response.text == "foo\n" |     assert response.text == "foo\n" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_handle_is_a_directory_error(app, static_file_directory): | ||||||
|  |     error_text = "Is a directory. Access denied" | ||||||
|  |     app.static("/static", static_file_directory) | ||||||
|  |  | ||||||
|  |     @app.exception(Exception) | ||||||
|  |     async def handleStaticDirError(request, exception): | ||||||
|  |         if isinstance(exception, IsADirectoryError): | ||||||
|  |             return text(error_text, status=403) | ||||||
|  |         raise exception | ||||||
|  |  | ||||||
|  |     request, response = app.test_client.get("/static/") | ||||||
|  |  | ||||||
|  |     assert response.status == 403 | ||||||
|  |     assert response.text == error_text | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_stack_trace_on_not_found(app, static_file_directory, caplog): | def test_stack_trace_on_not_found(app, static_file_directory, caplog): | ||||||
|     app.static("/static", static_file_directory) |     app.static("/static", static_file_directory) | ||||||
|  |  | ||||||
| @@ -471,7 +487,7 @@ def test_stack_trace_on_not_found(app, static_file_directory, caplog): | |||||||
|  |  | ||||||
|     assert response.status == 404 |     assert response.status == 404 | ||||||
|     assert counter[logging.INFO] == 5 |     assert counter[logging.INFO] == 5 | ||||||
|     assert counter[logging.ERROR] == 1 |     assert counter[logging.ERROR] == 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_no_stack_trace_on_not_found(app, static_file_directory, caplog): | def test_no_stack_trace_on_not_found(app, static_file_directory, caplog): | ||||||
| @@ -507,3 +523,56 @@ def test_multiple_statics(app, static_file_directory): | |||||||
|     assert response.body == get_file_content( |     assert response.body == get_file_content( | ||||||
|         static_file_directory, "python.png" |         static_file_directory, "python.png" | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_resource_type_default(app, static_file_directory): | ||||||
|  |     app.static("/static", static_file_directory) | ||||||
|  |     app.static("/file", get_file_path(static_file_directory, "test.file")) | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/static") | ||||||
|  |     assert response.status == 404 | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/file") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.body == get_file_content( | ||||||
|  |         static_file_directory, "test.file" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_resource_type_file(app, static_file_directory): | ||||||
|  |     app.static( | ||||||
|  |         "/file", | ||||||
|  |         get_file_path(static_file_directory, "test.file"), | ||||||
|  |         resource_type="file", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/file") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.body == get_file_content( | ||||||
|  |         static_file_directory, "test.file" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeError): | ||||||
|  |         app.static("/static", static_file_directory, resource_type="file") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_resource_type_dir(app, static_file_directory): | ||||||
|  |     app.static("/static", static_file_directory, resource_type="dir") | ||||||
|  |  | ||||||
|  |     _, response = app.test_client.get("/static/test.file") | ||||||
|  |     assert response.status == 200 | ||||||
|  |     assert response.body == get_file_content( | ||||||
|  |         static_file_directory, "test.file" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeError): | ||||||
|  |         app.static( | ||||||
|  |             "/file", | ||||||
|  |             get_file_path(static_file_directory, "test.file"), | ||||||
|  |             resource_type="dir", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_resource_type_unknown(app, static_file_directory, caplog): | ||||||
|  |     with pytest.raises(ValueError): | ||||||
|  |         app.static("/static", static_file_directory, resource_type="unknown") | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								tests/test_touchup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								tests/test_touchup.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | import logging | ||||||
|  |  | ||||||
|  | from sanic.signals import RESERVED_NAMESPACES | ||||||
|  | from sanic.touchup import TouchUp | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_touchup_methods(app): | ||||||
|  |     assert len(TouchUp._registry) == 9 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_ode_removes_dispatch_events(app, caplog): | ||||||
|  |     with caplog.at_level(logging.DEBUG, logger="sanic.root"): | ||||||
|  |         await app._startup() | ||||||
|  |     logs = caplog.record_tuples | ||||||
|  |  | ||||||
|  |     for signal in RESERVED_NAMESPACES["http"]: | ||||||
|  |         assert ( | ||||||
|  |             "sanic.root", | ||||||
|  |             logging.DEBUG, | ||||||
|  |             f"Disabling event: {signal}", | ||||||
|  |         ) in logs | ||||||
| @@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_websocket_bp_route_name(app): | @pytest.mark.parametrize( | ||||||
|  |     "name,expected", | ||||||
|  |     ( | ||||||
|  |         ("test_route", "/bp/route"), | ||||||
|  |         ("test_route2", "/bp/route2"), | ||||||
|  |         ("foobar_3", "/bp/route3"), | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | def test_websocket_bp_route_name(app, name, expected): | ||||||
|     """Tests that blueprint websocket route is named.""" |     """Tests that blueprint websocket route is named.""" | ||||||
|     event = asyncio.Event() |     event = asyncio.Event() | ||||||
|     bp = Blueprint("test_bp", url_prefix="/bp") |     bp = Blueprint("test_bp", url_prefix="/bp") | ||||||
| @@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app): | |||||||
|     uri = app.url_for("test_bp.main") |     uri = app.url_for("test_bp.main") | ||||||
|     assert uri == "/bp/main" |     assert uri == "/bp/main" | ||||||
|  |  | ||||||
|     uri = app.url_for("test_bp.test_route") |     uri = app.url_for(f"test_bp.{name}") | ||||||
|     assert uri == "/bp/route" |     assert uri == expected | ||||||
|     request, response = SanicTestClient(app).websocket(uri) |     request, response = SanicTestClient(app).websocket(uri) | ||||||
|     assert response.opened is True |     assert response.opened is True | ||||||
|     assert event.is_set() |     assert event.is_set() | ||||||
|  |  | ||||||
|     event.clear() |  | ||||||
|     uri = app.url_for("test_bp.test_route2") |  | ||||||
|     assert uri == "/bp/route2" |  | ||||||
|     request, response = SanicTestClient(app).websocket(uri) |  | ||||||
|     assert response.opened is True |  | ||||||
|     assert event.is_set() |  | ||||||
|  |  | ||||||
|     uri = app.url_for("test_bp.foobar_3") |  | ||||||
|     assert uri == "/bp/route3" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # TODO: add test with a route with multiple hosts | # TODO: add test with a route with multiple hosts | ||||||
| # TODO: add test with a route with _host in url_for | # TODO: add test with a route with _host in url_for | ||||||
|   | |||||||
| @@ -175,7 +175,7 @@ def test_worker_close(worker): | |||||||
|     worker.wsgi = mock.Mock() |     worker.wsgi = mock.Mock() | ||||||
|     conn = mock.Mock() |     conn = mock.Mock() | ||||||
|     conn.websocket = mock.Mock() |     conn.websocket = mock.Mock() | ||||||
|     conn.websocket.close_connection = mock.Mock(wraps=_a_noop) |     conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) | ||||||
|     worker.connections = set([conn]) |     worker.connections = set([conn]) | ||||||
|     worker.log = mock.Mock() |     worker.log = mock.Mock() | ||||||
|     worker.loop = loop |     worker.loop = loop | ||||||
| @@ -190,5 +190,5 @@ def test_worker_close(worker): | |||||||
|     loop.run_until_complete(_close) |     loop.run_until_complete(_close) | ||||||
|  |  | ||||||
|     assert worker.signal.stopped |     assert worker.signal.stopped | ||||||
|     assert conn.websocket.close_connection.called |     assert conn.websocket.fail_connection.called | ||||||
|     assert len(worker.servers) == 0 |     assert len(worker.servers) == 0 | ||||||
|   | |||||||
							
								
								
									
										55
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										55
									
								
								tox.ini
									
									
									
									
									
								
							| @@ -2,53 +2,28 @@ | |||||||
| envlist = py37, py38, py39, pyNightly, pypy37, {py37,py38,py39,pyNightly,pypy37}-no-ext, lint, check, security, docs, type-checking | envlist = py37, py38, py39, pyNightly, pypy37, {py37,py38,py39,pyNightly,pypy37}-no-ext, lint, check, security, docs, type-checking | ||||||
|  |  | ||||||
| [testenv] | [testenv] | ||||||
| usedevelop = True | usedevelop = true | ||||||
| setenv = | setenv = | ||||||
|     {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UJSON=1 |     {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UJSON=1 | ||||||
|     {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 |     {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 | ||||||
| deps = | extras = test | ||||||
|     sanic-testing>=0.6.0 |  | ||||||
|     coverage==5.3 |  | ||||||
|     pytest==5.2.1 |  | ||||||
|     pytest-cov |  | ||||||
|     pytest-sanic |  | ||||||
|     pytest-sugar |  | ||||||
|     pytest-benchmark |  | ||||||
|     chardet==3.* |  | ||||||
|     beautifulsoup4 |  | ||||||
|     gunicorn==20.0.4 |  | ||||||
|     uvicorn |  | ||||||
|     websockets>=9.0 |  | ||||||
| commands = | commands = | ||||||
|     pytest {posargs:tests --cov sanic} |     pytest {posargs:tests --cov sanic} | ||||||
|     - coverage combine --append |     - coverage combine --append | ||||||
|     coverage report -m |     coverage report -m -i | ||||||
|     coverage html -i |     coverage html -i | ||||||
|  |  | ||||||
| [testenv:lint] | [testenv:lint] | ||||||
| deps = |  | ||||||
|     flake8 |  | ||||||
|     black |  | ||||||
|     isort>=5.0.0 |  | ||||||
|     bandit |  | ||||||
|  |  | ||||||
| commands = | commands = | ||||||
|     flake8 sanic |     flake8 sanic | ||||||
|     black --config ./.black.toml --check --verbose sanic/ |     black --config ./.black.toml --check --verbose sanic/ | ||||||
|     isort --check-only sanic --profile=black |     isort --check-only sanic --profile=black | ||||||
|  |  | ||||||
| [testenv:type-checking] | [testenv:type-checking] | ||||||
| deps = |  | ||||||
|     mypy>=0.901 |  | ||||||
|     types-ujson |  | ||||||
|  |  | ||||||
| commands = | commands = | ||||||
|     mypy sanic |     mypy sanic | ||||||
|  |  | ||||||
| [testenv:check] | [testenv:check] | ||||||
| deps = |  | ||||||
|     docutils |  | ||||||
|     pygments |  | ||||||
| commands = | commands = | ||||||
|     python setup.py check -r -s |     python setup.py check -r -s | ||||||
|  |  | ||||||
| @@ -60,8 +35,6 @@ markers = | |||||||
|     asyncio |     asyncio | ||||||
|  |  | ||||||
| [testenv:security] | [testenv:security] | ||||||
| deps = |  | ||||||
|     bandit |  | ||||||
|  |  | ||||||
| commands = | commands = | ||||||
|     bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py |     bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py | ||||||
| @@ -69,30 +42,10 @@ commands = | |||||||
| [testenv:docs] | [testenv:docs] | ||||||
| platform = linux|linux2|darwin | platform = linux|linux2|darwin | ||||||
| whitelist_externals = make | whitelist_externals = make | ||||||
| deps = | extras = docs | ||||||
|     sphinx>=2.1.2 |  | ||||||
|     sphinx_rtd_theme>=0.4.3 |  | ||||||
|     recommonmark>=0.5.0 |  | ||||||
|     docutils |  | ||||||
|     pygments |  | ||||||
|     gunicorn==20.0.4 |  | ||||||
| commands = | commands = | ||||||
|     make docs-test |     make docs-test | ||||||
|  |  | ||||||
| [testenv:coverage] | [testenv:coverage] | ||||||
| usedevelop = True |  | ||||||
| deps = |  | ||||||
|     sanic-testing>=0.6.0 |  | ||||||
|     coverage==5.3 |  | ||||||
|     pytest==5.2.1 |  | ||||||
|     pytest-cov |  | ||||||
|     pytest-sanic |  | ||||||
|     pytest-sugar |  | ||||||
|     pytest-benchmark |  | ||||||
|     chardet==3.* |  | ||||||
|     beautifulsoup4 |  | ||||||
|     gunicorn==20.0.4 |  | ||||||
|     uvicorn |  | ||||||
|     websockets>=9.0 |  | ||||||
| commands = | commands = | ||||||
|     pytest tests --cov=./sanic --cov-report=xml |     pytest tests --cov=./sanic --cov-report=xml | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user