From 5308fec3543abd91541ab28624f7061888c9ff03 Mon Sep 17 00:00:00 2001 From: Can Sarigol <56863826+cansarigol3megawatt@users.noreply.github.com> Date: Mon, 2 Aug 2021 18:12:12 +0200 Subject: [PATCH 1/2] Fixed for handling exceptions of asgi app call. (#2211) @cansarigol3megawatt Thanks for looking into this and getting the quick turnaround on this. I will :cherries: pick this into the 21.6 branch and get it out a little later tonight. --- sanic/__version__.py | 2 +- sanic/asgi.py | 5 ++++- tests/test_asgi.py | 31 ++++++++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/sanic/__version__.py b/sanic/__version__.py index 79344950..74a495e2 100644 --- a/sanic/__version__.py +++ b/sanic/__version__.py @@ -1 +1 @@ -__version__ = "21.6.1" +__version__ = "21.6.2" diff --git a/sanic/asgi.py b/sanic/asgi.py index 5765a5cd..330ced5a 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -207,4 +207,7 @@ class ASGIApp: """ Handle the incoming request. """ - await self.sanic_app.handle_request(self.request) + try: + await self.sanic_app.handle_request(self.request) + except Exception as e: + await self.sanic_app.handle_exception(self.request, e) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 5be3fd26..c707c12a 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -7,7 +7,7 @@ import uvicorn from sanic import Sanic from sanic.asgi import MockTransport -from sanic.exceptions import InvalidUsage +from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.request import Request from sanic.response import json, text from sanic.websocket import WebSocketConnection @@ -346,3 +346,32 @@ async def test_content_type(app): _, response = await app.asgi_client.get("/custom") assert response.headers.get("content-type") == "somethingelse" + + +@pytest.mark.asyncio +async def test_request_handle_exception(app): + @app.get("/error-prone") + def _request(request): + raise ServiceUnavailable(message="Service unavailable") + + _, response = await app.asgi_client.get("/wrong-path") + assert response.status_code == 404 + + _, response = await app.asgi_client.get("/error-prone") + assert response.status_code == 503 + +@pytest.mark.asyncio +async def test_request_exception_suppressed_by_middleware(app): + @app.get("/error-prone") + def _request(request): + raise ServiceUnavailable(message="Service unavailable") + + @app.on_request + def forbidden(request): + raise Forbidden(message="forbidden") + + _, response = await app.asgi_client.get("/wrong-path") + assert response.status_code == 403 + + _, response = await app.asgi_client.get("/error-prone") + assert response.status_code == 403 \ No newline at end of file From bc08383acd914c380ba9ac5112c9e61284e31ab9 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sat, 2 Oct 2021 21:55:23 +0300 Subject: [PATCH 2/2] Merge in main to current-release (#2254) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove unnecessary import in test_constants.py, which also fixes an error on win (#2180) Co-authored-by: Adam Hopkins * Manually reset the buffer when streaming request body (#2183) * Remove Duplicated Dependencies and PEP 517 Support (#2173) * Remove duplicated dependencies * Specify setuptools as the tool for generating distribution (PEP 517) * Add `isort` to `dev_require` * manage all dependencies in setup.py * Execute `make pretty` * Set usedevelop to true (revert previous change) * Fix the handling of the end of a chunked request. (#2188) * Fix the handling of the end of a chunked request. * Avoid hardcoding final chunk header size. * Add some unit tests for pipeline body reading * Decode bytes for json serialization Co-authored-by: L. Kärkkäinen Co-authored-by: Adam Hopkins * Resolve regressions in exceptions (#2181) * Update sanic-routing to fix path issues plus lookahead / lookbehind support (#2178) * Update sanic-routing to fix path issues plus lookahead / lookbehind support * Update setup.py Co-authored-by: Adam Hopkins Co-authored-by: Adam Hopkins * style(app,blueprints): add some type hints (#2196) * style(app,blueprints): add some type hints * style(app): option is Any * style(blueprints): url prefix default value is ``""`` * style(app): backward compatible * style(app): backward compatible * style(blueprints): defult is None * style(app): apply code style (black) * Update some CC config (#2199) * Update README.rst * raise exception for `_static_request_handler` unknown exception; add test with custom error (#2195) Co-authored-by: n.feofanov Co-authored-by: Adam Hopkins * Change dumps to AnyStr (#2193) * HTTP tests (#2194) * Fix issues with after request handling in HTTP pipelining (#2201) * Clean up after a request is complete, before the next pipelined request. * Limit the size of request body consumed after handler has finished. * Linter error. * Add unit test re: bad headers Co-authored-by: L. Kärkkäinen Co-authored-by: Adam Hopkins Co-authored-by: Adam Hopkins * Update CHANGELOG * Log remote address if available (#2207) * Log remote address if available * Add tests * Fix testing version Co-authored-by: Adam Hopkins * Fixed for handling exceptions of asgi app call. (#2211) @cansarigol3megawatt Thanks for looking into this and getting the quick turnaround on this. I will :cherries: pick this into the 21.6 branch and get it out a little later tonight. * Signals Integration (#2160) * Update some tests * Resolve #2122 route decorator returning tuple * Use rc sanic-routing version * Update unit tests to <:str> * Minimal working version with some signals implemented * Add more http signals * Update ASGI and change listeners to signals * Allow for dynamic ODE signals * Allow signals to be stacked * Begin tests * Prioritize match_info on keyword argument injection * WIP on tests * Compat with signals * Work through some test coverage * Passing tests * Post linting * Setup proper resets * coverage reporting * Fixes from vltr comments * clear delayed tasks * Fix bad test * rm pycache * uncomment windows tests (#2214) * Add convenience methods to BP groups (#2209) * Fix bug where ws exceptions not being logged (#2213) * Fix bug where ws exceptions not being logged * Fix t\est * Style: add type hints (#2217) * style(routes): add_route argument, return typing * style(listeners): typing * style(views): typing as_view * style(routes): change type hint * style(listeners): change type hint * style(routes): change type hint * add some more types * Change as_view typing * Add some cleaner type annotations Co-authored-by: Adam Hopkins * Add default messages to SanicExceptions (#2216) * Add default messages to SanicExceptions * Cleaner exception message setting * Copy Blueprints Implementation (#2184) * Accept header parsing (#2200) * Add some tests * docstring * Add accept matching * Add some more tests on matching * Add matching flags for wildcards * Add mathing controls to accept * Limit uvicorn 14 in testing * Add convenience for annotated handlers (#2225) * Split HttpProtocol parts into base SanicProtocol and HTTPProtocol subclass (#2229) * Split HttpProtocol parts into base SanicProtocol and HTTPProtocol subclass. * lint fixes * re-black server.py * Move server.py into its own module (#2230) * Move server.py into its own module * Change monkeypatch path on test_logging.py * Blueprint specific exception handlers (#2208) * Call abort() on sockets after close() to prevent dangling sockets (#2231) * Add ability to return Falsey but not-None from handlers (#2236) * Adds Blueprint Group exception decorator (#2238) * Add exception decorator * Added tests * Fix line too long * Static DIR and FILE resource types (#2244) * Explicit static directive for serving file or dir Co-authored-by: anbuhckr <36891836+anbuhckr@users.noreply.github.com> Co-authored-by: anbuhckr * Close HTTP loop when connection task cancelled (#2245) * Terminate loop when no transport exists * Add log when closing HTTP loop because of shutdown * Add unit test * New websockets (#2158) * First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work. * Update sanic/websocket.py Co-authored-by: Adam Hopkins * Update sanic/websocket.py Co-authored-by: Adam Hopkins * Update sanic/websocket.py Co-authored-by: Adam Hopkins * wip, update websockets code to new Sans/IO API * Refactored new websockets impl into own modules Incorporated other suggestions made by team * Another round of work on the new websockets impl * Added websocket_timeout support (matching previous/legacy support) * Lots more comments * Incorporated suggested changes from previous round of review * Changed RuntimeError usage to ServerError * Changed SanicException usage to ServerError * Removed some redundant asserts * Change remaining asserts to ServerErrors * Fixed some timeout handling issues * Fixed websocket.close() handling, and made it more robust * Made auto_close task smarter and more error-resilient * Made fail_connection routine smarter and more error-resilient * Further new websockets impl fixes * Update compatibility with Websockets v10 * Track server connection state in a more precise way * Try to handle the shutdown process more gracefully * Add a new end_connection() helper, to use as an alterative to close() or fail_connection() * Kill the auto-close task and keepalive-timeout task when sanic is shutdown * Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation. * Change a warning message to debug level Remove default values for deprecated websocket parameters * Fix flake8 errors * Fix a couple of missed failing tests * remove websocket bench from examples * Integrate suggestions from code reviews Use Optional[T] instead of union[T,None] Fix mypy type logic errors change "is not None" to truthy checks where appropriate change "is None" to falsy checks were appropriate Add more debug logging when debug mode is on Change to using sanic.logger for debug logging rather than error_logger. * Fix long line lengths of debug messages Add some new debug messages when websocket IO is paused and unpaused for flow control Fix websocket example to use app.static() * remove unused import in websocket example app * re-run isort after Flake8 fixes Co-authored-by: Adam Hopkins Co-authored-by: Adam Hopkins * Account for BP with exception handler but no routes (#2246) * Don't log "enabled" if auto-reload disabled (#2247) Fixes #2240 Co-authored-by: Adam Hopkins * Smarter auto fallback (#2162) * Smarter auto fallback * remove config from blueprints * Add tests for error formatting * Add check for proper format * Fix some tests * Add some tests * docstring * Add accept matching * Add some more tests on matching * Fix contains bug, earlier return on MediaType eq * Add matching flags for wildcards * Add mathing controls to accept * Cleanup dev cruft * Add cleanup and resolve OSError relating to test implementation * Fix test * Fix some typos * Some fixes to the new Websockets impl (#2248) * First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work. * Update sanic/websocket.py Co-authored-by: Adam Hopkins * Update sanic/websocket.py Co-authored-by: Adam Hopkins * Update sanic/websocket.py Co-authored-by: Adam Hopkins * wip, update websockets code to new Sans/IO API * Refactored new websockets impl into own modules Incorporated other suggestions made by team * Another round of work on the new websockets impl * Added websocket_timeout support (matching previous/legacy support) * Lots more comments * Incorporated suggested changes from previous round of review * Changed RuntimeError usage to ServerError * Changed SanicException usage to ServerError * Removed some redundant asserts * Change remaining asserts to ServerErrors * Fixed some timeout handling issues * Fixed websocket.close() handling, and made it more robust * Made auto_close task smarter and more error-resilient * Made fail_connection routine smarter and more error-resilient * Further new websockets impl fixes * Update compatibility with Websockets v10 * Track server connection state in a more precise way * Try to handle the shutdown process more gracefully * Add a new end_connection() helper, to use as an alterative to close() or fail_connection() * Kill the auto-close task and keepalive-timeout task when sanic is shutdown * Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation. * Change a warning message to debug level Remove default values for deprecated websocket parameters * Fix flake8 errors * Fix a couple of missed failing tests * remove websocket bench from examples * Integrate suggestions from code reviews Use Optional[T] instead of union[T,None] Fix mypy type logic errors change "is not None" to truthy checks where appropriate change "is None" to falsy checks were appropriate Add more debug logging when debug mode is on Change to using sanic.logger for debug logging rather than error_logger. * Fix long line lengths of debug messages Add some new debug messages when websocket IO is paused and unpaused for flow control Fix websocket example to use app.static() * remove unused import in websocket example app * re-run isort after Flake8 fixes * Some fixes to the new Websockets impl Will throw WebsocketClosed exception instead of ServerException now when attempting to read or write to closed websocket, this makes it easier to catch The various ws.recv() methods now have the ability to raise CancelledError into your websocket handler Fix a niche close-socket negotiation bug Fix bug where http protocol thought the websocket never sent any response. Allow data to still send in some cases after websocket enters CLOSING state. Fix some badly formatted and badly placed comments * allow eof_received to send back data too, if the connection is in CLOSING state Co-authored-by: Adam Hopkins Co-authored-by: Adam Hopkins * 21.9 release docs (#2218) * Beging 21.9 release docs * Add PRs to changelog * Change deprecation version * Update logging tests * Bump version * Update changelog * Change dev install command (#2251) Co-authored-by: Zhiwei Co-authored-by: L. Kärkkäinen <98187+Tronic@users.noreply.github.com> Co-authored-by: L. Kärkkäinen Co-authored-by: Robert Palmer Co-authored-by: Ryu JuHeon Co-authored-by: gluhar2006 <49654448+gluhar2006@users.noreply.github.com> Co-authored-by: n.feofanov Co-authored-by: Néstor Pérez <25409753+prryplatypus@users.noreply.github.com> Co-authored-by: Can Sarigol <56863826+cansarigol3megawatt@users.noreply.github.com> Co-authored-by: Zhiwei Co-authored-by: YongChan Cho Co-authored-by: Zhiwei Co-authored-by: Ashley Sommer Co-authored-by: anbuhckr <36891836+anbuhckr@users.noreply.github.com> Co-authored-by: anbuhckr --- .codeclimate.yml | 12 + .github/workflows/pr-windows.yml | 62 +- CHANGELOG.rst | 19 + CONTRIBUTING.rst | 2 +- README.rst | 12 +- docs/conf.py | 20 +- docs/sanic/changelog.rst | 2 + docs/sanic/contributing.rst | 2 +- docs/sanic/releases/21.9.md | 40 + examples/run_async_advanced.py | 29 +- examples/websocket.py | 9 +- pyproject.toml | 3 + sanic/__version__.py | 2 +- sanic/app.py | 315 +++++-- sanic/asgi.py | 41 +- sanic/base.py | 2 +- sanic/blueprint_group.py | 34 + sanic/blueprints.py | 111 ++- sanic/config.py | 16 +- sanic/errorpages.py | 123 ++- sanic/exceptions.py | 24 +- sanic/handlers.py | 82 +- sanic/headers.py | 200 +++++ sanic/helpers.py | 14 + sanic/http.py | 47 +- sanic/mixins/listeners.py | 27 +- sanic/mixins/routes.py | 139 +++- sanic/mixins/signals.py | 4 +- sanic/models/asgi.py | 2 +- sanic/models/futures.py | 2 + sanic/models/handler_types.py | 2 +- sanic/models/server_types.py | 52 ++ sanic/request.py | 15 + sanic/router.py | 34 + sanic/server.py | 793 ------------------ sanic/server/__init__.py | 26 + sanic/server/async_server.py | 115 +++ sanic/server/events.py | 16 + sanic/server/protocols/__init__.py | 0 sanic/server/protocols/base_protocol.py | 143 ++++ sanic/server/protocols/http_protocol.py | 238 ++++++ sanic/server/protocols/websocket_protocol.py | 164 ++++ sanic/server/runners.py | 280 +++++++ sanic/server/socket.py | 87 ++ sanic/server/websockets/__init__.py | 0 sanic/server/websockets/connection.py | 82 ++ sanic/server/websockets/frame.py | 294 +++++++ sanic/server/websockets/impl.py | 834 +++++++++++++++++++ sanic/signals.py | 97 ++- sanic/touchup/__init__.py | 8 + sanic/touchup/meta.py | 22 + sanic/touchup/schemes/__init__.py | 5 + sanic/touchup/schemes/base.py | 20 + sanic/touchup/schemes/ode.py | 67 ++ sanic/touchup/service.py | 33 + sanic/views.py | 5 +- sanic/websocket.py | 205 ----- sanic/worker.py | 42 +- setup.py | 33 +- tests/conftest.py | 57 +- tests/test_app.py | 10 +- tests/test_asgi.py | 5 +- tests/test_bad_request.py | 2 +- tests/test_blueprint_copy.py | 70 ++ tests/test_blueprint_group.py | 59 +- tests/test_blueprints.py | 55 +- tests/test_cli.py | 6 +- tests/test_config.py | 2 +- tests/test_constants.py | 6 +- tests/test_create_task.py | 21 +- tests/test_errorpages.py | 201 ++++- tests/test_exceptions.py | 40 + tests/test_exceptions_handler.py | 191 +++-- tests/test_graceful_shutdown.py | 46 + tests/test_handler_annotations.py | 39 + tests/test_headers.py | 187 ++++- tests/test_http.py | 137 +++ tests/test_keep_alive_timeout.py | 230 +---- tests/test_logging.py | 63 +- tests/test_logo.py | 80 +- tests/test_middleware.py | 32 +- tests/test_request.py | 36 + tests/test_request_timeout.py | 56 +- tests/test_routes.py | 47 +- tests/test_server_events.py | 31 +- tests/test_signal_handlers.py | 2 +- tests/test_signals.py | 4 +- tests/test_static.py | 69 ++ tests/test_touchup.py | 21 + tests/test_url_for.py | 24 +- tests/test_worker.py | 4 +- tox.ini | 55 +- 92 files changed, 5109 insertions(+), 1888 deletions(-) create mode 100644 docs/sanic/releases/21.9.md create mode 100644 pyproject.toml create mode 100644 sanic/models/server_types.py delete mode 100644 sanic/server.py create mode 100644 sanic/server/__init__.py create mode 100644 sanic/server/async_server.py create mode 100644 sanic/server/events.py create mode 100644 sanic/server/protocols/__init__.py create mode 100644 sanic/server/protocols/base_protocol.py create mode 100644 sanic/server/protocols/http_protocol.py create mode 100644 sanic/server/protocols/websocket_protocol.py create mode 100644 sanic/server/runners.py create mode 100644 sanic/server/socket.py create mode 100644 sanic/server/websockets/__init__.py create mode 100644 sanic/server/websockets/connection.py create mode 100644 sanic/server/websockets/frame.py create mode 100644 sanic/server/websockets/impl.py create mode 100644 sanic/touchup/__init__.py create mode 100644 sanic/touchup/meta.py create mode 100644 sanic/touchup/schemes/__init__.py create mode 100644 sanic/touchup/schemes/base.py create mode 100644 sanic/touchup/schemes/ode.py create mode 100644 sanic/touchup/service.py delete mode 100644 sanic/websocket.py create mode 100644 tests/test_blueprint_copy.py create mode 100644 tests/test_graceful_shutdown.py create mode 100644 tests/test_handler_annotations.py create mode 100644 tests/test_http.py create mode 100644 tests/test_touchup.py diff --git a/.codeclimate.yml b/.codeclimate.yml index 506005ac..08c10e26 100644 --- a/.codeclimate.yml +++ b/.codeclimate.yml @@ -10,3 +10,15 @@ exclude_patterns: - "examples/" - "hack/" - "scripts/" + - "tests/" +checks: + argument-count: + enabled: false + file-lines: + config: + threshold: 1000 + method-count: + config: + threshold: 40 + complex-logic: + enabled: false diff --git a/.github/workflows/pr-windows.yml b/.github/workflows/pr-windows.yml index a47a5006..e3a32e5d 100644 --- a/.github/workflows/pr-windows.yml +++ b/.github/workflows/pr-windows.yml @@ -1,34 +1,34 @@ -# name: Run Unit Tests on Windows -# on: -# pull_request: -# branches: -# - main +name: Run Unit Tests on Windows +on: + pull_request: + branches: + - main -# jobs: -# testsOnWindows: -# name: ut-${{ matrix.config.tox-env }} -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# config: -# - { python-version: 3.7, tox-env: py37-no-ext } -# - { python-version: 3.8, tox-env: py38-no-ext } -# - { python-version: 3.9, tox-env: py39-no-ext } -# - { python-version: pypy-3.7, tox-env: pypy37-no-ext } +jobs: + testsOnWindows: + name: ut-${{ matrix.config.tox-env }} + runs-on: windows-latest + strategy: + fail-fast: false + matrix: + config: + - { python-version: 3.7, tox-env: py37-no-ext } + - { python-version: 3.8, tox-env: py38-no-ext } + - { python-version: 3.9, tox-env: py39-no-ext } + - { python-version: pypy-3.7, tox-env: pypy37-no-ext } -# steps: -# - name: Checkout Repository -# uses: actions/checkout@v2 + steps: + - name: Checkout Repository + uses: actions/checkout@v2 -# - name: Run Unit Tests -# uses: ahopkins/custom-actions@pip-extra-args -# with: -# python-version: ${{ matrix.config.python-version }} -# test-infra-tool: tox -# test-infra-version: latest -# action: tests -# test-additional-args: "-e=${{ matrix.config.tox-env }}" -# experimental-ignore-error: "true" -# command-timeout: "600000" -# pip-extra-args: "--user" + - name: Run Unit Tests + uses: ahopkins/custom-actions@pip-extra-args + with: + python-version: ${{ matrix.config.python-version }} + test-infra-tool: tox + test-infra-version: latest + action: tests + test-additional-args: "-e=${{ matrix.config.tox-env }}" + experimental-ignore-error: "true" + command-timeout: "600000" + pip-extra-args: "--user" diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5e99452b..a9940da1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,22 @@ +.. note:: + + From v21.9, CHANGELOG files are maintained in ``./docs/sanic/releases`` + +Version 21.6.1 +-------------- + +Bugfixes +******** + + * `#2178 `_ + Update sanic-routing to allow for better splitting of complex URI templates + * `#2183 `_ + Proper handling of chunked request bodies to resolve phantom 503 in logs + * `#2181 `_ + Resolve regression in exception logging + * `#2201 `_ + Cleanup request info in pipelined requests + Version 21.6.0 -------------- diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index c87f2355..74dee22f 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -19,7 +19,7 @@ a virtual environment already set up, then run: .. code-block:: bash - pip3 install -e . ".[dev]" + pip install -e ".[dev]" Dependency Changes ------------------ diff --git a/README.rst b/README.rst index c3623bb8..c6616f16 100644 --- a/README.rst +++ b/README.rst @@ -77,17 +77,7 @@ The goal of the project is to provide a simple way to get up and running a highl Sponsor ------- -|Try CodeStream| - -.. |Try CodeStream| image:: https://alt-images.codestream.com/codestream_logo_sanicorg.png - :target: https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner - :alt: Try CodeStream - -Manage pull requests and conduct code reviews in your IDE with full source-tree context. Comment on any line, not just the diffs. Use jump-to-definition, your favorite keybindings, and code intelligence with more of your workflow. - -`Learn More `_ - -Thank you to our sponsor. Check out `open collective `_ to learn more about helping to fund Sanic. +Check out `open collective `_ to learn more about helping to fund Sanic. Installation ------------ diff --git a/docs/conf.py b/docs/conf.py index 62f6ae4a..30a01e4c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,10 +10,8 @@ import os import sys -# Add support for auto-doc -import recommonmark -from recommonmark.transform import AutoStructify +# Add support for auto-doc # Ensure that sanic is present in the path, to allow sphinx-apidoc to @@ -26,7 +24,7 @@ import sanic # -- General configuration ------------------------------------------------ -extensions = ["sphinx.ext.autodoc", "recommonmark"] +extensions = ["sphinx.ext.autodoc", "m2r2"] templates_path = ["_templates"] @@ -162,20 +160,6 @@ autodoc_default_options = { "member-order": "groupwise", } - -# app setup hook -def setup(app): - app.add_config_value( - "recommonmark_config", - { - "enable_eval_rst": True, - "enable_auto_doc_ref": False, - }, - True, - ) - app.add_transform(AutoStructify) - - html_theme_options = { "style_external_links": False, } diff --git a/docs/sanic/changelog.rst b/docs/sanic/changelog.rst index fb389e05..516b8587 100644 --- a/docs/sanic/changelog.rst +++ b/docs/sanic/changelog.rst @@ -1,4 +1,6 @@ 📜 Changelog ============ +.. mdinclude:: ./releases/21.9.md + .. include:: ../../CHANGELOG.rst diff --git a/docs/sanic/contributing.rst b/docs/sanic/contributing.rst index 91cfd11e..5d21caa2 100644 --- a/docs/sanic/contributing.rst +++ b/docs/sanic/contributing.rst @@ -1,4 +1,4 @@ ♥️ Contributing -============== +=============== .. include:: ../../CONTRIBUTING.rst diff --git a/docs/sanic/releases/21.9.md b/docs/sanic/releases/21.9.md new file mode 100644 index 00000000..8900340d --- /dev/null +++ b/docs/sanic/releases/21.9.md @@ -0,0 +1,40 @@ +## Version 21.9 + +### Features +- [#2158](https://github.com/sanic-org/sanic/pull/2158), [#2248](https://github.com/sanic-org/sanic/pull/2248) Complete overhaul of I/O to websockets +- [#2160](https://github.com/sanic-org/sanic/pull/2160) Add new 17 signals into server and request lifecycles +- [#2162](https://github.com/sanic-org/sanic/pull/2162) Smarter `auto` fallback formatting upon exception +- [#2184](https://github.com/sanic-org/sanic/pull/2184) Introduce implementation for copying a Blueprint +- [#2200](https://github.com/sanic-org/sanic/pull/2200) Accept header parsing +- [#2207](https://github.com/sanic-org/sanic/pull/2207) Log remote address if available +- [#2209](https://github.com/sanic-org/sanic/pull/2209) Add convenience methods to BP groups +- [#2216](https://github.com/sanic-org/sanic/pull/2216) Add default messages to SanicExceptions +- [#2225](https://github.com/sanic-org/sanic/pull/2225) Type annotation convenience for annotated handlers with path parameters +- [#2236](https://github.com/sanic-org/sanic/pull/2236) Allow Falsey (but not-None) responses from route handlers +- [#2238](https://github.com/sanic-org/sanic/pull/2238) Add `exception` decorator to Blueprint Groups +- [#2244](https://github.com/sanic-org/sanic/pull/2244) Explicit static directive for serving file or dir (ex: `static(..., resource_type="file")`) +- [#2245](https://github.com/sanic-org/sanic/pull/2245) Close HTTP loop when connection task cancelled + +### Bugfixes +- [#2188](https://github.com/sanic-org/sanic/pull/2188) Fix the handling of the end of a chunked request +- [#2195](https://github.com/sanic-org/sanic/pull/2195) Resolve unexpected error handling on static requests +- [#2208](https://github.com/sanic-org/sanic/pull/2208) Make blueprint-based exceptions attach and trigger in a more intuitive manner +- [#2211](https://github.com/sanic-org/sanic/pull/2211) Fixed for handling exceptions of asgi app call +- [#2213](https://github.com/sanic-org/sanic/pull/2213) Fix bug where ws exceptions not being logged +- [#2231](https://github.com/sanic-org/sanic/pull/2231) Cleaner closing of tasks by using `abort()` in strategic places to avoid dangling sockets +- [#2247](https://github.com/sanic-org/sanic/pull/2247) Fix logging of auto-reload status in debug mode +- [#2246](https://github.com/sanic-org/sanic/pull/2246) Account for BP with exception handler but no routes + +### Developer infrastructure +- [#2194](https://github.com/sanic-org/sanic/pull/2194) HTTP unit tests with raw client +- [#2199](https://github.com/sanic-org/sanic/pull/2199) Switch to codeclimate +- [#2214](https://github.com/sanic-org/sanic/pull/2214) Try Reopening Windows Tests +- [#2229](https://github.com/sanic-org/sanic/pull/2229) Refactor `HttpProtocol` into a base class +- [#2230](https://github.com/sanic-org/sanic/pull/2230) Refactor `server.py` into multi-file module + +### Miscellaneous +- [#2173](https://github.com/sanic-org/sanic/pull/2173) Remove Duplicated Dependencies and PEP 517 Support +- [#2193](https://github.com/sanic-org/sanic/pull/2193), [#2196](https://github.com/sanic-org/sanic/pull/2196), [#2217](https://github.com/sanic-org/sanic/pull/2217) Type annotation changes + + + diff --git a/examples/run_async_advanced.py b/examples/run_async_advanced.py index 36027c2f..27f86f3f 100644 --- a/examples/run_async_advanced.py +++ b/examples/run_async_advanced.py @@ -1,29 +1,44 @@ -from sanic import Sanic -from sanic import response -from signal import signal, SIGINT import asyncio + +from signal import SIGINT, signal + import uvloop +from sanic import Sanic, response +from sanic.server import AsyncioServer + + app = Sanic(__name__) -@app.listener('after_server_start') + +@app.listener("after_server_start") async def after_start_test(app, loop): print("Async Server Started!") + @app.route("/") async def test(request): return response.json({"answer": "42"}) + asyncio.set_event_loop(uvloop.new_event_loop()) -serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True) +serv_coro = app.create_server( + host="0.0.0.0", port=8000, return_asyncio_server=True +) loop = asyncio.get_event_loop() serv_task = asyncio.ensure_future(serv_coro, loop=loop) signal(SIGINT, lambda s, f: loop.stop()) -server = loop.run_until_complete(serv_task) +server: AsyncioServer = loop.run_until_complete(serv_task) # type: ignore +server.startup() + +# When using app.run(), this actually triggers before the serv_coro. +# But, in this example, we are using the convenience method, even if it is +# out of order. +server.before_start() server.after_start() try: loop.run_forever() -except KeyboardInterrupt as e: +except KeyboardInterrupt: loop.stop() finally: server.before_stop() diff --git a/examples/websocket.py b/examples/websocket.py index 9cba083c..92f71375 100644 --- a/examples/websocket.py +++ b/examples/websocket.py @@ -1,13 +1,14 @@ from sanic import Sanic -from sanic.response import file +from sanic.response import redirect app = Sanic(__name__) -@app.route('/') -async def index(request): - return await file('websocket.html') +app.static('index.html', "websocket.html") +@app.route('/') +def index(request): + return redirect("index.html") @app.websocket('/feed') async def feed(request, ws): diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..9787c3bd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/sanic/__version__.py b/sanic/__version__.py index 74a495e2..32566438 100644 --- a/sanic/__version__.py +++ b/sanic/__version__.py @@ -1 +1 @@ -__version__ = "21.6.2" +__version__ = "21.9.0" diff --git a/sanic/app.py b/sanic/app.py index ec9027b5..01aa07cb 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import logging import logging.config import os import re from asyncio import ( + AbstractEventLoop, CancelledError, Protocol, ensure_future, @@ -21,6 +24,7 @@ from traceback import format_exc from types import SimpleNamespace from typing import ( Any, + AnyStr, Awaitable, Callable, Coroutine, @@ -30,6 +34,7 @@ from typing import ( List, Optional, Set, + Tuple, Type, Union, ) @@ -69,20 +74,29 @@ from sanic.router import Router from sanic.server import AsyncioServer, HttpProtocol from sanic.server import Signal as ServerSignal from sanic.server import serve, serve_multiple, serve_single +from sanic.server.protocols.websocket_protocol import WebSocketProtocol +from sanic.server.websockets.impl import ConnectionClosed from sanic.signals import Signal, SignalRouter -from sanic.websocket import ConnectionClosed, WebSocketProtocol +from sanic.touchup import TouchUp, TouchUpMeta -class Sanic(BaseSanic): +class Sanic(BaseSanic, metaclass=TouchUpMeta): """ The main application instance """ + __touchup__ = ( + "handle_request", + "handle_exception", + "_run_response_middleware", + "_run_request_middleware", + ) __fake_slots__ = ( "_asgi_app", "_app_registry", "_asgi_client", "_blueprint_order", + "_delayed_tasks", "_future_routes", "_future_statics", "_future_middleware", @@ -137,7 +151,7 @@ class Sanic(BaseSanic): log_config: Optional[Dict[str, Any]] = None, configure_logging: bool = True, register: Optional[bool] = None, - dumps: Optional[Callable[..., str]] = None, + dumps: Optional[Callable[..., AnyStr]] = None, ) -> None: super().__init__(name=name) @@ -153,6 +167,7 @@ class Sanic(BaseSanic): self._asgi_client = None self._blueprint_order: List[Blueprint] = [] + self._delayed_tasks: List[str] = [] self._test_client = None self._test_manager = None self.asgi = False @@ -164,7 +179,9 @@ class Sanic(BaseSanic): self.configure_logging = configure_logging self.ctx = ctx or SimpleNamespace() self.debug = None - self.error_handler = error_handler or ErrorHandler() + self.error_handler = error_handler or ErrorHandler( + fallback=self.config.FALLBACK_ERROR_FORMAT, + ) self.is_running = False self.is_stopping = False self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) @@ -190,9 +207,10 @@ class Sanic(BaseSanic): self.__class__.register_app(self) self.router.ctx.app = self + self.signal_router.ctx.app = self if dumps: - BaseHTTPResponse._dumps = dumps + BaseHTTPResponse._dumps = dumps # type: ignore @property def loop(self): @@ -230,9 +248,12 @@ class Sanic(BaseSanic): loop = self.loop # Will raise SanicError if loop is not started self._loop_add_task(task, self, loop) except SanicException: - self.listener("before_server_start")( - partial(self._loop_add_task, task) - ) + task_name = f"sanic.delayed_task.{hash(task)}" + if not self._delayed_tasks: + self.after_server_start(partial(self.dispatch_delayed_tasks)) + + self.signal(task_name)(partial(self.run_delayed_task, task=task)) + self._delayed_tasks.append(task_name) def register_listener(self, listener: Callable, event: str) -> Any: """ @@ -244,12 +265,20 @@ class Sanic(BaseSanic): """ try: - _event = ListenerEvent(event) - except ValueError: - valid = ", ".join(ListenerEvent.__members__.values()) + _event = ListenerEvent[event.upper()] + except (ValueError, AttributeError): + valid = ", ".join( + map(lambda x: x.lower(), ListenerEvent.__members__.keys()) + ) raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") - self.listeners[_event].append(listener) + if "." in _event: + self.signal(_event.value)( + partial(self._listener, listener=listener) + ) + else: + self.listeners[_event.value].append(listener) + return listener def register_middleware(self, middleware, attach_to: str = "request"): @@ -308,7 +337,11 @@ class Sanic(BaseSanic): self.named_response_middleware[_rn].appendleft(middleware) return middleware - def _apply_exception_handler(self, handler: FutureException): + def _apply_exception_handler( + self, + handler: FutureException, + route_names: Optional[List[str]] = None, + ): """Decorate a function to be registered as a handler for exceptions :param exceptions: exceptions @@ -318,9 +351,9 @@ class Sanic(BaseSanic): for exception in handler.exceptions: if isinstance(exception, (tuple, list)): for e in exception: - self.error_handler.add(e, handler.handler) + self.error_handler.add(e, handler.handler, route_names) else: - self.error_handler.add(exception, handler.handler) + self.error_handler.add(exception, handler.handler, route_names) return handler.handler def _apply_listener(self, listener: FutureListener): @@ -377,11 +410,17 @@ class Sanic(BaseSanic): *, condition: Optional[Dict[str, str]] = None, context: Optional[Dict[str, Any]] = None, + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, ) -> Coroutine[Any, Any, Awaitable[Any]]: return self.signal_router.dispatch( event, context=context, condition=condition, + inline=inline, + reverse=reverse, + fail_not_found=fail_not_found, ) async def event( @@ -411,7 +450,13 @@ class Sanic(BaseSanic): self.websocket_enabled = enable - def blueprint(self, blueprint, **options): + def blueprint( + self, + blueprint: Union[ + Blueprint, List[Blueprint], Tuple[Blueprint], BlueprintGroup + ], + **options: Any, + ): """Register a blueprint on the application. :param blueprint: Blueprint object or (list, tuple) thereof @@ -651,7 +696,7 @@ class Sanic(BaseSanic): async def handle_exception( self, request: Request, exception: BaseException - ): + ): # no cov """ A handler that catches specific exceptions and outputs a response. @@ -661,6 +706,12 @@ class Sanic(BaseSanic): :type exception: BaseException :raises ServerError: response 500 """ + await self.dispatch( + "http.lifecycle.exception", + inline=True, + context={"request": request, "exception": exception}, + ) + # -------------------------------------------- # # Request Middleware # -------------------------------------------- # @@ -707,7 +758,7 @@ class Sanic(BaseSanic): f"Invalid response type {response!r} (need HTTPResponse)" ) - async def handle_request(self, request: Request): + async def handle_request(self, request: Request): # no cov """Take a request from the HTTP Server and return a response object to be sent back The HTTP Server only expects a response object, so exception handling must be done here @@ -715,10 +766,22 @@ class Sanic(BaseSanic): :param request: HTTP Request object :return: Nothing """ + await self.dispatch( + "http.lifecycle.handle", + inline=True, + context={"request": request}, + ) + # Define `response` var here to remove warnings about # allocation before assignment below. response = None try: + + await self.dispatch( + "http.routing.before", + inline=True, + context={"request": request}, + ) # Fetch handler from router route, handler, kwargs = self.router.get( request.path, @@ -726,19 +789,29 @@ class Sanic(BaseSanic): request.headers.getone("host", None), ) - request._match_info = kwargs + request._match_info = {**kwargs} request.route = route + await self.dispatch( + "http.routing.after", + inline=True, + context={ + "request": request, + "route": route, + "kwargs": kwargs, + "handler": handler, + }, + ) + if ( - request.stream.request_body # type: ignore + request.stream + and request.stream.request_body and not route.ctx.ignore_body ): if hasattr(handler, "is_stream"): # Streaming handler: lift the size limit - request.stream.request_max_size = float( # type: ignore - "inf" - ) + request.stream.request_max_size = float("inf") else: # Non-streaming handler: preload body await request.receive_body() @@ -765,17 +838,25 @@ class Sanic(BaseSanic): ) # Run response handler - response = handler(request, **kwargs) + response = handler(request, **request.match_info) if isawaitable(response): response = await response - if response: + if response is not None: response = await request.respond(response) elif not hasattr(handler, "is_websocket"): response = request.stream.response # type: ignore # Make sure that response is finished / run StreamingHTTP callback if isinstance(response, BaseHTTPResponse): + await self.dispatch( + "http.lifecycle.response", + inline=True, + context={ + "request": request, + "response": response, + }, + ) await response.send(end_stream=True) else: if not hasattr(handler, "is_websocket"): @@ -793,23 +874,11 @@ class Sanic(BaseSanic): async def _websocket_handler( self, handler, request, *args, subprotocols=None, **kwargs ): - request.app = self - if not getattr(handler, "__blueprintname__", False): - request._name = handler.__name__ - else: - request._name = ( - getattr(handler, "__blueprintname__", "") + handler.__name__ - ) - - pass - if self.asgi: ws = request.transport.get_websocket_connection() await ws.accept(subprotocols) else: protocol = request.transport.get_protocol() - protocol.app = self - ws = await protocol.websocket_handshake(request, subprotocols) # schedule the application handler @@ -817,13 +886,19 @@ class Sanic(BaseSanic): # needs to be cancelled due to the server being stopped fut = ensure_future(handler(request, ws, *args, **kwargs)) self.websocket_tasks.add(fut) + cancelled = False try: await fut + except Exception as e: + self.error_handler.log(request, e) except (CancelledError, ConnectionClosed): - pass + cancelled = True finally: self.websocket_tasks.remove(fut) - await ws.close() + if cancelled: + ws.end_connection(1000) + else: + await ws.close() # -------------------------------------------------------------------- # # Testing @@ -869,7 +944,7 @@ class Sanic(BaseSanic): *, debug: bool = False, auto_reload: Optional[bool] = None, - ssl: Union[dict, SSLContext, None] = None, + ssl: Union[Dict[str, str], SSLContext, None] = None, sock: Optional[socket] = None, workers: int = 1, protocol: Optional[Type[Protocol]] = None, @@ -999,7 +1074,7 @@ class Sanic(BaseSanic): port: Optional[int] = None, *, debug: bool = False, - ssl: Union[dict, SSLContext, None] = None, + ssl: Union[Dict[str, str], SSLContext, None] = None, sock: Optional[socket] = None, protocol: Type[Protocol] = None, backlog: int = 100, @@ -1071,11 +1146,6 @@ class Sanic(BaseSanic): run_async=return_asyncio_server, ) - # Trigger before_start events - await self.trigger_events( - server_settings.get("before_start", []), - server_settings.get("loop"), - ) main_start = server_settings.pop("main_start", None) main_stop = server_settings.pop("main_stop", None) if main_start or main_stop: @@ -1088,17 +1158,9 @@ class Sanic(BaseSanic): asyncio_server_kwargs=asyncio_server_kwargs, **server_settings ) - async def trigger_events(self, events, loop): - """Trigger events (functions or async) - :param events: one or more sync or async functions to execute - :param loop: event loop - """ - for event in events: - result = event(loop) - if isawaitable(result): - await result - - async def _run_request_middleware(self, request, request_name=None): + async def _run_request_middleware( + self, request, request_name=None + ): # no cov # The if improves speed. I don't know why named_middleware = self.named_request_middleware.get( request_name, deque() @@ -1111,25 +1173,67 @@ class Sanic(BaseSanic): request.request_middleware_started = True for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + response = middleware(request) if isawaitable(response): response = await response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + if response: return response return None async def _run_response_middleware( self, request, response, request_name=None - ): + ): # no cov named_middleware = self.named_response_middleware.get( request_name, deque() ) applicable_middleware = self.response_middleware + named_middleware if applicable_middleware: for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": response, + }, + condition={"attach_to": "response"}, + ) + _response = middleware(request, response) if isawaitable(_response): _response = await _response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": _response if _response else response, + }, + condition={"attach_to": "response"}, + ) + if _response: response = _response if isinstance(response, BaseHTTPResponse): @@ -1155,10 +1259,6 @@ class Sanic(BaseSanic): ): """Helper function used by `run` and `create_server`.""" - self.listeners["before_server_start"] = [ - self.finalize - ] + self.listeners["before_server_start"] - if isinstance(ssl, dict): # try common aliaseses cert = ssl.get("cert") or ssl.get("certificate") @@ -1195,10 +1295,6 @@ class Sanic(BaseSanic): # Register start/stop events for event_name, settings_name, reverse in ( - ("before_server_start", "before_start", False), - ("after_server_start", "after_start", False), - ("before_server_stop", "before_stop", True), - ("after_server_stop", "after_stop", True), ("main_process_start", "main_start", False), ("main_process_stop", "main_stop", True), ): @@ -1236,7 +1332,8 @@ class Sanic(BaseSanic): logger.info(f"Goin' Fast @ {proto}://{host}:{port}") debug_mode = "enabled" if self.debug else "disabled" - logger.debug("Sanic auto-reload: enabled") + reload_mode = "enabled" if auto_reload else "disabled" + logger.debug(f"Sanic auto-reload: {reload_mode}") logger.debug(f"Sanic debug mode: {debug_mode}") return server_settings @@ -1246,20 +1343,44 @@ class Sanic(BaseSanic): return ".".join(parts) @classmethod - def _loop_add_task(cls, task, app, loop): + def _prep_task(cls, task, app, loop): if callable(task): try: - loop.create_task(task(app)) + task = task(app) except TypeError: - loop.create_task(task()) - else: - loop.create_task(task) + task = task() + + return task + + @classmethod + def _loop_add_task(cls, task, app, loop): + prepped = cls._prep_task(task, app, loop) + loop.create_task(prepped) @classmethod def _cancel_websocket_tasks(cls, app, loop): for task in app.websocket_tasks: task.cancel() + @staticmethod + async def dispatch_delayed_tasks(app, loop): + for name in app._delayed_tasks: + await app.dispatch(name, context={"app": app, "loop": loop}) + app._delayed_tasks.clear() + + @staticmethod + async def run_delayed_task(app, loop, task): + prepped = app._prep_task(task, app, loop) + await prepped + + @staticmethod + async def _listener( + app: Sanic, loop: AbstractEventLoop, listener: ListenerType + ): + maybe_coro = listener(app, loop) + if maybe_coro and isawaitable(maybe_coro): + await maybe_coro + # -------------------------------------------------------------------- # # ASGI # -------------------------------------------------------------------- # @@ -1333,15 +1454,51 @@ class Sanic(BaseSanic): raise SanicException(f'Sanic app name "{name}" not found.') # -------------------------------------------------------------------- # - # Static methods + # Lifecycle # -------------------------------------------------------------------- # - @staticmethod - async def finalize(app, _): + def finalize(self): try: - app.router.finalize() - if app.signal_router.routes: - app.signal_router.finalize() # noqa + self.router.finalize() except FinalizationError as e: if not Sanic.test_mode: - raise e # noqa + raise e + + def signalize(self): + try: + self.signal_router.finalize() + except FinalizationError as e: + if not Sanic.test_mode: + raise e + + async def _startup(self): + self.signalize() + self.finalize() + TouchUp.run(self) + + async def _server_event( + self, + concern: str, + action: str, + loop: Optional[AbstractEventLoop] = None, + ) -> None: + event = f"server.{concern}.{action}" + if action not in ("before", "after") or concern not in ( + "init", + "shutdown", + ): + raise SanicException(f"Invalid server event: {event}") + logger.debug(f"Triggering server events: {event}") + reverse = concern == "shutdown" + if loop is None: + loop = self.loop + await self.dispatch( + event, + fail_not_found=False, + reverse=reverse, + inline=True, + context={ + "app": self, + "loop": loop, + }, + ) diff --git a/sanic/asgi.py b/sanic/asgi.py index 330ced5a..55c18d5c 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,6 +1,5 @@ import warnings -from inspect import isawaitable from typing import Optional from urllib.parse import quote @@ -11,21 +10,27 @@ from sanic.exceptions import ServerError from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.request import Request from sanic.server import ConnInfo -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection class Lifespan: def __init__(self, asgi_app: "ASGIApp") -> None: self.asgi_app = asgi_app - if "before_server_start" in self.asgi_app.sanic_app.listeners: + if ( + "server.init.before" + in self.asgi_app.sanic_app.signal_router.name_index + ): warnings.warn( 'You have set a listener for "before_server_start" ' "in ASGI mode. " "It will be executed as early as possible, but not before " "the ASGI server is started." ) - if "after_server_stop" in self.asgi_app.sanic_app.listeners: + if ( + "server.shutdown.after" + in self.asgi_app.sanic_app.signal_router.name_index + ): warnings.warn( 'You have set a listener for "after_server_stop" ' "in ASGI mode. " @@ -42,19 +47,9 @@ class Lifespan: in sequence since the ASGI lifespan protocol only supports a single startup event. """ - self.asgi_app.sanic_app.router.finalize() - if self.asgi_app.sanic_app.signal_router.routes: - self.asgi_app.sanic_app.signal_router.finalize() - listeners = self.asgi_app.sanic_app.listeners.get( - "before_server_start", [] - ) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) - - for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if response and isawaitable(response): - await response + await self.asgi_app.sanic_app._startup() + await self.asgi_app.sanic_app._server_event("init", "before") + await self.asgi_app.sanic_app._server_event("init", "after") async def shutdown(self) -> None: """ @@ -65,16 +60,8 @@ class Lifespan: in sequence since the ASGI lifespan protocol only supports a single shutdown event. """ - listeners = self.asgi_app.sanic_app.listeners.get( - "before_server_stop", [] - ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) - - for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if response and isawaitable(response): - await response + await self.asgi_app.sanic_app._server_event("shutdown", "before") + await self.asgi_app.sanic_app._server_event("shutdown", "after") async def __call__( self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend diff --git a/sanic/base.py b/sanic/base.py index ff4833fa..5d1358d8 100644 --- a/sanic/base.py +++ b/sanic/base.py @@ -58,7 +58,7 @@ class BaseSanic( if name not in self.__fake_slots__: warn( f"Setting variables on {self.__class__.__name__} instances is " - "deprecated and will be removed in version 21.9. You should " + "deprecated and will be removed in version 21.12. You should " f"change your {self.__class__.__name__} instance to use " f"instance.ctx.{name} instead.", DeprecationWarning, diff --git a/sanic/blueprint_group.py b/sanic/blueprint_group.py index 45f30894..8bec376d 100644 --- a/sanic/blueprint_group.py +++ b/sanic/blueprint_group.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import MutableSequence +from functools import partial from typing import TYPE_CHECKING, List, Optional, Union @@ -196,6 +197,27 @@ class BlueprintGroup(MutableSequence): """ self._blueprints.append(value) + def exception(self, *exceptions, **kwargs): + """ + A decorator that can be used to implement a global exception handler + for all the Blueprints that belong to this Blueprint Group. + + In case of nested Blueprint Groups, the same handler is applied + across each of the Blueprints recursively. + + :param args: List of Python exceptions to be caught by the handler + :param kwargs: Additional optional arguments to be passed to the + exception handler + :return a decorated method to handle global exceptions for any + blueprint registered under this group. + """ + + def register_exception_handler_for_blueprints(fn): + for blueprint in self.blueprints: + blueprint.exception(*exceptions, **kwargs)(fn) + + return register_exception_handler_for_blueprints + def insert(self, index: int, item: Blueprint) -> None: """ The Abstract class `MutableSequence` leverages this insert method to @@ -229,3 +251,15 @@ class BlueprintGroup(MutableSequence): args = list(args)[1:] return register_middleware_for_blueprints(fn) return register_middleware_for_blueprints + + def on_request(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "request") + else: + return partial(self.middleware, attach_to="request") + + def on_response(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "response") + else: + return partial(self.middleware, attach_to="response") diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 2431f849..617ec606 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from collections import defaultdict +from copy import deepcopy from types import SimpleNamespace from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union @@ -12,6 +13,7 @@ from sanic_routing.route import Route # type: ignore from sanic.base import BaseSanic from sanic.blueprint_group import BlueprintGroup from sanic.exceptions import SanicException +from sanic.helpers import Default, _default from sanic.models.futures import FutureRoute, FutureStatic from sanic.models.handler_types import ( ListenerType, @@ -40,7 +42,7 @@ class Blueprint(BaseSanic): :param host: IP Address of FQDN for the sanic server to use. :param version: Blueprint Version :param strict_slashes: Enforce the API urls are requested with a - training */* + trailing */* """ __fake_slots__ = ( @@ -76,15 +78,9 @@ class Blueprint(BaseSanic): version_prefix: str = "/v", ): super().__init__(name=name) - - self._apps: Set[Sanic] = set() + self.reset() self.ctx = SimpleNamespace() - self.exceptions: List[RouteHandler] = [] self.host = host - self.listeners: Dict[str, List[ListenerType]] = {} - self.middlewares: List[MiddlewareType] = [] - self.routes: List[Route] = [] - self.statics: List[RouteHandler] = [] self.strict_slashes = strict_slashes self.url_prefix = ( url_prefix[:-1] @@ -93,7 +89,6 @@ class Blueprint(BaseSanic): ) self.version = version self.version_prefix = version_prefix - self.websocket_routes: List[Route] = [] def __repr__(self) -> str: args = ", ".join( @@ -144,12 +139,87 @@ class Blueprint(BaseSanic): kwargs["apply"] = False return super().signal(event, *args, **kwargs) + def reset(self): + self._apps: Set[Sanic] = set() + self.exceptions: List[RouteHandler] = [] + self.listeners: Dict[str, List[ListenerType]] = {} + self.middlewares: List[MiddlewareType] = [] + self.routes: List[Route] = [] + self.statics: List[RouteHandler] = [] + self.websocket_routes: List[Route] = [] + + def copy( + self, + name: str, + url_prefix: Optional[Union[str, Default]] = _default, + version: Optional[Union[int, str, float, Default]] = _default, + version_prefix: Union[str, Default] = _default, + strict_slashes: Optional[Union[bool, Default]] = _default, + with_registration: bool = True, + with_ctx: bool = False, + ): + """ + Copy a blueprint instance with some optional parameters to + override the values of attributes in the old instance. + + :param name: unique name of the blueprint + :param url_prefix: URL to be prefixed before all route URLs + :param version: Blueprint Version + :param version_prefix: the prefix of the version number shown in the + URL. + :param strict_slashes: Enforce the API urls are requested with a + trailing */* + :param with_registration: whether register new blueprint instance with + sanic apps that were registered with the old instance or not. + :param with_ctx: whether ``ctx`` will be copied or not. + """ + + attrs_backup = { + "_apps": self._apps, + "routes": self.routes, + "websocket_routes": self.websocket_routes, + "middlewares": self.middlewares, + "exceptions": self.exceptions, + "listeners": self.listeners, + "statics": self.statics, + } + + self.reset() + new_bp = deepcopy(self) + new_bp.name = name + + if not isinstance(url_prefix, Default): + new_bp.url_prefix = url_prefix + if not isinstance(version, Default): + new_bp.version = version + if not isinstance(strict_slashes, Default): + new_bp.strict_slashes = strict_slashes + if not isinstance(version_prefix, Default): + new_bp.version_prefix = version_prefix + + for key, value in attrs_backup.items(): + setattr(self, key, value) + + if with_registration and self._apps: + if new_bp._future_statics: + raise SanicException( + "Static routes registered with the old blueprint instance," + " cannot be registered again." + ) + for app in self._apps: + app.blueprint(new_bp) + + if not with_ctx: + new_bp.ctx = SimpleNamespace() + + return new_bp + @staticmethod def group( - *blueprints, - url_prefix="", - version=None, - strict_slashes=None, + *blueprints: Union[Blueprint, BlueprintGroup], + url_prefix: Optional[str] = None, + version: Optional[Union[int, str, float]] = None, + strict_slashes: Optional[bool] = None, version_prefix: str = "/v", ): """ @@ -196,6 +266,9 @@ class Blueprint(BaseSanic): opt_version = options.get("version", None) opt_strict_slashes = options.get("strict_slashes", None) opt_version_prefix = options.get("version_prefix", self.version_prefix) + error_format = options.get( + "error_format", app.config.FALLBACK_ERROR_FORMAT + ) routes = [] middleware = [] @@ -243,6 +316,7 @@ class Blueprint(BaseSanic): future.unquote, future.static, version_prefix, + error_format, ) route = app._apply_route(apply_route) @@ -261,19 +335,22 @@ class Blueprint(BaseSanic): route_names = [route.name for route in routes if route] - # Middleware if route_names: + # Middleware for future in self._future_middleware: middleware.append(app._apply_middleware(future, route_names)) - # Exceptions - for future in self._future_exceptions: - exception_handlers.append(app._apply_exception_handler(future)) + # Exceptions + for future in self._future_exceptions: + exception_handlers.append( + app._apply_exception_handler(future, route_names) + ) # Event listeners for listener in self._future_listeners: listeners[listener.event].append(app._apply_listener(listener)) + # Signals for signal in self._future_signals: signal.condition.update({"blueprint": self.name}) app._apply_signal(signal) diff --git a/sanic/config.py b/sanic/config.py index 27699f80..649d9414 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Union from warnings import warn +from sanic.errorpages import check_error_format from sanic.http import Http from .utils import load_module_from_file_location, str_to_bool @@ -20,7 +21,7 @@ BASE_LOGO = """ DEFAULT_CONFIG = { "ACCESS_LOG": True, "EVENT_AUTOREGISTER": False, - "FALLBACK_ERROR_FORMAT": "html", + "FALLBACK_ERROR_FORMAT": "auto", "FORWARDED_FOR_HEADER": "X-Forwarded-For", "FORWARDED_SECRET": None, "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec @@ -35,12 +36,9 @@ DEFAULT_CONFIG = { "REQUEST_MAX_SIZE": 100000000, # 100 megabytes "REQUEST_TIMEOUT": 60, # 60 seconds "RESPONSE_TIMEOUT": 60, # 60 seconds - "WEBSOCKET_MAX_QUEUE": 32, "WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte "WEBSOCKET_PING_INTERVAL": 20, "WEBSOCKET_PING_TIMEOUT": 20, - "WEBSOCKET_READ_LIMIT": 2 ** 16, - "WEBSOCKET_WRITE_LIMIT": 2 ** 16, } @@ -62,12 +60,10 @@ class Config(dict): REQUEST_MAX_SIZE: int REQUEST_TIMEOUT: int RESPONSE_TIMEOUT: int - WEBSOCKET_MAX_QUEUE: int + SERVER_NAME: str WEBSOCKET_MAX_SIZE: int WEBSOCKET_PING_INTERVAL: int WEBSOCKET_PING_TIMEOUT: int - WEBSOCKET_READ_LIMIT: int - WEBSOCKET_WRITE_LIMIT: int def __init__( self, @@ -100,6 +96,7 @@ class Config(dict): self.load_environment_vars(SANIC_PREFIX) self._configure_header_size() + self._check_error_format() def __getattr__(self, attr): try: @@ -115,6 +112,8 @@ class Config(dict): "REQUEST_MAX_SIZE", ): self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() def _configure_header_size(self): Http.set_header_max_size( @@ -123,6 +122,9 @@ class Config(dict): self.REQUEST_MAX_SIZE, ) + def _check_error_format(self): + check_error_format(self.FALLBACK_ERROR_FORMAT) + def load_environment_vars(self, prefix=SANIC_PREFIX): """ Looks for prefixed environment variables and applies diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 5fc10de1..82cdd57a 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = { } RENDERERS_BY_CONTENT_TYPE = { - "multipart/form-data": HTMLRenderer, - "application/json": JSONRenderer, "text/plain": TextRenderer, + "application/json": JSONRenderer, + "multipart/form-data": HTMLRenderer, + "text/html": HTMLRenderer, } +CONTENT_TYPE_BY_RENDERERS = { + v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items() +} + +RESPONSE_MAPPING = { + "empty": "html", + "json": "json", + "text": "text", + "raw": "text", + "html": "html", + "file": "html", + "file_stream": "text", + "stream": "text", + "redirect": "html", + "text/plain": "text", + "text/html": "html", + "application/json": "json", +} + + +def check_error_format(format): + if format not in RENDERERS_BY_CONFIG and format != "auto": + raise SanicException(f"Unknown format: {format}") def exception_response( request: Request, exception: Exception, debug: bool, + fallback: str, + base: t.Type[BaseRenderer], renderer: t.Type[t.Optional[BaseRenderer]] = None, ) -> HTTPResponse: """ Render a response for the default FALLBACK exception handler. """ + content_type = None if not renderer: - renderer = HTMLRenderer + # Make sure we have something set + renderer = base + render_format = fallback if request: - if request.app.config.FALLBACK_ERROR_FORMAT == "auto": + # If there is a request, try and get the format + # from the route + if request.route: try: - renderer = JSONRenderer if request.json else HTMLRenderer - except InvalidUsage: + render_format = request.route.ctx.error_format + except AttributeError: + ... + + content_type = request.headers.getone("content-type", "").split( + ";" + )[0] + + acceptable = request.accept + + # If the format is auto still, make a guess + if render_format == "auto": + # First, if there is an Accept header, check if text/html + # is the first option + # According to MDN Web Docs, all major browsers use text/html + # as the primary value in Accept (with the exception of IE 8, + # and, well, if you are supporting IE 8, then you have bigger + # problems to concern yourself with than what default exception + # renderer is used) + # Source: + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values + + if acceptable and acceptable[0].match( + "text/html", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ): renderer = HTMLRenderer - content_type, *_ = request.headers.getone( - "content-type", "" - ).split(";") - renderer = RENDERERS_BY_CONTENT_TYPE.get( - content_type, renderer - ) + # Second, if there is an Accept header, check if + # application/json is an option, or if the content-type + # is application/json + elif ( + acceptable + and acceptable.match( + "application/json", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ) + or content_type == "application/json" + ): + renderer = JSONRenderer + + # Third, if there is no Accept header, assume we want text. + # The likely use case here is a raw socket. + elif not acceptable: + renderer = TextRenderer + else: + # Fourth, look to see if there was a JSON body + # When in this situation, the request is probably coming + # from curl, an API client like Postman or Insomnia, or a + # package like requests or httpx + try: + # Give them the benefit of the doubt if they did: + # $ curl localhost:8000 -d '{"foo": "bar"}' + # And provide them with JSONRenderer + renderer = JSONRenderer if request.json else base + except InvalidUsage: + renderer = base else: - render_format = request.app.config.FALLBACK_ERROR_FORMAT renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) + # Lastly, if there is an Accept header, make sure + # our choice is okay + if acceptable: + type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore + if type_ and type_ not in acceptable: + # If the renderer selected is not in the Accept header + # look through what is in the Accept header, and select + # the first option that matches. Otherwise, just drop back + # to the original default + for accept in acceptable: + mtype = f"{accept.type_}/{accept.subtype}" + maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) + if maybe: + renderer = maybe + break + else: + renderer = base + renderer = t.cast(t.Type[BaseRenderer], renderer) return renderer(request, exception, debug).render() diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 16cd684d..1bb06f1d 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -4,16 +4,20 @@ from sanic.helpers import STATUS_CODES class SanicException(Exception): + message: str = "" + def __init__( self, message: Optional[Union[str, bytes]] = None, status_code: Optional[int] = None, quiet: Optional[bool] = None, ) -> None: - - if message is None and status_code is not None: - msg: bytes = STATUS_CODES.get(status_code, b"") - message = msg.decode("utf8") + if message is None: + if self.message: + message = self.message + elif status_code is not None: + msg: bytes = STATUS_CODES.get(status_code, b"") + message = msg.decode("utf8") super().__init__(message) @@ -122,8 +126,11 @@ class HeaderNotFound(InvalidUsage): **Status**: 400 Bad Request """ - status_code = 400 - quiet = True + +class InvalidHeader(InvalidUsage): + """ + **Status**: 400 Bad Request + """ class ContentRangeError(SanicException): @@ -230,6 +237,11 @@ class InvalidSignal(SanicException): pass +class WebsocketClosed(SanicException): + quiet = True + message = "Client has closed the websocket connection" + + def abort(status_code: int, message: Optional[Union[str, bytes]] = None): """ Raise an exception based on SanicException. Returns the HTTP response diff --git a/sanic/handlers.py b/sanic/handlers.py index dd1fbac1..ffeb76b8 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -1,12 +1,13 @@ -from traceback import format_exc +from typing import Dict, List, Optional, Tuple, Type -from sanic.errorpages import exception_response +from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response from sanic.exceptions import ( ContentRangeError, HeaderNotFound, InvalidRangeType, ) from sanic.log import error_logger +from sanic.models.handler_types import RouteHandler from sanic.response import text @@ -23,15 +24,17 @@ class ErrorHandler: """ - handlers = None - cached_handlers = None - - def __init__(self): - self.handlers = [] - self.cached_handlers = {} + # Beginning in v22.3, the base renderer will be TextRenderer + def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer): + self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] + self.cached_handlers: Dict[ + Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] + ] = {} self.debug = False + self.fallback = fallback + self.base = base - def add(self, exception, handler): + def add(self, exception, handler, route_names: Optional[List[str]] = None): """ Add a new exception handler to an already existing handler object. @@ -44,11 +47,16 @@ class ErrorHandler: :return: None """ - # self.handlers to be deprecated and removed in version 21.12 + # self.handlers is deprecated and will be removed in version 22.3 self.handlers.append((exception, handler)) - self.cached_handlers[exception] = handler - def lookup(self, exception): + if route_names: + for route in route_names: + self.cached_handlers[(exception, route)] = handler + else: + self.cached_handlers[(exception, None)] = handler + + def lookup(self, exception, route_name: Optional[str]): """ Lookup the existing instance of :class:`ErrorHandler` and fetch the registered handler for a specific type of exception. @@ -63,17 +71,26 @@ class ErrorHandler: :return: Registered function if found ``None`` otherwise """ exception_class = type(exception) - if exception_class in self.cached_handlers: - return self.cached_handlers[exception_class] - for ancestor in type.mro(exception_class): - if ancestor in self.cached_handlers: - handler = self.cached_handlers[ancestor] - self.cached_handlers[exception_class] = handler + for name in (route_name, None): + exception_key = (exception_class, name) + handler = self.cached_handlers.get(exception_key) + if handler: return handler - if ancestor is BaseException: - break - self.cached_handlers[exception_class] = None + + for name in (route_name, None): + for ancestor in type.mro(exception_class): + exception_key = (ancestor, name) + if exception_key in self.cached_handlers: + handler = self.cached_handlers[exception_key] + self.cached_handlers[ + (exception_class, route_name) + ] = handler + return handler + + if ancestor is BaseException: + break + self.cached_handlers[(exception_class, route_name)] = None handler = None return handler @@ -91,7 +108,8 @@ class ErrorHandler: :return: Wrap the return value obtained from :func:`default` or registered handler for that type of exception. """ - handler = self.lookup(exception) + route_name = request.name if request else None + handler = self.lookup(exception, route_name) response = None try: if handler: @@ -99,7 +117,6 @@ class ErrorHandler: if response is None: response = self.default(request, exception) except Exception: - self.log(format_exc()) try: url = repr(request.url) except AttributeError: @@ -115,11 +132,6 @@ class ErrorHandler: return text("An error occurred while handling an error", 500) return response - def log(self, message, level="error"): - """ - Deprecated, do not use. - """ - def default(self, request, exception): """ Provide a default behavior for the objects of :class:`ErrorHandler`. @@ -135,6 +147,17 @@ class ErrorHandler: :class:`Exception` :return: """ + self.log(request, exception) + return exception_response( + request, + exception, + debug=self.debug, + base=self.base, + fallback=self.fallback, + ) + + @staticmethod + def log(request, exception): quiet = getattr(exception, "quiet", False) if quiet is False: try: @@ -142,13 +165,10 @@ class ErrorHandler: except AttributeError: url = "unknown" - self.log(format_exc()) error_logger.exception( "Exception occurred while handling uri: %s", url ) - return exception_response(request, exception, self.debug) - class ContentRangeHandler: """ diff --git a/sanic/headers.py b/sanic/headers.py index 66427442..dbb8720f 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import re from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import unquote +from sanic.exceptions import InvalidHeader from sanic.helpers import STATUS_CODES @@ -30,6 +33,175 @@ _host_re = re.compile( # For more information, consult ../tests/test_requests.py +def parse_arg_as_accept(f): + def func(self, other, *args, **kwargs): + if not isinstance(other, Accept) and other: + other = Accept.parse(other) + return f(self, other, *args, **kwargs) + + return func + + +class MediaType(str): + def __new__(cls, value: str): + return str.__new__(cls, value) + + def __init__(self, value: str) -> None: + self.value = value + self.is_wildcard = self.check_if_wildcard(value) + + def __eq__(self, other): + if self.is_wildcard: + return True + + if self.match(other): + return True + + other_is_wildcard = ( + other.is_wildcard + if isinstance(other, MediaType) + else self.check_if_wildcard(other) + ) + + return other_is_wildcard + + def match(self, other): + other_value = other.value if isinstance(other, MediaType) else other + return self.value == other_value + + @staticmethod + def check_if_wildcard(value): + return value == "*" + + +class Accept(str): + def __new__(cls, value: str, *args, **kwargs): + return str.__new__(cls, value) + + def __init__( + self, + value: str, + type_: MediaType, + subtype: MediaType, + *, + q: str = "1.0", + **kwargs: str, + ): + qvalue = float(q) + if qvalue > 1 or qvalue < 0: + raise InvalidHeader( + f"Accept header qvalue must be between 0 and 1, not: {qvalue}" + ) + self.value = value + self.type_ = type_ + self.subtype = subtype + self.qvalue = qvalue + self.params = kwargs + + def _compare(self, other, method): + try: + return method(self.qvalue, other.qvalue) + except (AttributeError, TypeError): + return NotImplemented + + @parse_arg_as_accept + def __lt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s < o) + + @parse_arg_as_accept + def __le__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s <= o) + + @parse_arg_as_accept + def __eq__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s == o) + + @parse_arg_as_accept + def __ge__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s >= o) + + @parse_arg_as_accept + def __gt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s > o) + + @parse_arg_as_accept + def __ne__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s != o) + + @parse_arg_as_accept + def match( + self, + other, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + type_match = ( + self.type_ == other.type_ + if allow_type_wildcard + else ( + self.type_.match(other.type_) + and not self.type_.is_wildcard + and not other.type_.is_wildcard + ) + ) + subtype_match = ( + self.subtype == other.subtype + if allow_subtype_wildcard + else ( + self.subtype.match(other.subtype) + and not self.subtype.is_wildcard + and not other.subtype.is_wildcard + ) + ) + + return type_match and subtype_match + + @classmethod + def parse(cls, raw: str) -> Accept: + invalid = False + mtype = raw.strip() + + try: + media, *raw_params = mtype.split(";") + type_, subtype = media.split("/") + except ValueError: + invalid = True + + if invalid or not type_ or not subtype: + raise InvalidHeader(f"Header contains invalid Accept value: {raw}") + + params = dict( + [ + (key.strip(), value.strip()) + for key, value in (param.split("=", 1) for param in raw_params) + ] + ) + + return cls(mtype, MediaType(type_), MediaType(subtype), **params) + + +class AcceptContainer(list): + def __contains__(self, o: object) -> bool: + return any(item.match(o) for item in self) + + def match( + self, + o: object, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + return any( + item.match( + o, + allow_type_wildcard=allow_type_wildcard, + allow_subtype_wildcard=allow_subtype_wildcard, + ) + for item in self + ) + + def parse_content_header(value: str) -> Tuple[str, Options]: """Parse content-type and content-disposition header values. @@ -194,3 +366,31 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: ret += b"%b: %b\r\n" % h ret += b"\r\n" return ret + + +def _sort_accept_value(accept: Accept): + return ( + accept.qvalue, + len(accept.params), + accept.subtype != "*", + accept.type_ != "*", + ) + + +def parse_accept(accept: str) -> AcceptContainer: + """Parse an Accept header and order the acceptable media types in + accorsing to RFC 7231, s. 5.3.2 + https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + """ + media_types = accept.split(",") + accept_list: List[Accept] = [] + + for mtype in media_types: + if not mtype: + continue + + accept_list.append(Accept.parse(mtype)) + + return AcceptContainer( + sorted(accept_list, key=_sort_accept_value, reverse=True) + ) diff --git a/sanic/helpers.py b/sanic/helpers.py index 15ae7bf2..87d51b53 100644 --- a/sanic/helpers.py +++ b/sanic/helpers.py @@ -155,3 +155,17 @@ def import_string(module_name, package=None): if ismodule(obj): return obj return obj() + + +class Default: + """ + It is used to replace `None` or `object()` as a sentinel + that represents a default value. Sometimes we want to set + a value to `None` so we cannot use `None` to represent the + default value, and `object()` is hard to be typed. + """ + + pass + + +_default = Default() diff --git a/sanic/http.py b/sanic/http.py index 402238cb..d30e4c82 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -21,6 +21,7 @@ from sanic.exceptions import ( from sanic.headers import format_http1_response from sanic.helpers import has_message_body from sanic.log import access_logger, error_logger, logger +from sanic.touchup import TouchUpMeta class Stage(Enum): @@ -45,7 +46,7 @@ class Stage(Enum): HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" -class Http: +class Http(metaclass=TouchUpMeta): """ Internal helper for managing the HTTP request/response cycle @@ -67,9 +68,15 @@ class Http: HEADER_CEILING = 16_384 HEADER_MAX_SIZE = 0 + __touchup__ = ( + "http1_request_header", + "http1_response_header", + "read", + ) __slots__ = [ "_send", "_receive_more", + "dispatch", "recv_buffer", "protocol", "expecting_continue", @@ -97,6 +104,7 @@ class Http: self.protocol = protocol self.keep_alive = True self.stage: Stage = Stage.IDLE + self.dispatch = self.protocol.app.dispatch self.init_for_request() def init_for_request(self): @@ -140,6 +148,12 @@ class Http: await self.response.send(end_stream=True) except CancelledError: # Write an appropriate response before exiting + if not self.protocol.transport: + logger.info( + f"Request: {self.request.method} {self.request.url} " + "stopped. Transport is closed." + ) + return e = self.exception or ServiceUnavailable("Cancelled") self.exception = None self.keep_alive = False @@ -173,17 +187,17 @@ class Http: if self.response: self.response.stream = None - self.init_for_request() - # Exit and disconnect if no more requests can be taken if self.stage is not Stage.IDLE or not self.keep_alive: break + self.init_for_request() + # Wait for the next request if not self.recv_buffer: await self._receive_more() - async def http1_request_header(self): + async def http1_request_header(self): # no cov """ Receive and parse request header into self.request. """ @@ -212,6 +226,12 @@ class Http: reqline, *split_headers = raw_headers.split("\r\n") method, self.url, protocol = reqline.split(" ") + await self.dispatch( + "http.lifecycle.read_head", + inline=True, + context={"head": bytes(head)}, + ) + if protocol == "HTTP/1.1": self.keep_alive = True elif protocol == "HTTP/1.0": @@ -250,6 +270,11 @@ class Http: transport=self.protocol.transport, app=self.protocol.app, ) + await self.dispatch( + "http.lifecycle.request", + inline=True, + context={"request": request}, + ) # Prepare for request body self.request_bytes_left = self.request_bytes = 0 @@ -280,7 +305,7 @@ class Http: async def http1_response_header( self, data: bytes, end_stream: bool - ) -> None: + ) -> None: # no cov res = self.response # Compatibility with simple response body @@ -452,8 +477,8 @@ class Http: "request": "nil", } if req is not None: - if req.ip: - extra["host"] = f"{req.ip}:{req.port}" + if req.remote_addr or req.ip: + extra["host"] = f"{req.remote_addr or req.ip}:{req.port}" extra["request"] = f"{req.method} {req.url}" access_logger.info("", extra=extra) @@ -469,7 +494,7 @@ class Http: if data: yield data - async def read(self) -> Optional[bytes]: + async def read(self) -> Optional[bytes]: # no cov """ Read some bytes of request body. """ @@ -543,6 +568,12 @@ class Http: self.request_bytes_left -= size + await self.dispatch( + "http.lifecycle.read_body", + inline=True, + context={"body": data}, + ) + return data # Response methods diff --git a/sanic/mixins/listeners.py b/sanic/mixins/listeners.py index c12326c4..ebf9b131 100644 --- a/sanic/mixins/listeners.py +++ b/sanic/mixins/listeners.py @@ -1,18 +1,19 @@ from enum import Enum, auto from functools import partial -from typing import Any, Callable, Coroutine, List, Optional, Union +from typing import List, Optional, Union from sanic.models.futures import FutureListener +from sanic.models.handler_types import ListenerType class ListenerEvent(str, Enum): def _generate_next_value_(name: str, *args) -> str: # type: ignore return name.lower() - BEFORE_SERVER_START = auto() - AFTER_SERVER_START = auto() - BEFORE_SERVER_STOP = auto() - AFTER_SERVER_STOP = auto() + BEFORE_SERVER_START = "server.init.before" + AFTER_SERVER_START = "server.init.after" + BEFORE_SERVER_STOP = "server.shutdown.before" + AFTER_SERVER_STOP = "server.shutdown.after" MAIN_PROCESS_START = auto() MAIN_PROCESS_STOP = auto() @@ -26,9 +27,7 @@ class ListenerMixin: def listener( self, - listener_or_event: Union[ - Callable[..., Coroutine[Any, Any, None]], str - ], + listener_or_event: Union[ListenerType, str], event_or_none: Optional[str] = None, apply: bool = True, ): @@ -63,20 +62,20 @@ class ListenerMixin: else: return partial(register_listener, event=listener_or_event) - def main_process_start(self, listener): + def main_process_start(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "main_process_start") - def main_process_stop(self, listener): + def main_process_stop(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "main_process_stop") - def before_server_start(self, listener): + def before_server_start(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "before_server_start") - def after_server_start(self, listener): + def after_server_start(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "after_server_start") - def before_server_stop(self, listener): + def before_server_stop(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "before_server_stop") - def after_server_stop(self, listener): + def after_server_stop(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "after_server_stop") diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 5af1610d..8467a2e3 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -1,17 +1,20 @@ +from ast import NodeVisitor, Return, parse from functools import partial, wraps -from inspect import signature +from inspect import getsource, signature from mimetypes import guess_type from os import path from pathlib import PurePath from re import sub +from textwrap import dedent from time import gmtime, strftime -from typing import Iterable, List, Optional, Set, Union +from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import unquote from sanic_routing.route import Route # type: ignore from sanic.compat import stat_async from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS +from sanic.errorpages import RESPONSE_MAPPING from sanic.exceptions import ( ContentRangeError, FileNotFound, @@ -21,10 +24,16 @@ from sanic.exceptions import ( from sanic.handlers import ContentRangeHandler from sanic.log import error_logger from sanic.models.futures import FutureRoute, FutureStatic +from sanic.models.handler_types import RouteHandler from sanic.response import HTTPResponse, file, file_stream from sanic.views import CompositionView +RouteWrapper = Callable[ + [RouteHandler], Union[RouteHandler, Tuple[Route, RouteHandler]] +] + + class RouteMixin: name: str @@ -55,7 +64,8 @@ class RouteMixin: unquote: bool = False, static: bool = False, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Decorate a function to be registered as a route @@ -97,6 +107,7 @@ class RouteMixin: nonlocal websocket nonlocal static nonlocal version_prefix + nonlocal error_format if isinstance(handler, tuple): # if a handler fn is already wrapped in a route, the handler @@ -115,10 +126,16 @@ class RouteMixin: "Expected either string or Iterable of host strings, " "not %s" % host ) - - if isinstance(subprotocols, (list, tuple, set)): + if isinstance(subprotocols, list): + # Ordered subprotocols, maintain order + subprotocols = tuple(subprotocols) + elif isinstance(subprotocols, set): + # subprotocol is unordered, keep it unordered subprotocols = frozenset(subprotocols) + if not error_format or error_format == "auto": + error_format = self._determine_error_format(handler) + route = FutureRoute( handler, uri, @@ -134,6 +151,7 @@ class RouteMixin: unquote, static, version_prefix, + error_format, ) self._future_routes.add(route) @@ -168,7 +186,7 @@ class RouteMixin: def add_route( self, - handler, + handler: RouteHandler, uri: str, methods: Iterable[str] = frozenset({"GET"}), host: Optional[str] = None, @@ -177,7 +195,8 @@ class RouteMixin: name: Optional[str] = None, stream: bool = False, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteHandler: """A helper method to register class instance or functions as a handler to the application url routes. @@ -200,7 +219,8 @@ class RouteMixin: methods = set() for method in HTTP_METHODS: - _handler = getattr(handler.view_class, method.lower(), None) + view_class = getattr(handler, "view_class") + _handler = getattr(view_class, method.lower(), None) if _handler: methods.add(method) if hasattr(_handler, "is_stream"): @@ -226,6 +246,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, )(handler) return handler @@ -239,7 +260,8 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **GET** *HTTP* method @@ -262,6 +284,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def post( @@ -273,7 +296,8 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **POST** *HTTP* method @@ -296,6 +320,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def put( @@ -307,7 +332,8 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **PUT** *HTTP* method @@ -330,6 +356,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def head( @@ -341,7 +368,8 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **HEAD** *HTTP* method @@ -372,6 +400,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def options( @@ -383,7 +412,8 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **OPTIONS** *HTTP* method @@ -414,6 +444,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def patch( @@ -425,7 +456,8 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **PATCH** *HTTP* method @@ -458,6 +490,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def delete( @@ -469,7 +502,8 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **DELETE** *HTTP* method @@ -492,6 +526,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def websocket( @@ -504,6 +539,7 @@ class RouteMixin: name: Optional[str] = None, apply: bool = True, version_prefix: str = "/v", + error_format: Optional[str] = None, ): """ Decorate a function to be registered as a websocket route @@ -530,6 +566,7 @@ class RouteMixin: subprotocols=subprotocols, websocket=True, version_prefix=version_prefix, + error_format=error_format, ) def add_websocket_route( @@ -542,6 +579,7 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", + error_format: Optional[str] = None, ): """ A helper method to register a function as a websocket route. @@ -570,6 +608,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, )(handler) def static( @@ -585,6 +624,7 @@ class RouteMixin: strict_slashes=None, content_type=None, apply=True, + resource_type=None, ): """ Register a root to serve files from. The input can either be a @@ -634,6 +674,7 @@ class RouteMixin: host, strict_slashes, content_type, + resource_type, ) self._future_statics.add(static) @@ -777,10 +818,11 @@ class RouteMixin: ) except Exception: error_logger.exception( - f"Exception in static request handler:\ - path={file_or_directory}, " + f"Exception in static request handler: " + f"path={file_or_directory}, " f"relative_url={__file_uri__}" ) + raise def _register_static( self, @@ -828,8 +870,27 @@ class RouteMixin: name = static.name # If we're not trying to match a file directly, # serve from the folder - if not path.isfile(file_or_directory): + if not static.resource_type: + if not path.isfile(file_or_directory): + uri += "/<__file_uri__:path>" + elif static.resource_type == "dir": + if path.isfile(file_or_directory): + raise TypeError( + "Resource type improperly identified as directory. " + f"'{file_or_directory}'" + ) uri += "/<__file_uri__:path>" + elif static.resource_type == "file" and not path.isfile( + file_or_directory + ): + raise TypeError( + "Resource type improperly identified as file. " + f"'{file_or_directory}'" + ) + elif static.resource_type != "file": + raise ValueError( + "The resource_type should be set to 'file' or 'dir'" + ) # special prefix for static files # if not static.name.startswith("_static_"): @@ -846,7 +907,7 @@ class RouteMixin: ) ) - route, _ = self.route( + route, _ = self.route( # type: ignore uri=uri, methods=["GET", "HEAD"], name=name, @@ -856,3 +917,43 @@ class RouteMixin: )(_handler) return route + + def _determine_error_format(self, handler) -> str: + if not isinstance(handler, CompositionView): + try: + src = dedent(getsource(handler)) + tree = parse(src) + http_response_types = self._get_response_types(tree) + + if len(http_response_types) == 1: + return next(iter(http_response_types)) + except (OSError, TypeError): + ... + + return "auto" + + def _get_response_types(self, node): + types = set() + + class HttpResponseVisitor(NodeVisitor): + def visit_Return(self, node: Return) -> Any: + nonlocal types + + try: + checks = [node.value.func.id] # type: ignore + if node.value.keywords: # type: ignore + checks += [ + k.value + for k in node.value.keywords # type: ignore + if k.arg == "content_type" + ] + + for check in checks: + if check in RESPONSE_MAPPING: + types.add(RESPONSE_MAPPING[check]) + except AttributeError: + ... + + HttpResponseVisitor().visit(node) + + return types diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index e849e562..2be9fee2 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -23,7 +23,7 @@ class SignalMixin: *, apply: bool = True, condition: Dict[str, Any] = None, - ) -> Callable[[SignalHandler], FutureSignal]: + ) -> Callable[[SignalHandler], SignalHandler]: """ For creating a signal handler, used similar to a route handler: @@ -54,7 +54,7 @@ class SignalMixin: if apply: self._apply_signal(future_signal) - return future_signal + return handler return decorator diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py index 595b0553..1b707ebc 100644 --- a/sanic/models/asgi.py +++ b/sanic/models/asgi.py @@ -3,7 +3,7 @@ import asyncio from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union from sanic.exceptions import InvalidUsage -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection ASGIScope = MutableMapping[str, Any] diff --git a/sanic/models/futures.py b/sanic/models/futures.py index 2350bedb..fe7d77eb 100644 --- a/sanic/models/futures.py +++ b/sanic/models/futures.py @@ -24,6 +24,7 @@ class FutureRoute(NamedTuple): unquote: bool static: bool version_prefix: str + error_format: Optional[str] class FutureListener(NamedTuple): @@ -52,6 +53,7 @@ class FutureStatic(NamedTuple): host: Optional[str] strict_slashes: Optional[bool] content_type: Optional[bool] + resource_type: Optional[str] class FutureSignal(NamedTuple): diff --git a/sanic/models/handler_types.py b/sanic/models/handler_types.py index 704def7a..0144c964 100644 --- a/sanic/models/handler_types.py +++ b/sanic/models/handler_types.py @@ -21,5 +21,5 @@ MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType] ListenerType = Callable[ [Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]] ] -RouteHandler = Callable[..., Coroutine[Any, Any, HTTPResponse]] +RouteHandler = Callable[..., Coroutine[Any, Any, Optional[HTTPResponse]]] SignalHandler = Callable[..., Coroutine[Any, Any, None]] diff --git a/sanic/models/server_types.py b/sanic/models/server_types.py new file mode 100644 index 00000000..f0ced247 --- /dev/null +++ b/sanic/models/server_types.py @@ -0,0 +1,52 @@ +from types import SimpleNamespace + +from sanic.models.protocol_types import TransportProtocol + + +class Signal: + stopped = False + + +class ConnInfo: + """ + Local and remote addresses and SSL status info. + """ + + __slots__ = ( + "client_port", + "client", + "client_ip", + "ctx", + "peername", + "server_port", + "server", + "sockname", + "ssl", + ) + + def __init__(self, transport: TransportProtocol, unix=None): + self.ctx = SimpleNamespace() + self.peername = None + self.server = self.client = "" + self.server_port = self.client_port = 0 + self.client_ip = "" + self.sockname = addr = transport.get_extra_info("sockname") + self.ssl: bool = bool(transport.get_extra_info("sslcontext")) + + if isinstance(addr, str): # UNIX socket + self.server = unix or addr + return + + # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) + if isinstance(addr, tuple): + self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.server_port = addr[1] + # self.server gets non-standard port appended + if addr[1] != (443 if self.ssl else 80): + self.server = f"{self.server}:{addr[1]}" + self.peername = addr = transport.get_extra_info("peername") + + if isinstance(addr, tuple): + self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.client_ip = addr[0] + self.client_port = addr[1] diff --git a/sanic/request.py b/sanic/request.py index 177df637..c744e3c3 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -34,7 +34,9 @@ from sanic.compat import CancelledErrors, Header from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.exceptions import InvalidUsage from sanic.headers import ( + AcceptContainer, Options, + parse_accept, parse_content_header, parse_forwarded, parse_host, @@ -94,6 +96,7 @@ class Request: "head", "headers", "method", + "parsed_accept", "parsed_args", "parsed_not_grouped_args", "parsed_files", @@ -136,6 +139,7 @@ class Request: self.conn_info: Optional[ConnInfo] = None self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None + self.parsed_accept: Optional[AcceptContainer] = None self.parsed_json = None self.parsed_form = None self.parsed_files = None @@ -296,6 +300,13 @@ class Request: return self.parsed_json + @property + def accept(self) -> AcceptContainer: + if self.parsed_accept is None: + accept_header = self.headers.getone("accept", "") + self.parsed_accept = parse_accept(accept_header) + return self.parsed_accept + @property def token(self): """Attempt to return the auth header token. @@ -497,6 +508,10 @@ class Request: """ return self._match_info + @match_info.setter + def match_info(self, value): + self._match_info = value + # Transport properties (obtained from local interface only) @property diff --git a/sanic/router.py b/sanic/router.py index 0973a3fa..6995ed6d 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -1,5 +1,9 @@ +from __future__ import annotations + from functools import lru_cache +from inspect import signature from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from uuid import UUID from sanic_routing import BaseRouter # type: ignore from sanic_routing.exceptions import NoMethod # type: ignore @@ -9,6 +13,7 @@ from sanic_routing.exceptions import ( from sanic_routing.route import Route # type: ignore from sanic.constants import HTTP_METHODS +from sanic.errorpages import check_error_format from sanic.exceptions import MethodNotSupported, NotFound, SanicException from sanic.models.handler_types import RouteHandler @@ -74,6 +79,7 @@ class Router(BaseRouter): unquote: bool = False, static: bool = False, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> Union[Route, List[Route]]: """ Add a handler to the router @@ -106,6 +112,8 @@ class Router(BaseRouter): version = str(version).strip("/").lstrip("v") uri = "/".join([f"{version_prefix}{version}", uri.lstrip("/")]) + uri = self._normalize(uri, handler) + params = dict( path=uri, handler=handler, @@ -131,6 +139,11 @@ class Router(BaseRouter): route.ctx.stream = stream route.ctx.hosts = hosts route.ctx.static = static + route.ctx.error_format = ( + error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT + ) + + check_error_format(route.ctx.error_format) routes.append(route) @@ -187,3 +200,24 @@ class Router(BaseRouter): raise SanicException( f"Invalid route: {route}. Parameter names cannot use '__'." ) + + def _normalize(self, uri: str, handler: RouteHandler) -> str: + if "<" not in uri: + return uri + + sig = signature(handler) + mapping = { + param.name: param.annotation.__name__.lower() + for param in sig.parameters.values() + if param.annotation in (str, int, float, UUID) + } + + reconstruction = [] + for part in uri.split("/"): + if part.startswith("<") and ":" not in part: + name = part[1:-1] + annotation = mapping.get(name) + if annotation: + part = f"<{name}:{annotation}>" + reconstruction.append(part) + return "/".join(reconstruction) diff --git a/sanic/server.py b/sanic/server.py deleted file mode 100644 index 4ec83f9c..00000000 --- a/sanic/server.py +++ /dev/null @@ -1,793 +0,0 @@ -from __future__ import annotations - -from ssl import SSLContext -from types import SimpleNamespace -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - Optional, - Type, - Union, -) - -from sanic.models.handler_types import ListenerType - - -if TYPE_CHECKING: - from sanic.app import Sanic - -import asyncio -import multiprocessing -import os -import secrets -import socket -import stat - -from asyncio import CancelledError -from asyncio.transports import Transport -from functools import partial -from inspect import isawaitable -from ipaddress import ip_address -from signal import SIG_IGN, SIGINT, SIGTERM, Signals -from signal import signal as signal_func -from time import monotonic as current_time - -from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows -from sanic.config import Config -from sanic.exceptions import RequestTimeout, ServiceUnavailable -from sanic.http import Http, Stage -from sanic.log import error_logger, logger -from sanic.models.protocol_types import TransportProtocol -from sanic.request import Request - - -try: - import uvloop # type: ignore - - if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -except ImportError: - pass - - -class Signal: - stopped = False - - -class ConnInfo: - """ - Local and remote addresses and SSL status info. - """ - - __slots__ = ( - "client_port", - "client", - "client_ip", - "ctx", - "peername", - "server_port", - "server", - "sockname", - "ssl", - ) - - def __init__(self, transport: TransportProtocol, unix=None): - self.ctx = SimpleNamespace() - self.peername = None - self.server = self.client = "" - self.server_port = self.client_port = 0 - self.client_ip = "" - self.sockname = addr = transport.get_extra_info("sockname") - self.ssl: bool = bool(transport.get_extra_info("sslcontext")) - - if isinstance(addr, str): # UNIX socket - self.server = unix or addr - return - - # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) - if isinstance(addr, tuple): - self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" - self.server_port = addr[1] - # self.server gets non-standard port appended - if addr[1] != (443 if self.ssl else 80): - self.server = f"{self.server}:{addr[1]}" - self.peername = addr = transport.get_extra_info("peername") - - if isinstance(addr, tuple): - self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" - self.client_ip = addr[0] - self.client_port = addr[1] - - -class HttpProtocol(asyncio.Protocol): - """ - This class provides a basic HTTP implementation of the sanic framework. - """ - - __slots__ = ( - # app - "app", - # event loop, connection - "loop", - "transport", - "connections", - "signal", - "conn_info", - "ctx", - # request params - "request", - # request config - "request_handler", - "request_timeout", - "response_timeout", - "keep_alive_timeout", - "request_max_size", - "request_class", - "error_handler", - # enable or disable access log purpose - "access_log", - # connection management - "state", - "url", - "_handler_task", - "_can_write", - "_data_received", - "_time", - "_task", - "_http", - "_exception", - "recv_buffer", - "_unix", - ) - - def __init__( - self, - *, - loop, - app: Sanic, - signal=None, - connections=None, - state=None, - unix=None, - **kwargs, - ): - asyncio.set_event_loop(loop) - self.loop = loop - self.app: Sanic = app - self.url = None - self.transport: Optional[Transport] = None - self.conn_info: Optional[ConnInfo] = None - self.request: Optional[Request] = None - self.signal = signal or Signal() - self.access_log = self.app.config.ACCESS_LOG - self.connections = connections if connections is not None else set() - self.request_handler = self.app.handle_request - self.error_handler = self.app.error_handler - self.request_timeout = self.app.config.REQUEST_TIMEOUT - self.response_timeout = self.app.config.RESPONSE_TIMEOUT - self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT - self.request_max_size = self.app.config.REQUEST_MAX_SIZE - self.request_class = self.app.request_class or Request - self.state = state if state else {} - if "requests_count" not in self.state: - self.state["requests_count"] = 0 - self._data_received = asyncio.Event() - self._can_write = asyncio.Event() - self._can_write.set() - self._exception = None - self._unix = unix - - def _setup_connection(self): - self._http = Http(self) - self._time = current_time() - self.check_timeouts() - - async def connection_task(self): - """ - Run a HTTP connection. - - Timeouts and some additional error handling occur here, while most of - everything else happens in class Http or in code called from there. - """ - try: - self._setup_connection() - await self._http.http1() - except CancelledError: - pass - except Exception: - error_logger.exception("protocol.connection_task uncaught") - finally: - if self.app.debug and self._http: - ip = self.transport.get_extra_info("peername") - error_logger.error( - "Connection lost before response written" - f" @ {ip} {self._http.request}" - ) - self._http = None - self._task = None - try: - self.close() - except BaseException: - error_logger.exception("Closing failed") - - async def receive_more(self): - """ - Wait until more data is received into the Server protocol's buffer - """ - self.transport.resume_reading() - self._data_received.clear() - await self._data_received.wait() - - def check_timeouts(self): - """ - Runs itself periodically to enforce any expired timeouts. - """ - try: - if not self._task: - return - duration = current_time() - self._time - stage = self._http.stage - if stage is Stage.IDLE and duration > self.keep_alive_timeout: - logger.debug("KeepAlive Timeout. Closing connection.") - elif stage is Stage.REQUEST and duration > self.request_timeout: - logger.debug("Request Timeout. Closing connection.") - self._http.exception = RequestTimeout("Request Timeout") - elif stage is Stage.HANDLER and self._http.upgrade_websocket: - logger.debug("Handling websocket. Timeouts disabled.") - return - elif ( - stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) - and duration > self.response_timeout - ): - logger.debug("Response Timeout. Closing connection.") - self._http.exception = ServiceUnavailable("Response Timeout") - else: - interval = ( - min( - self.keep_alive_timeout, - self.request_timeout, - self.response_timeout, - ) - / 2 - ) - self.loop.call_later(max(0.1, interval), self.check_timeouts) - return - self._task.cancel() - except Exception: - error_logger.exception("protocol.check_timeouts") - - async def send(self, data): - """ - Writes data with backpressure control. - """ - await self._can_write.wait() - if self.transport.is_closing(): - raise CancelledError - self.transport.write(data) - self._time = current_time() - - def close_if_idle(self) -> bool: - """ - Close the connection if a request is not being sent or received - - :return: boolean - True if closed, false if staying open - """ - if self._http is None or self._http.stage is Stage.IDLE: - self.close() - return True - return False - - def close(self): - """ - Force close the connection. - """ - # Cause a call to connection_lost where further cleanup occurs - if self.transport: - self.transport.close() - self.transport = None - - # -------------------------------------------- # - # Only asyncio.Protocol callbacks below this - # -------------------------------------------- # - - def connection_made(self, transport): - try: - # TODO: Benchmark to find suitable write buffer limits - transport.set_write_buffer_limits(low=16384, high=65536) - self.connections.add(self) - self.transport = transport - self._task = self.loop.create_task(self.connection_task()) - self.recv_buffer = bytearray() - self.conn_info = ConnInfo(self.transport, unix=self._unix) - except Exception: - error_logger.exception("protocol.connect_made") - - def connection_lost(self, exc): - try: - self.connections.discard(self) - self.resume_writing() - if self._task: - self._task.cancel() - except Exception: - error_logger.exception("protocol.connection_lost") - - def pause_writing(self): - self._can_write.clear() - - def resume_writing(self): - self._can_write.set() - - def data_received(self, data: bytes): - try: - self._time = current_time() - if not data: - return self.close() - self.recv_buffer += data - - if ( - len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE - and self.transport - ): - self.transport.pause_reading() - - if self._data_received: - self._data_received.set() - except Exception: - error_logger.exception("protocol.data_received") - - -def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): - """ - Trigger event callbacks (functions or async) - - :param events: one or more sync or async functions to execute - :param loop: event loop - """ - if events: - for event in events: - result = event(loop) - if isawaitable(result): - loop.run_until_complete(result) - - -class AsyncioServer: - """ - Wraps an asyncio server with functionality that might be useful to - a user who needs to manage the server lifecycle manually. - """ - - __slots__ = ( - "loop", - "serve_coro", - "_after_start", - "_before_stop", - "_after_stop", - "server", - "connections", - ) - - def __init__( - self, - loop, - serve_coro, - connections, - after_start: Optional[Iterable[ListenerType]], - before_stop: Optional[Iterable[ListenerType]], - after_stop: Optional[Iterable[ListenerType]], - ): - # Note, Sanic already called "before_server_start" events - # before this helper was even created. So we don't need it here. - self.loop = loop - self.serve_coro = serve_coro - self._after_start = after_start - self._before_stop = before_stop - self._after_stop = after_stop - self.server = None - self.connections = connections - - def after_start(self): - """ - Trigger "after_server_start" events - """ - trigger_events(self._after_start, self.loop) - - def before_stop(self): - """ - Trigger "before_server_stop" events - """ - trigger_events(self._before_stop, self.loop) - - def after_stop(self): - """ - Trigger "after_server_stop" events - """ - trigger_events(self._after_stop, self.loop) - - def is_serving(self) -> bool: - if self.server: - return self.server.is_serving() - return False - - def wait_closed(self): - if self.server: - return self.server.wait_closed() - - def close(self): - if self.server: - self.server.close() - coro = self.wait_closed() - task = asyncio.ensure_future(coro, loop=self.loop) - return task - - def start_serving(self): - if self.server: - try: - return self.server.start_serving() - except AttributeError: - raise NotImplementedError( - "server.start_serving not available in this version " - "of asyncio or uvloop." - ) - - def serve_forever(self): - if self.server: - try: - return self.server.serve_forever() - except AttributeError: - raise NotImplementedError( - "server.serve_forever not available in this version " - "of asyncio or uvloop." - ) - - def __await__(self): - """ - Starts the asyncio server, returns AsyncServerCoro - """ - task = asyncio.ensure_future(self.serve_coro) - while not task.done(): - yield - self.server = task.result() - return self - - -def serve( - host, - port, - app, - before_start: Optional[Iterable[ListenerType]] = None, - after_start: Optional[Iterable[ListenerType]] = None, - before_stop: Optional[Iterable[ListenerType]] = None, - after_stop: Optional[Iterable[ListenerType]] = None, - ssl: Optional[SSLContext] = None, - sock: Optional[socket.socket] = None, - unix: Optional[str] = None, - reuse_port: bool = False, - loop=None, - protocol: Type[asyncio.Protocol] = HttpProtocol, - backlog: int = 100, - register_sys_signals: bool = True, - run_multiple: bool = False, - run_async: bool = False, - connections=None, - signal=Signal(), - state=None, - asyncio_server_kwargs=None, -): - """Start asynchronous HTTP Server on an individual process. - - :param host: Address to host on - :param port: Port to host on - :param before_start: function to be executed before the server starts - listening. Takes arguments `app` instance and `loop` - :param after_start: function to be executed after the server starts - listening. Takes arguments `app` instance and `loop` - :param before_stop: function to be executed when a stop signal is - received before it is respected. Takes arguments - `app` instance and `loop` - :param after_stop: function to be executed when a stop signal is - received after it is respected. Takes arguments - `app` instance and `loop` - :param ssl: SSLContext - :param sock: Socket for the server to accept connections from - :param unix: Unix socket to listen on instead of TCP port - :param reuse_port: `True` for multiple workers - :param loop: asyncio compatible event loop - :param run_async: bool: Do not create a new event loop for the server, - and return an AsyncServer object rather than running it - :param asyncio_server_kwargs: key-value args for asyncio/uvloop - create_server method - :return: Nothing - """ - if not run_async and not loop: - # create new event_loop after fork - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if app.debug: - loop.set_debug(app.debug) - - app.asgi = False - - connections = connections if connections is not None else set() - protocol_kwargs = _build_protocol_kwargs(protocol, app.config) - server = partial( - protocol, - loop=loop, - connections=connections, - signal=signal, - app=app, - state=state, - unix=unix, - **protocol_kwargs, - ) - asyncio_server_kwargs = ( - asyncio_server_kwargs if asyncio_server_kwargs else {} - ) - # UNIX sockets are always bound by us (to preserve semantics between modes) - if unix: - sock = bind_unix_socket(unix, backlog=backlog) - server_coroutine = loop.create_server( - server, - None if sock else host, - None if sock else port, - ssl=ssl, - reuse_port=reuse_port, - sock=sock, - backlog=backlog, - **asyncio_server_kwargs, - ) - - if run_async: - return AsyncioServer( - loop=loop, - serve_coro=server_coroutine, - connections=connections, - after_start=after_start, - before_stop=before_stop, - after_stop=after_stop, - ) - - trigger_events(before_start, loop) - - try: - http_server = loop.run_until_complete(server_coroutine) - except BaseException: - error_logger.exception("Unable to start server") - return - - trigger_events(after_start, loop) - - # Ignore SIGINT when run_multiple - if run_multiple: - signal_func(SIGINT, SIG_IGN) - - # Register signals for graceful termination - if register_sys_signals: - if OS_IS_WINDOWS: - ctrlc_workaround_for_windows(app) - else: - for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: - loop.add_signal_handler(_signal, app.stop) - pid = os.getpid() - try: - logger.info("Starting worker [%s]", pid) - loop.run_forever() - finally: - logger.info("Stopping worker [%s]", pid) - - # Run the on_stop function if provided - trigger_events(before_stop, loop) - - # Wait for event loop to finish and all connections to drain - http_server.close() - loop.run_until_complete(http_server.wait_closed()) - - # Complete all tasks on the loop - signal.stopped = True - for connection in connections: - connection.close_if_idle() - - # Gracefully shutdown timeout. - # We should provide graceful_shutdown_timeout, - # instead of letting connection hangs forever. - # Let's roughly calcucate time. - graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT - start_shutdown: float = 0 - while connections and (start_shutdown < graceful): - loop.run_until_complete(asyncio.sleep(0.1)) - start_shutdown = start_shutdown + 0.1 - - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - coros = [] - for conn in connections: - if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) - else: - conn.close() - - _shutdown = asyncio.gather(*coros) - loop.run_until_complete(_shutdown) - - trigger_events(after_stop, loop) - - remove_unix_socket(unix) - - -def _build_protocol_kwargs( - protocol: Type[asyncio.Protocol], config: Config -) -> Dict[str, Union[int, float]]: - if hasattr(protocol, "websocket_handshake"): - return { - "websocket_max_size": config.WEBSOCKET_MAX_SIZE, - "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, - "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, - "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, - "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, - "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, - } - return {} - - -def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: - """Create TCP server socket. - :param host: IPv4, IPv6 or hostname may be specified - :param port: TCP port number - :param backlog: Maximum number of connections to queue - :return: socket.socket object - """ - try: # IP address: family must be specified for IPv6 at least - ip = ip_address(host) - host = str(ip) - sock = socket.socket( - socket.AF_INET6 if ip.version == 6 else socket.AF_INET - ) - except ValueError: # Hostname, may become AF_INET or AF_INET6 - sock = socket.socket() - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((host, port)) - sock.listen(backlog) - return sock - - -def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: - """Create unix socket. - :param path: filesystem path - :param backlog: Maximum number of connections to queue - :return: socket.socket object - """ - """Open or atomically replace existing socket with zero downtime.""" - # Sanitise and pre-verify socket path - path = os.path.abspath(path) - folder = os.path.dirname(path) - if not os.path.isdir(folder): - raise FileNotFoundError(f"Socket folder does not exist: {folder}") - try: - if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): - raise FileExistsError(f"Existing file is not a socket: {path}") - except FileNotFoundError: - pass - # Create new socket with a random temporary name - tmp_path = f"{path}.{secrets.token_urlsafe()}" - sock = socket.socket(socket.AF_UNIX) - try: - # Critical section begins (filename races) - sock.bind(tmp_path) - try: - os.chmod(tmp_path, mode) - # Start listening before rename to avoid connection failures - sock.listen(backlog) - os.rename(tmp_path, path) - except: # noqa: E722 - try: - os.unlink(tmp_path) - finally: - raise - except: # noqa: E722 - try: - sock.close() - finally: - raise - return sock - - -def remove_unix_socket(path: Optional[str]) -> None: - """Remove dead unix socket during server exit.""" - if not path: - return - try: - if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): - # Is it actually dead (doesn't belong to a new server instance)? - with socket.socket(socket.AF_UNIX) as testsock: - try: - testsock.connect(path) - except ConnectionRefusedError: - os.unlink(path) - except FileNotFoundError: - pass - - -def serve_single(server_settings): - main_start = server_settings.pop("main_start", None) - main_stop = server_settings.pop("main_stop", None) - - if not server_settings.get("run_async"): - # create new event_loop after fork - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - server_settings["loop"] = loop - - trigger_events(main_start, server_settings["loop"]) - serve(**server_settings) - trigger_events(main_stop, server_settings["loop"]) - - server_settings["loop"].close() - - -def serve_multiple(server_settings, workers): - """Start multiple server processes simultaneously. Stop on interrupt - and terminate signals, and drain connections when complete. - - :param server_settings: kw arguments to be passed to the serve function - :param workers: number of workers to launch - :param stop_event: if provided, is used as a stop signal - :return: - """ - server_settings["reuse_port"] = True - server_settings["run_multiple"] = True - - main_start = server_settings.pop("main_start", None) - main_stop = server_settings.pop("main_stop", None) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - trigger_events(main_start, loop) - - # Create a listening socket or use the one in settings - sock = server_settings.get("sock") - unix = server_settings["unix"] - backlog = server_settings["backlog"] - if unix: - sock = bind_unix_socket(unix, backlog=backlog) - server_settings["unix"] = unix - if sock is None: - sock = bind_socket( - server_settings["host"], server_settings["port"], backlog=backlog - ) - sock.set_inheritable(True) - server_settings["sock"] = sock - server_settings["host"] = None - server_settings["port"] = None - - processes = [] - - def sig_handler(signal, frame): - logger.info("Received signal %s. Shutting down.", Signals(signal).name) - for process in processes: - os.kill(process.pid, SIGTERM) - - signal_func(SIGINT, lambda s, f: sig_handler(s, f)) - signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) - mp = multiprocessing.get_context("fork") - - for _ in range(workers): - process = mp.Process(target=serve, kwargs=server_settings) - process.daemon = True - process.start() - processes.append(process) - - for process in processes: - process.join() - - # the above processes will block this until they're stopped - for process in processes: - process.terminate() - - trigger_events(main_stop, loop) - - sock.close() - loop.close() - remove_unix_socket(unix) diff --git a/sanic/server/__init__.py b/sanic/server/__init__.py new file mode 100644 index 00000000..8e26dcd0 --- /dev/null +++ b/sanic/server/__init__.py @@ -0,0 +1,26 @@ +import asyncio + +from sanic.models.server_types import ConnInfo, Signal +from sanic.server.async_server import AsyncioServer +from sanic.server.protocols.http_protocol import HttpProtocol +from sanic.server.runners import serve, serve_multiple, serve_single + + +try: + import uvloop # type: ignore + + if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +__all__ = ( + "AsyncioServer", + "ConnInfo", + "HttpProtocol", + "Signal", + "serve", + "serve_multiple", + "serve_single", +) diff --git a/sanic/server/async_server.py b/sanic/server/async_server.py new file mode 100644 index 00000000..33b8b4c0 --- /dev/null +++ b/sanic/server/async_server.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import asyncio + +from sanic.exceptions import SanicException + + +class AsyncioServer: + """ + Wraps an asyncio server with functionality that might be useful to + a user who needs to manage the server lifecycle manually. + """ + + __slots__ = ("app", "connections", "loop", "serve_coro", "server", "init") + + def __init__( + self, + app, + loop, + serve_coro, + connections, + ): + # Note, Sanic already called "before_server_start" events + # before this helper was even created. So we don't need it here. + self.app = app + self.connections = connections + self.loop = loop + self.serve_coro = serve_coro + self.server = None + self.init = False + + def startup(self): + """ + Trigger "before_server_start" events + """ + self.init = True + return self.app._startup() + + def before_start(self): + """ + Trigger "before_server_start" events + """ + return self._server_event("init", "before") + + def after_start(self): + """ + Trigger "after_server_start" events + """ + return self._server_event("init", "after") + + def before_stop(self): + """ + Trigger "before_server_stop" events + """ + return self._server_event("shutdown", "before") + + def after_stop(self): + """ + Trigger "after_server_stop" events + """ + return self._server_event("shutdown", "after") + + def is_serving(self) -> bool: + if self.server: + return self.server.is_serving() + return False + + def wait_closed(self): + if self.server: + return self.server.wait_closed() + + def close(self): + if self.server: + self.server.close() + coro = self.wait_closed() + task = asyncio.ensure_future(coro, loop=self.loop) + return task + + def start_serving(self): + if self.server: + try: + return self.server.start_serving() + except AttributeError: + raise NotImplementedError( + "server.start_serving not available in this version " + "of asyncio or uvloop." + ) + + def serve_forever(self): + if self.server: + try: + return self.server.serve_forever() + except AttributeError: + raise NotImplementedError( + "server.serve_forever not available in this version " + "of asyncio or uvloop." + ) + + def _server_event(self, concern: str, action: str): + if not self.init: + raise SanicException( + "Cannot dispatch server event without " + "first running server.startup()" + ) + return self.app._server_event(concern, action, loop=self.loop) + + def __await__(self): + """ + Starts the asyncio server, returns AsyncServerCoro + """ + task = asyncio.ensure_future(self.serve_coro) + while not task.done(): + yield + self.server = task.result() + return self diff --git a/sanic/server/events.py b/sanic/server/events.py new file mode 100644 index 00000000..3b71281d --- /dev/null +++ b/sanic/server/events.py @@ -0,0 +1,16 @@ +from inspect import isawaitable +from typing import Any, Callable, Iterable, Optional + + +def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): + """ + Trigger event callbacks (functions or async) + + :param events: one or more sync or async functions to execute + :param loop: event loop + """ + if events: + for event in events: + result = event(loop) + if isawaitable(result): + loop.run_until_complete(result) diff --git a/sanic/server/protocols/__init__.py b/sanic/server/protocols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sanic/server/protocols/base_protocol.py b/sanic/server/protocols/base_protocol.py new file mode 100644 index 00000000..63d4bfb5 --- /dev/null +++ b/sanic/server/protocols/base_protocol.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + + +if TYPE_CHECKING: + from sanic.app import Sanic + +import asyncio + +from asyncio import CancelledError +from asyncio.transports import Transport +from time import monotonic as current_time + +from sanic.log import error_logger +from sanic.models.server_types import ConnInfo, Signal + + +class SanicProtocol(asyncio.Protocol): + __slots__ = ( + "app", + # event loop, connection + "loop", + "transport", + "connections", + "conn_info", + "signal", + "_can_write", + "_time", + "_task", + "_unix", + "_data_received", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + unix=None, + **kwargs, + ): + asyncio.set_event_loop(loop) + self.loop = loop + self.app: Sanic = app + self.signal = signal or Signal() + self.transport: Optional[Transport] = None + self.connections = connections if connections is not None else set() + self.conn_info: Optional[ConnInfo] = None + self._can_write = asyncio.Event() + self._can_write.set() + self._unix = unix + self._time = 0.0 # type: float + self._task = None # type: Optional[asyncio.Task] + self._data_received = asyncio.Event() + + @property + def ctx(self): + if self.conn_info is not None: + return self.conn_info.ctx + else: + return None + + async def send(self, data): + """ + Generic data write implementation with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + self.transport.write(data) + self._time = current_time() + + async def receive_more(self): + """ + Wait until more data is received into the Server protocol's buffer + """ + self.transport.resume_reading() + self._data_received.clear() + await self._data_received.wait() + + def close(self, timeout: Optional[float] = None): + """ + Attempt close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.close() + if timeout is None: + timeout = self.app.config.GRACEFUL_SHUTDOWN_TIMEOUT + self.loop.call_later(timeout, self.abort) + + def abort(self): + """ + Force close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.abort() + self.transport = None + + # asyncio.Protocol API Callbacks # + # ------------------------------ # + def connection_made(self, transport): + """ + Generic connection-made, with no connection_task, and no recv_buffer. + Override this for protocol-specific connection implementations. + """ + try: + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def connection_lost(self, exc): + try: + self.connections.discard(self) + self.resume_writing() + if self._task: + self._task.cancel() + except BaseException: + error_logger.exception("protocol.connection_lost") + + def pause_writing(self): + self._can_write.clear() + + def resume_writing(self): + self._can_write.set() + + def data_received(self, data: bytes): + try: + self._time = current_time() + if not data: + return self.close() + + if self._data_received: + self._data_received.set() + except BaseException: + error_logger.exception("protocol.data_received") diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py new file mode 100644 index 00000000..409f5e4b --- /dev/null +++ b/sanic/server/protocols/http_protocol.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from sanic.touchup.meta import TouchUpMeta + + +if TYPE_CHECKING: + from sanic.app import Sanic + +from asyncio import CancelledError +from time import monotonic as current_time + +from sanic.exceptions import RequestTimeout, ServiceUnavailable +from sanic.http import Http, Stage +from sanic.log import error_logger, logger +from sanic.models.server_types import ConnInfo +from sanic.request import Request +from sanic.server.protocols.base_protocol import SanicProtocol + + +class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): + """ + This class provides implements the HTTP 1.1 protocol on top of our + Sanic Server transport + """ + + __touchup__ = ( + "send", + "connection_task", + ) + __slots__ = ( + # request params + "request", + # request config + "request_handler", + "request_timeout", + "response_timeout", + "keep_alive_timeout", + "request_max_size", + "request_class", + "error_handler", + # enable or disable access log purpose + "access_log", + # connection management + "state", + "url", + "_handler_task", + "_http", + "_exception", + "recv_buffer", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + state=None, + unix=None, + **kwargs, + ): + super().__init__( + loop=loop, + app=app, + signal=signal, + connections=connections, + unix=unix, + ) + self.url = None + self.request: Optional[Request] = None + self.access_log = self.app.config.ACCESS_LOG + self.request_handler = self.app.handle_request + self.error_handler = self.app.error_handler + self.request_timeout = self.app.config.REQUEST_TIMEOUT + self.response_timeout = self.app.config.RESPONSE_TIMEOUT + self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT + self.request_max_size = self.app.config.REQUEST_MAX_SIZE + self.request_class = self.app.request_class or Request + self.state = state if state else {} + if "requests_count" not in self.state: + self.state["requests_count"] = 0 + self._exception = None + + def _setup_connection(self): + self._http = Http(self) + self._time = current_time() + self.check_timeouts() + + async def connection_task(self): # no cov + """ + Run a HTTP connection. + + Timeouts and some additional error handling occur here, while most of + everything else happens in class Http or in code called from there. + """ + try: + self._setup_connection() + await self.app.dispatch( + "http.lifecycle.begin", + inline=True, + context={"conn_info": self.conn_info}, + ) + await self._http.http1() + except CancelledError: + pass + except Exception: + error_logger.exception("protocol.connection_task uncaught") + finally: + if ( + self.app.debug + and self._http + and self.transport + and not self._http.upgrade_websocket + ): + ip = self.transport.get_extra_info("peername") + error_logger.error( + "Connection lost before response written" + f" @ {ip} {self._http.request}" + ) + self._http = None + self._task = None + try: + self.close() + except BaseException: + error_logger.exception("Closing failed") + finally: + await self.app.dispatch( + "http.lifecycle.complete", + inline=True, + context={"conn_info": self.conn_info}, + ) + # Important to keep this Ellipsis here for the TouchUp module + ... + + def check_timeouts(self): + """ + Runs itself periodically to enforce any expired timeouts. + """ + try: + if not self._task: + return + duration = current_time() - self._time + stage = self._http.stage + if stage is Stage.IDLE and duration > self.keep_alive_timeout: + logger.debug("KeepAlive Timeout. Closing connection.") + elif stage is Stage.REQUEST and duration > self.request_timeout: + logger.debug("Request Timeout. Closing connection.") + self._http.exception = RequestTimeout("Request Timeout") + elif stage is Stage.HANDLER and self._http.upgrade_websocket: + logger.debug("Handling websocket. Timeouts disabled.") + return + elif ( + stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) + and duration > self.response_timeout + ): + logger.debug("Response Timeout. Closing connection.") + self._http.exception = ServiceUnavailable("Response Timeout") + else: + interval = ( + min( + self.keep_alive_timeout, + self.request_timeout, + self.response_timeout, + ) + / 2 + ) + self.loop.call_later(max(0.1, interval), self.check_timeouts) + return + self._task.cancel() + except Exception: + error_logger.exception("protocol.check_timeouts") + + async def send(self, data): # no cov + """ + Writes HTTP data with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + await self.app.dispatch( + "http.lifecycle.send", + inline=True, + context={"data": data}, + ) + self.transport.write(data) + self._time = current_time() + + def close_if_idle(self) -> bool: + """ + Close the connection if a request is not being sent or received + + :return: boolean - True if closed, false if staying open + """ + if self._http is None or self._http.stage is Stage.IDLE: + self.close() + return True + return False + + # -------------------------------------------- # + # Only asyncio.Protocol callbacks below this + # -------------------------------------------- # + + def connection_made(self, transport): + """ + HTTP-protocol-specific new connection handler + """ + try: + # TODO: Benchmark to find suitable write buffer limits + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self._task = self.loop.create_task(self.connection_task()) + self.recv_buffer = bytearray() + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def data_received(self, data: bytes): + + try: + self._time = current_time() + if not data: + return self.close() + self.recv_buffer += data + + if ( + len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE + and self.transport + ): + self.transport.pause_reading() + + if self._data_received: + self._data_received.set() + except Exception: + error_logger.exception("protocol.data_received") diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py new file mode 100644 index 00000000..457f1cd0 --- /dev/null +++ b/sanic/server/protocols/websocket_protocol.py @@ -0,0 +1,164 @@ +from typing import TYPE_CHECKING, Optional, Sequence + +from websockets.connection import CLOSED, CLOSING, OPEN +from websockets.server import ServerConnection + +from sanic.exceptions import ServerError +from sanic.log import error_logger +from sanic.server import HttpProtocol + +from ..websockets.impl import WebsocketImplProtocol + + +if TYPE_CHECKING: + from websockets import http11 + + +class WebSocketProtocol(HttpProtocol): + + websocket: Optional[WebsocketImplProtocol] + websocket_timeout: float + websocket_max_size = Optional[int] + websocket_ping_interval = Optional[float] + websocket_ping_timeout = Optional[float] + + def __init__( + self, + *args, + websocket_timeout: float = 10.0, + websocket_max_size: Optional[int] = None, + websocket_max_queue: Optional[int] = None, # max_queue is deprecated + websocket_read_limit: Optional[int] = None, # read_limit is deprecated + websocket_write_limit: Optional[int] = None, # write_limit deprecated + websocket_ping_interval: Optional[float] = 20.0, + websocket_ping_timeout: Optional[float] = 20.0, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.websocket = None + self.websocket_timeout = websocket_timeout + self.websocket_max_size = websocket_max_size + if websocket_max_queue is not None and websocket_max_queue > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses queueing, so websocket_max_queue" + " is no longer required." + ) + ) + if websocket_read_limit is not None and websocket_read_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses read buffers, so " + "websocket_read_limit is not required." + ) + ) + if websocket_write_limit is not None and websocket_write_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses write buffers, so " + "websocket_write_limit is not required." + ) + ) + self.websocket_ping_interval = websocket_ping_interval + self.websocket_ping_timeout = websocket_ping_timeout + + def connection_lost(self, exc): + if self.websocket is not None: + self.websocket.connection_lost(exc) + super().connection_lost(exc) + + def data_received(self, data): + if self.websocket is not None: + self.websocket.data_received(data) + else: + # Pass it to HttpProtocol handler first + # That will (hopefully) upgrade it to a websocket. + super().data_received(data) + + def eof_received(self) -> Optional[bool]: + if self.websocket is not None: + return self.websocket.eof_received() + else: + return False + + def close(self, timeout: Optional[float] = None): + # Called by HttpProtocol at the end of connection_task + # If we've upgraded to websocket, we do our own closing + if self.websocket is not None: + # Note, we don't want to use websocket.close() + # That is used for user's application code to send a + # websocket close packet. This is different. + self.websocket.end_connection(1001) + else: + super().close() + + def close_if_idle(self): + # Called by Sanic Server when shutting down + # If we've upgraded to websocket, shut it down + if self.websocket is not None: + if self.websocket.connection.state in (CLOSING, CLOSED): + return True + elif self.websocket.loop is not None: + self.websocket.loop.create_task(self.websocket.close(1001)) + else: + self.websocket.end_connection(1001) + else: + return super().close_if_idle() + + async def websocket_handshake( + self, request, subprotocols=Optional[Sequence[str]] + ): + # let the websockets package do the handshake with the client + try: + if subprotocols is not None: + # subprotocols can be a set or frozenset, + # but ServerConnection needs a list + subprotocols = list(subprotocols) + ws_conn = ServerConnection( + max_size=self.websocket_max_size, + subprotocols=subprotocols, + state=OPEN, + logger=error_logger, + ) + resp: "http11.Response" = ws_conn.accept(request) + except Exception: + msg = ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ) + raise ServerError(msg, status_code=500) + if 100 <= resp.status_code <= 299: + rbody = "".join( + [ + "HTTP/1.1 ", + str(resp.status_code), + " ", + resp.reason_phrase, + "\r\n", + ] + ) + rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) + if resp.body is not None: + rbody += f"\r\n{resp.body}\r\n\r\n" + else: + rbody += "\r\n" + await super().send(rbody.encode()) + else: + raise ServerError(resp.body, resp.status_code) + self.websocket = WebsocketImplProtocol( + ws_conn, + ping_interval=self.websocket_ping_interval, + ping_timeout=self.websocket_ping_timeout, + close_timeout=self.websocket_timeout, + ) + loop = ( + request.transport.loop + if hasattr(request, "transport") + and hasattr(request.transport, "loop") + else None + ) + await self.websocket.connection_made(self, loop=loop) + return self.websocket diff --git a/sanic/server/runners.py b/sanic/server/runners.py new file mode 100644 index 00000000..f0bebb03 --- /dev/null +++ b/sanic/server/runners.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from ssl import SSLContext +from typing import TYPE_CHECKING, Dict, Optional, Type, Union + +from sanic.config import Config +from sanic.server.events import trigger_events + + +if TYPE_CHECKING: + from sanic.app import Sanic + +import asyncio +import multiprocessing +import os +import socket + +from functools import partial +from signal import SIG_IGN, SIGINT, SIGTERM, Signals +from signal import signal as signal_func + +from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows +from sanic.log import error_logger, logger +from sanic.models.server_types import Signal +from sanic.server.async_server import AsyncioServer +from sanic.server.protocols.http_protocol import HttpProtocol +from sanic.server.socket import ( + bind_socket, + bind_unix_socket, + remove_unix_socket, +) + + +def serve( + host, + port, + app: Sanic, + ssl: Optional[SSLContext] = None, + sock: Optional[socket.socket] = None, + unix: Optional[str] = None, + reuse_port: bool = False, + loop=None, + protocol: Type[asyncio.Protocol] = HttpProtocol, + backlog: int = 100, + register_sys_signals: bool = True, + run_multiple: bool = False, + run_async: bool = False, + connections=None, + signal=Signal(), + state=None, + asyncio_server_kwargs=None, +): + """Start asynchronous HTTP Server on an individual process. + + :param host: Address to host on + :param port: Port to host on + :param before_start: function to be executed before the server starts + listening. Takes arguments `app` instance and `loop` + :param after_start: function to be executed after the server starts + listening. Takes arguments `app` instance and `loop` + :param before_stop: function to be executed when a stop signal is + received before it is respected. Takes arguments + `app` instance and `loop` + :param after_stop: function to be executed when a stop signal is + received after it is respected. Takes arguments + `app` instance and `loop` + :param ssl: SSLContext + :param sock: Socket for the server to accept connections from + :param unix: Unix socket to listen on instead of TCP port + :param reuse_port: `True` for multiple workers + :param loop: asyncio compatible event loop + :param run_async: bool: Do not create a new event loop for the server, + and return an AsyncServer object rather than running it + :param asyncio_server_kwargs: key-value args for asyncio/uvloop + create_server method + :return: Nothing + """ + if not run_async and not loop: + # create new event_loop after fork + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if app.debug: + loop.set_debug(app.debug) + + app.asgi = False + + connections = connections if connections is not None else set() + protocol_kwargs = _build_protocol_kwargs(protocol, app.config) + server = partial( + protocol, + loop=loop, + connections=connections, + signal=signal, + app=app, + state=state, + unix=unix, + **protocol_kwargs, + ) + asyncio_server_kwargs = ( + asyncio_server_kwargs if asyncio_server_kwargs else {} + ) + # UNIX sockets are always bound by us (to preserve semantics between modes) + if unix: + sock = bind_unix_socket(unix, backlog=backlog) + server_coroutine = loop.create_server( + server, + None if sock else host, + None if sock else port, + ssl=ssl, + reuse_port=reuse_port, + sock=sock, + backlog=backlog, + **asyncio_server_kwargs, + ) + + if run_async: + return AsyncioServer( + app=app, + loop=loop, + serve_coro=server_coroutine, + connections=connections, + ) + + loop.run_until_complete(app._startup()) + loop.run_until_complete(app._server_event("init", "before")) + + try: + http_server = loop.run_until_complete(server_coroutine) + except BaseException: + error_logger.exception("Unable to start server") + return + + # Ignore SIGINT when run_multiple + if run_multiple: + signal_func(SIGINT, SIG_IGN) + + # Register signals for graceful termination + if register_sys_signals: + if OS_IS_WINDOWS: + ctrlc_workaround_for_windows(app) + else: + for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: + loop.add_signal_handler(_signal, app.stop) + + loop.run_until_complete(app._server_event("init", "after")) + pid = os.getpid() + try: + logger.info("Starting worker [%s]", pid) + loop.run_forever() + finally: + logger.info("Stopping worker [%s]", pid) + + # Run the on_stop function if provided + loop.run_until_complete(app._server_event("shutdown", "before")) + + # Wait for event loop to finish and all connections to drain + http_server.close() + loop.run_until_complete(http_server.wait_closed()) + + # Complete all tasks on the loop + signal.stopped = True + for connection in connections: + connection.close_if_idle() + + # Gracefully shutdown timeout. + # We should provide graceful_shutdown_timeout, + # instead of letting connection hangs forever. + # Let's roughly calcucate time. + graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT + start_shutdown: float = 0 + while connections and (start_shutdown < graceful): + loop.run_until_complete(asyncio.sleep(0.1)) + start_shutdown = start_shutdown + 0.1 + + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + for conn in connections: + if hasattr(conn, "websocket") and conn.websocket: + conn.websocket.fail_connection(code=1001) + else: + conn.abort() + loop.run_until_complete(app._server_event("shutdown", "after")) + + remove_unix_socket(unix) + + +def serve_single(server_settings): + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + + if not server_settings.get("run_async"): + # create new event_loop after fork + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + server_settings["loop"] = loop + + trigger_events(main_start, server_settings["loop"]) + serve(**server_settings) + trigger_events(main_stop, server_settings["loop"]) + + server_settings["loop"].close() + + +def serve_multiple(server_settings, workers): + """Start multiple server processes simultaneously. Stop on interrupt + and terminate signals, and drain connections when complete. + + :param server_settings: kw arguments to be passed to the serve function + :param workers: number of workers to launch + :param stop_event: if provided, is used as a stop signal + :return: + """ + server_settings["reuse_port"] = True + server_settings["run_multiple"] = True + + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + trigger_events(main_start, loop) + + # Create a listening socket or use the one in settings + sock = server_settings.get("sock") + unix = server_settings["unix"] + backlog = server_settings["backlog"] + if unix: + sock = bind_unix_socket(unix, backlog=backlog) + server_settings["unix"] = unix + if sock is None: + sock = bind_socket( + server_settings["host"], server_settings["port"], backlog=backlog + ) + sock.set_inheritable(True) + server_settings["sock"] = sock + server_settings["host"] = None + server_settings["port"] = None + + processes = [] + + def sig_handler(signal, frame): + logger.info("Received signal %s. Shutting down.", Signals(signal).name) + for process in processes: + os.kill(process.pid, SIGTERM) + + signal_func(SIGINT, lambda s, f: sig_handler(s, f)) + signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) + mp = multiprocessing.get_context("fork") + + for _ in range(workers): + process = mp.Process(target=serve, kwargs=server_settings) + process.daemon = True + process.start() + processes.append(process) + + for process in processes: + process.join() + + # the above processes will block this until they're stopped + for process in processes: + process.terminate() + + trigger_events(main_stop, loop) + + sock.close() + loop.close() + remove_unix_socket(unix) + + +def _build_protocol_kwargs( + protocol: Type[asyncio.Protocol], config: Config +) -> Dict[str, Union[int, float]]: + if hasattr(protocol, "websocket_handshake"): + return { + "websocket_max_size": config.WEBSOCKET_MAX_SIZE, + "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, + "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, + } + return {} diff --git a/sanic/server/socket.py b/sanic/server/socket.py new file mode 100644 index 00000000..3d908306 --- /dev/null +++ b/sanic/server/socket.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import os +import secrets +import socket +import stat + +from ipaddress import ip_address +from typing import Optional + + +def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: + """Create TCP server socket. + :param host: IPv4, IPv6 or hostname may be specified + :param port: TCP port number + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + try: # IP address: family must be specified for IPv6 at least + ip = ip_address(host) + host = str(ip) + sock = socket.socket( + socket.AF_INET6 if ip.version == 6 else socket.AF_INET + ) + except ValueError: # Hostname, may become AF_INET or AF_INET6 + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, port)) + sock.listen(backlog) + return sock + + +def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: + """Create unix socket. + :param path: filesystem path + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + """Open or atomically replace existing socket with zero downtime.""" + # Sanitise and pre-verify socket path + path = os.path.abspath(path) + folder = os.path.dirname(path) + if not os.path.isdir(folder): + raise FileNotFoundError(f"Socket folder does not exist: {folder}") + try: + if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + raise FileExistsError(f"Existing file is not a socket: {path}") + except FileNotFoundError: + pass + # Create new socket with a random temporary name + tmp_path = f"{path}.{secrets.token_urlsafe()}" + sock = socket.socket(socket.AF_UNIX) + try: + # Critical section begins (filename races) + sock.bind(tmp_path) + try: + os.chmod(tmp_path, mode) + # Start listening before rename to avoid connection failures + sock.listen(backlog) + os.rename(tmp_path, path) + except: # noqa: E722 + try: + os.unlink(tmp_path) + finally: + raise + except: # noqa: E722 + try: + sock.close() + finally: + raise + return sock + + +def remove_unix_socket(path: Optional[str]) -> None: + """Remove dead unix socket during server exit.""" + if not path: + return + try: + if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + # Is it actually dead (doesn't belong to a new server instance)? + with socket.socket(socket.AF_UNIX) as testsock: + try: + testsock.connect(path) + except ConnectionRefusedError: + os.unlink(path) + except FileNotFoundError: + pass diff --git a/sanic/server/websockets/__init__.py b/sanic/server/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sanic/server/websockets/connection.py b/sanic/server/websockets/connection.py new file mode 100644 index 00000000..c53a65a5 --- /dev/null +++ b/sanic/server/websockets/connection.py @@ -0,0 +1,82 @@ +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + MutableMapping, + Optional, + Union, +) + + +ASIMessage = MutableMapping[str, Any] + + +class WebSocketConnection: + """ + This is for ASGI Connections. + It provides an interface similar to WebsocketProtocol, but + sends/receives over an ASGI connection. + """ + + # TODO + # - Implement ping/pong + + def __init__( + self, + send: Callable[[ASIMessage], Awaitable[None]], + receive: Callable[[], Awaitable[ASIMessage]], + subprotocols: Optional[List[str]] = None, + ) -> None: + self._send = send + self._receive = receive + self._subprotocols = subprotocols or [] + + async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: + message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} + + if isinstance(data, bytes): + message.update({"bytes": data}) + else: + message.update({"text": str(data)}) + + await self._send(message) + + async def recv(self, *args, **kwargs) -> Optional[str]: + message = await self._receive() + + if message["type"] == "websocket.receive": + return message["text"] + elif message["type"] == "websocket.disconnect": + pass + + return None + + receive = recv + + async def accept(self, subprotocols: Optional[List[str]] = None) -> None: + subprotocol = None + if subprotocols: + for subp in subprotocols: + if subp in self.subprotocols: + subprotocol = subp + break + + await self._send( + { + "type": "websocket.accept", + "subprotocol": subprotocol, + } + ) + + async def close(self, code: int = 1000, reason: str = "") -> None: + pass + + @property + def subprotocols(self): + return self._subprotocols + + @subprotocols.setter + def subprotocols(self, subprotocols: Optional[List[str]] = None): + self._subprotocols = subprotocols or [] diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py new file mode 100644 index 00000000..fef27db1 --- /dev/null +++ b/sanic/server/websockets/frame.py @@ -0,0 +1,294 @@ +import asyncio +import codecs + +from typing import TYPE_CHECKING, AsyncIterator, List, Optional + +from websockets.frames import Frame, Opcode +from websockets.typing import Data + +from sanic.exceptions import ServerError + + +if TYPE_CHECKING: + from .impl import WebsocketImplProtocol + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class WebsocketFrameAssembler: + """ + Assemble a message from frames. + Code borrowed from aaugustin/websockets project: + https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py + """ + + __slots__ = ( + "protocol", + "read_mutex", + "write_mutex", + "message_complete", + "message_fetched", + "get_in_progress", + "decoder", + "completed_queue", + "chunks", + "chunks_queue", + "paused", + "get_id", + "put_id", + ) + if TYPE_CHECKING: + protocol: "WebsocketImplProtocol" + read_mutex: asyncio.Lock + write_mutex: asyncio.Lock + message_complete: asyncio.Event + message_fetched: asyncio.Event + completed_queue: asyncio.Queue + get_in_progress: bool + decoder: Optional[codecs.IncrementalDecoder] + # For streaming chunks rather than messages: + chunks: List[Data] + chunks_queue: Optional[asyncio.Queue[Optional[Data]]] + paused: bool + + def __init__(self, protocol) -> None: + + self.protocol = protocol + + self.read_mutex = asyncio.Lock() + self.write_mutex = asyncio.Lock() + + self.completed_queue = asyncio.Queue( + maxsize=1 + ) # type: asyncio.Queue[Data] + + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = asyncio.Event() + # get() sets this event to let put() + self.message_fetched = asyncio.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder = None + + # Buffer data from frames belonging to the same message. + self.chunks = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a Queue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + self.chunks_queue = None + + # Flag to indicate we've paused the protocol + self.paused = False + + async def get(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Read the next message. + :meth:`get` returns a single :class:`str` or :class:`bytes`. + If the :message was fragmented, :meth:`get` waits until the last frame + is received, then it reassembles the message. + If ``timeout`` is set and elapses before a complete message is + received, :meth:`get` returns ``None``. + """ + async with self.read_mutex: + if timeout is not None and timeout <= 0: + if not self.message_complete.is_set(): + return None + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get() on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one task can get here. + if timeout is None: + completed = await self.message_complete.wait() + elif timeout <= 0: + completed = self.message_complete.is_set() + else: + try: + await asyncio.wait_for( + self.message_complete.wait(), timeout=timeout + ) + except asyncio.TimeoutError: + ... + finally: + completed = self.message_complete.is_set() + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + return None + if not self.message_complete.is_set(): + return None + + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is here + # as a failsafe + raise ServerError( + "Websocket get() found a message when " + "state was already fetched." + ) + self.message_fetched.set() + self.chunks = [] + # this should already be None, but set it here for safety + self.chunks_queue = None + return message + + async def get_iter(self) -> AsyncIterator[Data]: + """ + Stream the next message. + Iterating the return value of :meth:`get_iter` yields a :class:`str` + or :class:`bytes` for each frame in the message. + """ + async with self.read_mutex: + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get_iter on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = asyncio.Queue() + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + await self.chunks_queue.put(None) + + # Locking with get_in_progress ensures only one task can get here + for c in chunks: + yield c + while True: + chunk = await self.chunks_queue.get() + if chunk is None: + break + yield chunk + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + if not self.message_complete.is_set(): + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "Websocket frame assembler chunks queue ended before " + "message was complete." + ) + self.message_complete.clear() + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is + # here as a failsafe + raise ServerError( + "Websocket get_iter() found a message when state was " + "already fetched." + ) + + self.message_fetched.set() + # this should already be empty, but set it here for safety + self.chunks = [] + self.chunks_queue = None + + async def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + When ``frame`` is the final frame in a message, :meth:`put` waits + until the message is fetched, either by calling :meth:`get` or by + iterating the return value of :meth:`get_iter`. + :meth:`put` assumes that the stream of frames respects the protocol. + If it doesn't, the behavior is undefined. + """ + + async with self.write_mutex: + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + if self.chunks_queue is None: + self.chunks.append(data) + else: + await self.chunks_queue.put(data) + + if not frame.fin: + return + if not self.get_in_progress: + # nobody is waiting for this frame, so try to pause subsequent + # frames at the protocol level + self.paused = self.protocol.pause_frames() + # Message is complete. Wait until it's fetched to return. + + if self.chunks_queue is not None: + await self.chunks_queue.put(None) + if self.message_complete.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when a message was " + "already in its chamber." + ) + self.message_complete.set() # Signal to get() it can serve the + if self.message_fetched.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when the previous " + "message was not yet fetched." + ) + + # Allow get() to run and eventually set the event. + await self.message_fetched.wait() + self.message_fetched.clear() + self.decoder = None diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py new file mode 100644 index 00000000..ed0d7fed --- /dev/null +++ b/sanic/server/websockets/impl.py @@ -0,0 +1,834 @@ +import asyncio +import random +import struct + +from typing import ( + AsyncIterator, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) + +from websockets.connection import CLOSED, CLOSING, OPEN, Event +from websockets.exceptions import ConnectionClosed, ConnectionClosedError +from websockets.frames import Frame, Opcode +from websockets.server import ServerConnection +from websockets.typing import Data + +from sanic.log import error_logger, logger +from sanic.server.protocols.base_protocol import SanicProtocol + +from ...exceptions import ServerError, WebsocketClosed +from .frame import WebsocketFrameAssembler + + +class WebsocketImplProtocol: + connection: ServerConnection + io_proto: Optional[SanicProtocol] + loop: Optional[asyncio.AbstractEventLoop] + max_queue: int + close_timeout: float + ping_interval: Optional[float] + ping_timeout: Optional[float] + assembler: WebsocketFrameAssembler + # Dict[bytes, asyncio.Future[None]] + pings: Dict[bytes, asyncio.Future] + conn_mutex: asyncio.Lock + recv_lock: asyncio.Lock + recv_cancel: Optional[asyncio.Future] + process_event_mutex: asyncio.Lock + can_pause: bool + # Optional[asyncio.Future[None]] + data_finished_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + pause_frame_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + connection_lost_waiter: Optional[asyncio.Future] + keepalive_ping_task: Optional[asyncio.Task] + auto_closer_task: Optional[asyncio.Task] + + def __init__( + self, + connection, + max_queue=None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: float = 10, + loop=None, + ): + self.connection = connection + self.io_proto = None + self.loop = None + self.max_queue = max_queue + self.close_timeout = close_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.assembler = WebsocketFrameAssembler(self) + self.pings = {} + self.conn_mutex = asyncio.Lock() + self.recv_lock = asyncio.Lock() + self.recv_cancel = None + self.process_event_mutex = asyncio.Lock() + self.data_finished_fut = None + self.can_pause = True + self.pause_frame_fut = None + self.keepalive_ping_task = None + self.auto_closer_task = None + self.connection_lost_waiter = None + + @property + def subprotocol(self): + return self.connection.subprotocol + + def pause_frames(self): + if not self.can_pause: + return False + if self.pause_frame_fut: + logger.debug("Websocket connection already paused.") + return False + if (not self.loop) or (not self.io_proto): + return False + if self.io_proto.transport: + self.io_proto.transport.pause_reading() + self.pause_frame_fut = self.loop.create_future() + logger.debug("Websocket connection paused.") + return True + + def resume_frames(self): + if not self.pause_frame_fut: + logger.debug("Websocket connection not paused.") + return False + if (not self.loop) or (not self.io_proto): + logger.debug( + "Websocket attempting to resume reading frames, " + "but connection is gone." + ) + return False + if self.io_proto.transport: + self.io_proto.transport.resume_reading() + self.pause_frame_fut.set_result(None) + self.pause_frame_fut = None + logger.debug("Websocket connection unpaused.") + return True + + async def connection_made( + self, + io_proto: SanicProtocol, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + if not loop: + try: + loop = getattr(io_proto, "loop") + except AttributeError: + loop = asyncio.get_event_loop() + if not loop: + # This catch is for mypy type checker + # to assert loop is not None here. + raise ServerError("Connection received with no asyncio loop.") + if self.auto_closer_task: + raise ServerError( + "Cannot call connection_made more than once " + "on a websocket connection." + ) + self.loop = loop + self.io_proto = io_proto + self.connection_lost_waiter = self.loop.create_future() + self.data_finished_fut = asyncio.shield(self.loop.create_future()) + + if self.ping_interval: + self.keepalive_ping_task = asyncio.create_task( + self.keepalive_ping() + ) + self.auto_closer_task = asyncio.create_task( + self.auto_close_connection() + ) + + async def wait_for_connection_lost(self, timeout=None) -> bool: + """ + Wait until the TCP connection is closed or ``timeout`` elapses. + If timeout is None, wait forever. + Recommend you should pass in self.close_timeout as timeout + + Return ``True`` if the connection is closed and ``False`` otherwise. + + """ + if not self.connection_lost_waiter: + return False + if self.connection_lost_waiter.done(): + return True + else: + try: + await asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), timeout + ) + return True + except asyncio.TimeoutError: + # Re-check self.connection_lost_waiter.done() synchronously + # because connection_lost() could run between the moment the + # timeout occurs and the moment this coroutine resumes running + return self.connection_lost_waiter.done() + + async def process_events(self, events: Sequence[Event]) -> None: + """ + Process a list of incoming events. + """ + # Wrapped in a mutex lock, to prevent other incoming events + # from processing at the same time + async with self.process_event_mutex: + for event in events: + if not isinstance(event, Frame): + # Event is not a frame. Ignore it. + continue + if event.opcode == Opcode.PONG: + await self.process_pong(event) + elif event.opcode == Opcode.CLOSE: + if self.recv_cancel: + self.recv_cancel.cancel() + else: + await self.assembler.put(event) + + async def process_pong(self, frame: Frame) -> None: + if frame.data in self.pings: + # Acknowledge all pings up to the one matching this pong. + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + if not ping.done(): + ping.set_result(None) + if ping_id == frame.data: + break + else: # noqa + raise ServerError("ping_id is not in self.pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + This coroutine exits when the connection terminates and one of the + following happens: + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`. + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep(self.ping_interval) + + # ping() raises CancelledError if the connection is closed, + # when auto_close_connection() cancels keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + + ping_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + await asyncio.wait_for(ping_waiter, self.ping_timeout) + except asyncio.TimeoutError: + error_logger.warning( + "Websocket timed out waiting for pong" + ) + self.fail_connection(1011) + break + except asyncio.CancelledError: + # It is expected for this task to be cancelled during during + # normal operation, when the connection is closed. + logger.debug("Websocket keepalive ping task was cancelled.") + except (ConnectionClosed, WebsocketClosed): + logger.debug("Websocket closed. Keepalive ping task exiting.") + except Exception as e: + error_logger.warning( + "Unexpected exception in websocket keepalive ping task." + ) + logger.debug(str(e)) + + def _force_disconnect(self) -> bool: + """ + Internal methdod used by end_connection and fail_connection + only when the graceful auto-closer cannot be used + """ + if self.auto_closer_task and not self.auto_closer_task.done(): + self.auto_closer_task.cancel() + if self.data_finished_fut and not self.data_finished_fut.done(): + self.data_finished_fut.cancel() + self.data_finished_fut = None + if self.keepalive_ping_task and not self.keepalive_ping_task.done(): + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + if self.loop and self.io_proto and self.io_proto.transport: + self.io_proto.transport.close() + self.loop.call_later( + self.close_timeout, self.io_proto.transport.abort + ) + # We were never open, or already closed + return True + + def fail_connection(self, code: int = 1006, reason: str = "") -> bool: + """ + Fail the WebSocket Connection + This requires: + 1. Stopping all processing of incoming data, which means cancelling + pausing the underlying io protocol. The close code will be 1006 + unless a close frame was received earlier. + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + 3. Closing the connection. :meth:`auto_close_connection` takes care + of this. + (The specification describes these steps in the opposite order.) + """ + if self.io_proto and self.io_proto.transport: + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # ut can be called when the transport is already paused or closed + self.io_proto.transport.pause_reading() + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not draining the write buffer is acceptable in this context. + + # clear the send buffer + _ = self.connection.data_to_send() + # If we're not already CLOSED or CLOSING, then send the close. + if self.connection.state is OPEN: + if code in (1000, 1001): + self.connection.send_close(code, reason) + else: + self.connection.fail(code, reason) + try: + data_to_send = self.connection.data_to_send() + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + ... + if code == 1006: + # Special case: 1006 consider the transport already closed + self.connection.state = CLOSED + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have a graceful auto-closer. Use it to close the connection. + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + return self._force_disconnect() + return False + + def end_connection(self, code=1000, reason=""): + # This is like slightly more graceful form of fail_connection + # Use this instead of close() when you need an immediate + # close and cannot await websocket.close() handshake. + + if code == 1006 or not self.io_proto or not self.io_proto.transport: + return self.fail_connection(code, reason) + + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # i.e. it can be called when the transport is already paused or closed. + self.io_proto.transport.pause_reading() + if self.connection.state == OPEN: + data_to_send = self.connection.data_to_send() + self.connection.send_close(code, reason) + data_to_send.extend(self.connection.data_to_send()) + try: + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + # But that doesn't matter at this point + ... + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have the ability to signal the auto-closer + # try to trigger it to auto-close the connection + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + # Auto-closer is not running, do force disconnect + return self._force_disconnect() + return False + + async def auto_close_connection(self) -> None: + """ + Close the WebSocket Connection + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + """ + try: + # Wait for the data transfer phase to complete. + if self.data_finished_fut: + try: + await self.data_finished_fut + logger.debug( + "Websocket task finished. Closing the connection." + ) + except asyncio.CancelledError: + # Cancelled error is called when data phase is cancelled + # if an error occurred or the client closed the connection + logger.debug( + "Websocket handler cancelled. Closing the connection." + ) + + # Cancel the keepalive ping task. + if self.keepalive_ping_task: + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + + # Half-close the TCP connection if possible (when there's no TLS). + if ( + self.io_proto + and self.io_proto.transport + and self.io_proto.transport.can_write_eof() + ): + logger.debug("Websocket half-closing TCP connection") + self.io_proto.transport.write_eof() + if self.connection_lost_waiter: + if await self.wait_for_connection_lost(timeout=0): + return + except asyncio.CancelledError: + ... + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is cancelled (for example). + if (not self.io_proto) or (not self.io_proto.transport): + # we were never open, or done. Can't do any finalization. + return + elif ( + self.connection_lost_waiter + and self.connection_lost_waiter.done() + ): + # connection confirmed closed already, proceed to abort waiter + ... + elif self.io_proto.transport.is_closing(): + # Connection is already closing (due to half-close above) + # proceed to abort waiter + ... + else: + self.io_proto.transport.close() + if not self.connection_lost_waiter: + # Our connection monitor task isn't running. + try: + await asyncio.sleep(self.close_timeout) + except asyncio.CancelledError: + ... + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + else: + if await self.wait_for_connection_lost( + timeout=self.close_timeout + ): + # Connection aborted before the timeout expired. + return + error_logger.warning( + "Timeout waiting for TCP connection to close. Aborting" + ) + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + They'll never receive a pong once the connection is closed. + """ + if self.connection.state is not CLOSED: + raise ServerError( + "Webscoket about_pings should only be called " + "after connection state is changed to CLOSED" + ) + + for ping in self.pings.values(): + ping.set_exception(ConnectionClosedError(None, None)) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + ping.cancel() + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + This is a websocket-protocol level close. + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + :param code: WebSocket close code + :param reason: WebSocket close reason + """ + if code == 1006: + self.fail_connection(code, reason) + return + async with self.conn_mutex: + if self.connection.state is OPEN: + self.connection.send_close(code, reason) + data_to_send = self.connection.data_to_send() + await self.send_data(data_to_send) + + async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Receive the next message. + Return a :class:`str` for a text frame and :class:`bytes` for a binary + frame. + When the end of the message stream is reached, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + If ``timeout`` is ``None``, block until a message is received. Else, + if no message is received within ``timeout`` seconds, return ``None``. + Set ``timeout`` to ``0`` to check if a message was already received. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises asyncio.CancelledError: if the websocket closes while waiting + :raises ServerError: if two tasks call :meth:`recv` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv while another task is " + "already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + try: + self.recv_cancel = asyncio.Future() + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + else: + self.recv_cancel.cancel() + return done_task.result() + finally: + self.recv_cancel = None + self.recv_lock.release() + + async def recv_burst(self, max_recv=256) -> Sequence[Data]: + """ + Receive the messages which have arrived since last checking. + Return a :class:`list` containing :class:`str` for a text frame + and :class:`bytes` for a binary frame. + When the end of the message stream is reached, :meth:`recv_burst` + raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a + normal connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises ServerError: if two tasks call :meth:`recv_burst` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv_burst while another task is already waiting " + "for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + messages = [] + try: + # Prevent pausing the transport when we're + # receiving a burst of messages + self.can_pause = False + self.recv_cancel = asyncio.Future() + while True: + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout=0)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv_burst was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + m = done_task.result() + if m is None: + # None left in the burst. This is good! + break + messages.append(m) + if len(messages) >= max_recv: + # Too much data in the pipe. Hit our burst limit. + break + # Allow an eventloop iteration for the + # next message to pass into the Assembler + await asyncio.sleep(0) + self.recv_cancel.cancel() + finally: + self.recv_cancel = None + self.can_pause = True + self.recv_lock.release() + return messages + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + Return an iterator of :class:`str` for a text frame and :class:`bytes` + for a binary frame. The iterator should be exhausted, or else the + connection will become unusable. + With the exception of the return value, :meth:`recv_streaming` behaves + like :meth:`recv`. + """ + if self.recv_lock.locked(): + raise ServerError( + "Cannot call recv_streaming while another task " + "is already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + try: + cancelled = False + self.recv_cancel = asyncio.Future() + self.can_pause = False + async for m in self.assembler.get_iter(): + if self.recv_cancel.done(): + cancelled = True + break + yield m + if cancelled: + raise asyncio.CancelledError() + finally: + self.can_pause = True + self.recv_cancel = None + self.recv_lock.release() + + async def send(self, message: Union[Data, Iterable[Data]]) -> None: + """ + Send a message. + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects. In that case the message is fragmented. Each item + is treated as a message fragment and sent in its own frame. All items + must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + :meth:`send` rejects dict-like objects because this is often an error. + If you wish to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`. + :raises TypeError: for unsupported inputs + """ + async with self.conn_mutex: + + if self.connection.state in (CLOSED, CLOSING): + raise WebsocketClosed( + "Cannot write to websocket interface after it is closed." + ) + if (not self.data_finished_fut) or self.data_finished_fut.done(): + raise ServerError( + "Cannot write to websocket interface after it is finished." + ) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + self.connection.send_text(message.encode("utf-8")) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, (bytes, bytearray, memoryview)): + self.connection.send_binary(message) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, Mapping): + # Catch a common mistake -- passing a dict to send(). + raise TypeError("data is a dict-like object") + + elif isinstance(message, Iterable): + # Fragmented message -- regular iterator. + raise NotImplementedError( + "Fragmented websocket messages are not supported." + ) + else: + raise TypeError("Websocket data must be bytes, str.") + + async def ping(self, data: Optional[Data] = None) -> asyncio.Future: + """ + Send a ping. + Return an :class:`~asyncio.Future` that will be resolved when the + corresponding pong is received. You can ignore it if you don't intend + to wait. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point:: + await pong_event = ws.ping() + await pong_event # only if you want to wait for the pong + By default, the ping contains four random bytes. This payload may be + overridden with the optional ``data`` argument which must be a string + (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + raise WebsocketClosed( + "Cannot send a ping when the websocket interface " + "is closed." + ) + if (not self.io_proto) or (not self.io_proto.loop): + raise ServerError( + "Cannot send a ping when the websocket has no I/O " + "protocol attached." + ) + if data is not None: + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise ValueError( + "already waiting for a pong with the same data" + ) + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + self.pings[data] = self.io_proto.loop.create_future() + + self.connection.send_ping(data) + await self.send_data(self.connection.data_to_send()) + + return asyncio.shield(self.pings[data]) + + async def pong(self, data: Data = b"") -> None: + """ + Send a pong. + An unsolicited pong may serve as a unidirectional heartbeat. + The payload may be set with the optional ``data`` argument which must + be a string (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + # Cannot send pong after transport is shutting down + return + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + self.connection.send_pong(data) + await self.send_data(self.connection.data_to_send()) + + async def send_data(self, data_to_send): + for data in data_to_send: + if data: + await self.io_proto.send(data) + else: + # Send an EOF - We don't actually send it, + # just trigger to autoclose the connection + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + async def async_data_received(self, data_to_send, events_to_process): + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + # receiving data can generate data to send (eg, pong for a ping) + # send connection.data_to_send() + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + + def data_received(self, data): + self.connection.receive_data(data) + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + if len(data_to_send) > 0 or len(events_to_process) > 0: + asyncio.create_task( + self.async_data_received(data_to_send, events_to_process) + ) + + async def async_eof_received(self, data_to_send, events_to_process): + # receiving EOF can generate data to send + # send connection.data_to_send() + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + if self.recv_cancel: + self.recv_cancel.cancel() + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + # Cancel the running handler if its waiting + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + def eof_received(self) -> Optional[bool]: + self.connection.receive_eof() + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + asyncio.create_task( + self.async_eof_received(data_to_send, events_to_process) + ) + return False + + def connection_lost(self, exc): + """ + The WebSocket Connection is Closed. + """ + if not self.connection.state == CLOSED: + # signal to the websocket connection handler + # we've lost the connection + self.connection.fail(code=1006) + self.connection.state = CLOSED + + self.abort_pings() + if self.connection_lost_waiter: + self.connection_lost_waiter.set_result(None) diff --git a/sanic/signals.py b/sanic/signals.py index eec2a438..2c1a704c 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.utils import path_to_parts # type: ignore from sanic.exceptions import InvalidSignal +from sanic.log import error_logger, logger from sanic.models.handler_types import SignalHandler -RESERVED_NAMESPACES = ( - "server", - "http", -) +RESERVED_NAMESPACES = { + "server": ( + # "server.main.start", + # "server.main.stop", + "server.init.before", + "server.init.after", + "server.shutdown.before", + "server.shutdown.after", + ), + "http": ( + "http.lifecycle.begin", + "http.lifecycle.complete", + "http.lifecycle.exception", + "http.lifecycle.handle", + "http.lifecycle.read_body", + "http.lifecycle.read_head", + "http.lifecycle.request", + "http.lifecycle.response", + "http.routing.after", + "http.routing.before", + "http.lifecycle.send", + "http.middleware.after", + "http.middleware.before", + ), +} + + +def _blank(): + ... class Signal(Route): @@ -59,8 +85,13 @@ class SignalRouter(BaseRouter): terms.append(extra) raise NotFound(message % tuple(terms)) + # Regex routes evaluate and can extract params directly. They are set + # on param_basket["__params__"] params = param_basket["__params__"] if not params: + # If param_basket["__params__"] does not exist, we might have + # param_basket["__matches__"], which are indexed based matches + # on path segments. They should already be cast types. params = { param.name: param_basket["__matches__"][idx] for idx, param in group.params.items() @@ -73,8 +104,18 @@ class SignalRouter(BaseRouter): event: str, context: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, str]] = None, - ) -> None: - group, handlers, params = self.get(event, condition=condition) + fail_not_found: bool = True, + reverse: bool = False, + ) -> Any: + try: + group, handlers, params = self.get(event, condition=condition) + except NotFound as e: + if fail_not_found: + raise e + else: + if self.ctx.app.debug: + error_logger.warning(str(e)) + return None events = [signal.ctx.event for signal in group] for signal_event in events: @@ -82,12 +123,19 @@ class SignalRouter(BaseRouter): if context: params.update(context) + if not reverse: + handlers = handlers[::-1] try: for handler in handlers: if condition is None or condition == handler.__requirements__: maybe_coroutine = handler(**params) if isawaitable(maybe_coroutine): - await maybe_coroutine + retval = await maybe_coroutine + if retval: + return retval + elif maybe_coroutine: + return maybe_coroutine + return None finally: for signal_event in events: signal_event.clear() @@ -98,14 +146,23 @@ class SignalRouter(BaseRouter): *, context: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, str]] = None, - ) -> asyncio.Task: - task = self.ctx.loop.create_task( - self._dispatch( - event, - context=context, - condition=condition, - ) + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, + ) -> Union[asyncio.Task, Any]: + dispatch = self._dispatch( + event, + context=context, + condition=condition, + fail_not_found=fail_not_found and inline, + reverse=reverse, ) + logger.debug(f"Dispatching signal: {event}") + + if inline: + return await dispatch + + task = asyncio.get_running_loop().create_task(dispatch) await asyncio.sleep(0) return task @@ -131,7 +188,9 @@ class SignalRouter(BaseRouter): append=True, ) # type: ignore - def finalize(self, do_compile: bool = True): + def finalize(self, do_compile: bool = True, do_optimize: bool = False): + self.add(_blank, "sanic.__signal__.__init__") + try: self.ctx.loop = asyncio.get_running_loop() except RuntimeError: @@ -140,7 +199,7 @@ class SignalRouter(BaseRouter): for signal in self.routes: signal.ctx.event = asyncio.Event() - return super().finalize(do_compile=do_compile) + return super().finalize(do_compile=do_compile, do_optimize=do_optimize) def _build_event_parts(self, event: str) -> Tuple[str, str, str]: parts = path_to_parts(event, self.delimiter) @@ -151,7 +210,11 @@ class SignalRouter(BaseRouter): ): raise InvalidSignal("Invalid signal event: %s" % event) - if parts[0] in RESERVED_NAMESPACES: + if ( + parts[0] in RESERVED_NAMESPACES + and event not in RESERVED_NAMESPACES[parts[0]] + and not (parts[2].startswith("<") and parts[2].endswith(">")) + ): raise InvalidSignal( "Cannot declare reserved signal event: %s" % event ) diff --git a/sanic/touchup/__init__.py b/sanic/touchup/__init__.py new file mode 100644 index 00000000..6fe208ab --- /dev/null +++ b/sanic/touchup/__init__.py @@ -0,0 +1,8 @@ +from .meta import TouchUpMeta +from .service import TouchUp + + +__all__ = ( + "TouchUp", + "TouchUpMeta", +) diff --git a/sanic/touchup/meta.py b/sanic/touchup/meta.py new file mode 100644 index 00000000..9f60af38 --- /dev/null +++ b/sanic/touchup/meta.py @@ -0,0 +1,22 @@ +from sanic.exceptions import SanicException + +from .service import TouchUp + + +class TouchUpMeta(type): + def __new__(cls, name, bases, attrs, **kwargs): + gen_class = super().__new__(cls, name, bases, attrs, **kwargs) + + methods = attrs.get("__touchup__") + attrs["__touched__"] = False + if methods: + + for method in methods: + if method not in attrs: + raise SanicException( + "Cannot perform touchup on non-existent method: " + f"{name}.{method}" + ) + TouchUp.register(gen_class, method) + + return gen_class diff --git a/sanic/touchup/schemes/__init__.py b/sanic/touchup/schemes/__init__.py new file mode 100644 index 00000000..87057a5f --- /dev/null +++ b/sanic/touchup/schemes/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseScheme +from .ode import OptionalDispatchEvent # noqa + + +__all__ = ("BaseScheme",) diff --git a/sanic/touchup/schemes/base.py b/sanic/touchup/schemes/base.py new file mode 100644 index 00000000..d16619b2 --- /dev/null +++ b/sanic/touchup/schemes/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Set, Type + + +class BaseScheme(ABC): + ident: str + _registry: Set[Type] = set() + + def __init__(self, app) -> None: + self.app = app + + @abstractmethod + def run(self, method, module_globals) -> None: + ... + + def __init_subclass__(cls): + BaseScheme._registry.add(cls) + + def __call__(self, method, module_globals): + return self.run(method, module_globals) diff --git a/sanic/touchup/schemes/ode.py b/sanic/touchup/schemes/ode.py new file mode 100644 index 00000000..357f748c --- /dev/null +++ b/sanic/touchup/schemes/ode.py @@ -0,0 +1,67 @@ +from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse +from inspect import getsource +from textwrap import dedent +from typing import Any + +from sanic.log import logger + +from .base import BaseScheme + + +class OptionalDispatchEvent(BaseScheme): + ident = "ODE" + + def __init__(self, app) -> None: + super().__init__(app) + + self._registered_events = [ + signal.path for signal in app.signal_router.routes + ] + + def run(self, method, module_globals): + raw_source = getsource(method) + src = dedent(raw_source) + tree = parse(src) + node = RemoveDispatch(self._registered_events).visit(tree) + compiled_src = compile(node, method.__name__, "exec") + exec_locals: Dict[str, Any] = {} + exec(compiled_src, module_globals, exec_locals) # nosec + + return exec_locals[method.__name__] + + +class RemoveDispatch(NodeTransformer): + def __init__(self, registered_events) -> None: + self._registered_events = registered_events + + def visit_Expr(self, node: Expr) -> Any: + call = node.value + if isinstance(call, Await): + call = call.value + + func = getattr(call, "func", None) + args = getattr(call, "args", None) + if not func or not args: + return node + + if isinstance(func, Attribute) and func.attr == "dispatch": + event = args[0] + if hasattr(event, "s"): + event_name = getattr(event, "value", event.s) + if self._not_registered(event_name): + logger.debug(f"Disabling event: {event_name}") + return None + return node + + def _not_registered(self, event_name): + dynamic = [] + for event in self._registered_events: + if event.endswith(">"): + namespace_concern, _ = event.rsplit(".", 1) + dynamic.append(namespace_concern) + + namespace_concern, _ = event_name.rsplit(".", 1) + return ( + event_name not in self._registered_events + and namespace_concern not in dynamic + ) diff --git a/sanic/touchup/service.py b/sanic/touchup/service.py new file mode 100644 index 00000000..95792dca --- /dev/null +++ b/sanic/touchup/service.py @@ -0,0 +1,33 @@ +from inspect import getmembers, getmodule +from typing import Set, Tuple, Type + +from .schemes import BaseScheme + + +class TouchUp: + _registry: Set[Tuple[Type, str]] = set() + + @classmethod + def run(cls, app): + for target, method_name in cls._registry: + method = getattr(target, method_name) + + if app.test_mode: + placeholder = f"_{method_name}" + if hasattr(target, placeholder): + method = getattr(target, placeholder) + else: + setattr(target, placeholder, method) + + module = getmodule(target) + module_globals = dict(getmembers(module)) + + for scheme in BaseScheme._registry: + modified = scheme(app)(method, module_globals) + setattr(target, method_name, modified) + + target.__touched__ = True + + @classmethod + def register(cls, target, method_name): + cls._registry.add((target, method_name)) diff --git a/sanic/views.py b/sanic/views.py index 64a872a4..c983bef7 100644 --- a/sanic/views.py +++ b/sanic/views.py @@ -13,6 +13,7 @@ from warnings import warn from sanic.constants import HTTP_METHODS from sanic.exceptions import InvalidUsage +from sanic.models.handler_types import RouteHandler if TYPE_CHECKING: @@ -86,7 +87,7 @@ class HTTPMethodView: return handler(request, *args, **kwargs) @classmethod - def as_view(cls, *class_args, **class_kwargs): + def as_view(cls, *class_args: Any, **class_kwargs: Any) -> RouteHandler: """Return view function for use with the routing system, that dispatches request to appropriate handler method. """ @@ -100,7 +101,7 @@ class HTTPMethodView: for decorator in cls.decorators: view = decorator(view) - view.view_class = cls + view.view_class = cls # type: ignore view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ view.__name__ = cls.__name__ diff --git a/sanic/websocket.py b/sanic/websocket.py deleted file mode 100644 index b5600ed7..00000000 --- a/sanic/websocket.py +++ /dev/null @@ -1,205 +0,0 @@ -from typing import ( - Any, - Awaitable, - Callable, - Dict, - List, - MutableMapping, - Optional, - Union, -) - -from httptools import HttpParserUpgrade # type: ignore -from websockets import ( # type: ignore - ConnectionClosed, - InvalidHandshake, - WebSocketCommonProtocol, -) - -# Despite the "legacy" namespace, the primary maintainer of websockets -# committed to maintaining backwards-compatibility until 2026 and will -# consider extending it if sanic continues depending on this module. -from websockets.legacy import handshake - -from sanic.exceptions import InvalidUsage -from sanic.server import HttpProtocol - - -__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"] - -ASIMessage = MutableMapping[str, Any] - - -class WebSocketProtocol(HttpProtocol): - def __init__( - self, - *args, - websocket_timeout=10, - websocket_max_size=None, - websocket_max_queue=None, - websocket_read_limit=2 ** 16, - websocket_write_limit=2 ** 16, - websocket_ping_interval=20, - websocket_ping_timeout=20, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.websocket = None - # self.app = None - self.websocket_timeout = websocket_timeout - self.websocket_max_size = websocket_max_size - self.websocket_max_queue = websocket_max_queue - self.websocket_read_limit = websocket_read_limit - self.websocket_write_limit = websocket_write_limit - self.websocket_ping_interval = websocket_ping_interval - self.websocket_ping_timeout = websocket_ping_timeout - - # timeouts make no sense for websocket routes - def request_timeout_callback(self): - if self.websocket is None: - super().request_timeout_callback() - - def response_timeout_callback(self): - if self.websocket is None: - super().response_timeout_callback() - - def keep_alive_timeout_callback(self): - if self.websocket is None: - super().keep_alive_timeout_callback() - - def connection_lost(self, exc): - if self.websocket is not None: - self.websocket.connection_lost(exc) - super().connection_lost(exc) - - def data_received(self, data): - if self.websocket is not None: - # pass the data to the websocket protocol - self.websocket.data_received(data) - else: - try: - super().data_received(data) - except HttpParserUpgrade: - # this is okay, it just indicates we've got an upgrade request - pass - - def write_response(self, response): - if self.websocket is not None: - # websocket requests do not write a response - self.transport.close() - else: - super().write_response(response) - - async def websocket_handshake(self, request, subprotocols=None): - # let the websockets package do the handshake with the client - headers = {} - - try: - key = handshake.check_request(request.headers) - handshake.build_response(headers, key) - except InvalidHandshake: - raise InvalidUsage("Invalid websocket request") - - subprotocol = None - if subprotocols and "Sec-Websocket-Protocol" in request.headers: - # select a subprotocol - client_subprotocols = [ - p.strip() - for p in request.headers["Sec-Websocket-Protocol"].split(",") - ] - for p in client_subprotocols: - if p in subprotocols: - subprotocol = p - headers["Sec-Websocket-Protocol"] = subprotocol - break - - # write the 101 response back to the client - rv = b"HTTP/1.1 101 Switching Protocols\r\n" - for k, v in headers.items(): - rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" - rv += b"\r\n" - request.transport.write(rv) - - # hook up the websocket protocol - self.websocket = WebSocketCommonProtocol( - close_timeout=self.websocket_timeout, - max_size=self.websocket_max_size, - max_queue=self.websocket_max_queue, - read_limit=self.websocket_read_limit, - write_limit=self.websocket_write_limit, - ping_interval=self.websocket_ping_interval, - ping_timeout=self.websocket_ping_timeout, - ) - # we use WebSocketCommonProtocol because we don't want the handshake - # logic from WebSocketServerProtocol; however, we must tell it that - # we're running on the server side - self.websocket.is_client = False - self.websocket.side = "server" - self.websocket.subprotocol = subprotocol - self.websocket.connection_made(request.transport) - self.websocket.connection_open() - return self.websocket - - -class WebSocketConnection: - - # TODO - # - Implement ping/pong - - def __init__( - self, - send: Callable[[ASIMessage], Awaitable[None]], - receive: Callable[[], Awaitable[ASIMessage]], - subprotocols: Optional[List[str]] = None, - ) -> None: - self._send = send - self._receive = receive - self._subprotocols = subprotocols or [] - - async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: - message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} - - if isinstance(data, bytes): - message.update({"bytes": data}) - else: - message.update({"text": str(data)}) - - await self._send(message) - - async def recv(self, *args, **kwargs) -> Optional[str]: - message = await self._receive() - - if message["type"] == "websocket.receive": - return message["text"] - elif message["type"] == "websocket.disconnect": - pass - - return None - - receive = recv - - async def accept(self, subprotocols: Optional[List[str]] = None) -> None: - subprotocol = None - if subprotocols: - for subp in subprotocols: - if subp in self.subprotocols: - subprotocol = subp - break - - await self._send( - { - "type": "websocket.accept", - "subprotocol": subprotocol, - } - ) - - async def close(self) -> None: - pass - - @property - def subprotocols(self): - return self._subprotocols - - @subprotocols.setter - def subprotocols(self, subprotocols: Optional[List[str]] = None): - self._subprotocols = subprotocols or [] diff --git a/sanic/worker.py b/sanic/worker.py index 342900e6..a3bc29b8 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -8,8 +8,8 @@ import traceback from gunicorn.workers import base # type: ignore from sanic.log import logger -from sanic.server import HttpProtocol, Signal, serve, trigger_events -from sanic.websocket import WebSocketProtocol +from sanic.server import HttpProtocol, Signal, serve +from sanic.server.protocols.websocket_protocol import WebSocketProtocol try: @@ -68,10 +68,10 @@ class GunicornWorker(base.Worker): ) self._server_settings["signal"] = self.signal self._server_settings.pop("sock") - trigger_events( - self._server_settings.get("before_start", []), self.loop + self._await(self.app.callable._startup()) + self._await( + self.app.callable._server_event("init", "before", loop=self.loop) ) - self._server_settings["before_start"] = () main_start = self._server_settings.pop("main_start", None) main_stop = self._server_settings.pop("main_stop", None) @@ -82,24 +82,29 @@ class GunicornWorker(base.Worker): "with GunicornWorker" ) - self._runner = asyncio.ensure_future(self._run(), loop=self.loop) try: - self.loop.run_until_complete(self._runner) + self._await(self._run()) self.app.callable.is_running = True - trigger_events( - self._server_settings.get("after_start", []), self.loop + self._await( + self.app.callable._server_event( + "init", "after", loop=self.loop + ) ) self.loop.run_until_complete(self._check_alive()) - trigger_events( - self._server_settings.get("before_stop", []), self.loop + self._await( + self.app.callable._server_event( + "shutdown", "before", loop=self.loop + ) ) self.loop.run_until_complete(self.close()) except BaseException: traceback.print_exc() finally: try: - trigger_events( - self._server_settings.get("after_stop", []), self.loop + self._await( + self.app.callable._server_event( + "shutdown", "after", loop=self.loop + ) ) except BaseException: traceback.print_exc() @@ -137,14 +142,11 @@ class GunicornWorker(base.Worker): # Force close non-idle connection after waiting for # graceful_shutdown_timeout - coros = [] for conn in self.connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) + conn.websocket.fail_connection(code=1001) else: - conn.close() - _shutdown = asyncio.gather(*coros, loop=self.loop) - await _shutdown + conn.abort() async def _run(self): for sock in self.sockets: @@ -238,3 +240,7 @@ class GunicornWorker(base.Worker): self.exit_code = 1 self.cfg.worker_abort(self) sys.exit(1) + + def _await(self, coro): + fut = asyncio.ensure_future(coro, loop=self.loop) + self.loop.run_until_complete(fut) diff --git a/setup.py b/setup.py index af347b3f..ecbf1e07 100644 --- a/setup.py +++ b/setup.py @@ -81,60 +81,63 @@ env_dependency = ( ) ujson = "ujson>=1.35" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency - +types_ujson = "types-ujson" + env_dependency requirements = [ "sanic-routing~=0.7", "httptools>=0.0.10", uvloop, ujson, "aiofiles>=0.6.0", - "websockets>=9.0", + "websockets>=10.0", "multidict>=5.0,<6.0", ] tests_require = [ - "sanic-testing>=0.7.0b1", + "sanic-testing>=0.7.0", "pytest==5.2.1", - "multidict>=5.0,<6.0", + "coverage==5.3", "gunicorn==20.0.4", "pytest-cov", "beautifulsoup4", - uvloop, - ujson, "pytest-sanic", "pytest-sugar", "pytest-benchmark", + "chardet==3.*", + "flake8", + "black", + "isort>=5.0.0", + "bandit", + "mypy>=0.901", + "docutils", + "pygments", + "uvicorn<0.15.0", + types_ujson, ] docs_require = [ "sphinx>=2.1.2", - "sphinx_rtd_theme", - "recommonmark>=0.5.0", + "sphinx_rtd_theme>=0.4.3", "docutils", "pygments", + "m2r2", ] dev_require = tests_require + [ - "aiofiles", "tox", - "black", - "flake8", - "bandit", "towncrier", ] -all_require = dev_require + docs_require +all_require = list(set(dev_require + docs_require)) if strtobool(os.environ.get("SANIC_NO_UJSON", "no")): print("Installing without uJSON") requirements.remove(ujson) - tests_require.remove(ujson) + tests_require.remove(types_ujson) # 'nt' means windows OS if strtobool(os.environ.get("SANIC_NO_UVLOOP", "no")): print("Installing without uvLoop") requirements.remove(uvloop) - tests_require.remove(uvloop) extras_require = { "test": tests_require, diff --git a/tests/conftest.py b/tests/conftest.py index 65b218cf..175e967e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import asyncio +import logging import random import re import string @@ -9,10 +11,12 @@ from typing import Tuple import pytest from sanic_routing.exceptions import RouteExists +from sanic_testing.testing import PORT from sanic import Sanic from sanic.constants import HTTP_METHODS from sanic.router import Router +from sanic.touchup.service import TouchUp slugify = re.compile(r"[^a-zA-Z0-9_\-]") @@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]: collect_ignore = ["test_worker.py"] -@pytest.fixture -def caplog(caplog): - yield caplog - - async def _handler(request): """ Dummy placeholder method used for route resolver when creating a new @@ -41,33 +40,32 @@ async def _handler(request): TYPE_TO_GENERATOR_MAP = { - "string": lambda: "".join( + "str": lambda: "".join( [random.choice(string.ascii_lowercase) for _ in range(4)] ), "int": lambda: random.choice(range(1000000)), - "number": lambda: random.random(), + "float": lambda: random.random(), "alpha": lambda: "".join( [random.choice(string.ascii_lowercase) for _ in range(4)] ), "uuid": lambda: str(uuid.uuid1()), } +CACHE = {} + class RouteStringGenerator: ROUTE_COUNT_PER_DEPTH = 100 HTTP_METHODS = HTTP_METHODS - ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"] + ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"] def generate_random_direct_route(self, max_route_depth=4): routes = [] for depth in range(1, max_route_depth + 1): for _ in range(self.ROUTE_COUNT_PER_DEPTH): route = "/".join( - [ - TYPE_TO_GENERATOR_MAP.get("string")() - for _ in range(depth) - ] + [TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)] ) route = route.replace(".", "", -1) route_detail = (random.choice(self.HTTP_METHODS), route) @@ -83,7 +81,7 @@ class RouteStringGenerator: new_route_part = "/".join( [ "<{}:{}>".format( - TYPE_TO_GENERATOR_MAP.get("string")(), + TYPE_TO_GENERATOR_MAP.get("str")(), random.choice(self.ROUTE_PARAM_TYPES), ) for _ in range(max_route_depth - current_length) @@ -98,7 +96,7 @@ class RouteStringGenerator: def generate_url_for_template(template): url = template for pattern, param_type in re.findall( - re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"), + re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"), template, ): value = TYPE_TO_GENERATOR_MAP.get(param_type)() @@ -111,6 +109,7 @@ def sanic_router(app): # noinspection PyProtectedMember def _setup(route_details: tuple) -> Tuple[Router, tuple]: router = Router() + router.ctx.app = app added_router = [] for method, route in route_details: try: @@ -141,5 +140,33 @@ def url_param_generator(): @pytest.fixture(scope="function") def app(request): + if not CACHE: + for target, method_name in TouchUp._registry: + CACHE[method_name] = getattr(target, method_name) app = Sanic(slugify.sub("-", request.node.name)) - return app + yield app + for target, method_name in TouchUp._registry: + setattr(target, method_name, CACHE[method_name]) + + +@pytest.fixture(scope="function") +def run_startup(caplog): + def run(app): + nonlocal caplog + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + with caplog.at_level(logging.DEBUG): + server = app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + loop._stopping = False + + _server = loop.run_until_complete(server) + + _server.close() + loop.run_until_complete(_server.wait_closed()) + app.stop() + + return caplog.record_tuples + + return run diff --git a/tests/test_app.py b/tests/test_app.py index 9598d54f..f222fba1 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -178,9 +178,6 @@ def test_app_enable_websocket(app, websocket_enabled, enable): @patch("sanic.app.WebSocketProtocol") def test_app_websocket_parameters(websocket_protocol_mock, app): app.config.WEBSOCKET_MAX_SIZE = 44 - app.config.WEBSOCKET_MAX_QUEUE = 45 - app.config.WEBSOCKET_READ_LIMIT = 46 - app.config.WEBSOCKET_WRITE_LIMIT = 47 app.config.WEBSOCKET_PING_TIMEOUT = 48 app.config.WEBSOCKET_PING_INTERVAL = 50 @@ -197,11 +194,6 @@ def test_app_websocket_parameters(websocket_protocol_mock, app): websocket_protocol_call_args = websocket_protocol_mock.call_args ws_kwargs = websocket_protocol_call_args[1] assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE - assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE - assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT - assert ( - ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT - ) assert ( ws_kwargs["websocket_ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT @@ -396,7 +388,7 @@ def test_app_set_attribute_warning(app): assert len(record) == 1 assert record[0].message.args[0] == ( "Setting variables on Sanic instances is deprecated " - "and will be removed in version 21.9. You should change your " + "and will be removed in version 21.12. You should change your " "Sanic instance to use instance.ctx.foo instead." ) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index c707c12a..3d464a4f 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -10,7 +10,7 @@ from sanic.asgi import MockTransport from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.request import Request from sanic.response import json, text -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection @pytest.fixture @@ -360,6 +360,7 @@ async def test_request_handle_exception(app): _, response = await app.asgi_client.get("/error-prone") assert response.status_code == 503 + @pytest.mark.asyncio async def test_request_exception_suppressed_by_middleware(app): @app.get("/error-prone") @@ -374,4 +375,4 @@ async def test_request_exception_suppressed_by_middleware(app): assert response.status_code == 403 _, response = await app.asgi_client.get("/error-prone") - assert response.status_code == 403 \ No newline at end of file + assert response.status_code == 403 diff --git a/tests/test_bad_request.py b/tests/test_bad_request.py index 140fbe8a..7a87d919 100644 --- a/tests/test_bad_request.py +++ b/tests/test_bad_request.py @@ -20,4 +20,4 @@ def test_bad_request_response(app): app.run(host="127.0.0.1", port=42101, debug=False) assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n" - assert b"Bad Request" in lines[-1] + assert b"Bad Request" in lines[-2] diff --git a/tests/test_blueprint_copy.py b/tests/test_blueprint_copy.py new file mode 100644 index 00000000..033e2e20 --- /dev/null +++ b/tests/test_blueprint_copy.py @@ -0,0 +1,70 @@ +from copy import deepcopy + +from sanic import Blueprint, Sanic, blueprints, response +from sanic.response import text + + +def test_bp_copy(app: Sanic): + bp1 = Blueprint("test_bp1", version=1) + bp1.ctx.test = 1 + assert hasattr(bp1.ctx, "test") + + @bp1.route("/page") + def handle_request(request): + return text("Hello world!") + + bp2 = bp1.copy(name="test_bp2", version=2) + assert id(bp1) != id(bp2) + assert bp1._apps == bp2._apps == set() + assert not hasattr(bp2.ctx, "test") + assert len(bp2._future_exceptions) == len(bp1._future_exceptions) + assert len(bp2._future_listeners) == len(bp1._future_listeners) + assert len(bp2._future_middleware) == len(bp1._future_middleware) + assert len(bp2._future_routes) == len(bp1._future_routes) + assert len(bp2._future_signals) == len(bp1._future_signals) + + app.blueprint(bp1) + app.blueprint(bp2) + + bp3 = bp1.copy(name="test_bp3", version=3, with_registration=True) + assert id(bp1) != id(bp3) + assert bp1._apps == bp3._apps and bp3._apps + assert not hasattr(bp3.ctx, "test") + + bp4 = bp1.copy(name="test_bp4", version=4, with_ctx=True) + assert id(bp1) != id(bp4) + assert bp4.ctx.test == 1 + + bp5 = bp1.copy(name="test_bp5", version=5, with_registration=False) + assert id(bp1) != id(bp5) + assert not bp5._apps + assert bp1._apps != set() + + app.blueprint(bp5) + + bp6 = bp1.copy( + name="test_bp6", + version=6, + with_registration=True, + version_prefix="/version", + ) + assert bp6._apps + assert bp6.version_prefix == "/version" + + _, response = app.test_client.get("/v1/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v2/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v3/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v4/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v5/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/version6/page") + assert "Hello world!" in response.text diff --git a/tests/test_blueprint_group.py b/tests/test_blueprint_group.py index 77ddf44c..09729c15 100644 --- a/tests/test_blueprint_group.py +++ b/tests/test_blueprint_group.py @@ -3,6 +3,12 @@ from pytest import raises from sanic.app import Sanic from sanic.blueprint_group import BlueprintGroup from sanic.blueprints import Blueprint +from sanic.exceptions import ( + Forbidden, + InvalidUsage, + SanicException, + ServerError, +) from sanic.request import Request from sanic.response import HTTPResponse, text @@ -96,16 +102,28 @@ def test_bp_group(app: Sanic): def blueprint_1_default_route(request): return text("BP1_OK") + @blueprint_1.route("/invalid") + def blueprint_1_error(request: Request): + raise InvalidUsage("Invalid") + @blueprint_2.route("/") def blueprint_2_default_route(request): return text("BP2_OK") + @blueprint_2.route("/error") + def blueprint_2_error(request: Request): + raise ServerError("Error") + blueprint_group_1 = Blueprint.group( blueprint_1, blueprint_2, url_prefix="/bp" ) blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3") + @blueprint_group_1.exception(InvalidUsage) + def handle_group_exception(request, exception): + return text("BP1_ERR_OK") + @blueprint_group_1.middleware("request") def blueprint_group_1_middleware(request): global MIDDLEWARE_INVOKE_COUNTER @@ -116,19 +134,47 @@ def test_bp_group(app: Sanic): global MIDDLEWARE_INVOKE_COUNTER MIDDLEWARE_INVOKE_COUNTER["request"] += 1 + @blueprint_group_1.on_request + def blueprint_group_1_convenience_1(request): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["request"] += 1 + + @blueprint_group_1.on_request() + def blueprint_group_1_convenience_2(request): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["request"] += 1 + @blueprint_3.route("/") def blueprint_3_default_route(request): return text("BP3_OK") + @blueprint_3.route("/forbidden") + def blueprint_3_forbidden(request: Request): + raise Forbidden("Forbidden") + blueprint_group_2 = Blueprint.group( blueprint_group_1, blueprint_3, url_prefix="/api" ) + @blueprint_group_2.exception(SanicException) + def handle_non_handled_exception(request, exception): + return text("BP2_ERR_OK") + @blueprint_group_2.middleware("response") def blueprint_group_2_middleware(request, response): global MIDDLEWARE_INVOKE_COUNTER MIDDLEWARE_INVOKE_COUNTER["response"] += 1 + @blueprint_group_2.on_response + def blueprint_group_2_middleware_convenience_1(request, response): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["response"] += 1 + + @blueprint_group_2.on_response() + def blueprint_group_2_middleware_convenience_2(request, response): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["response"] += 1 + app.blueprint(blueprint_group_2) @app.route("/") @@ -141,14 +187,23 @@ def test_bp_group(app: Sanic): _, response = app.test_client.get("/api/bp/bp1") assert response.text == "BP1_OK" + _, response = app.test_client.get("/api/bp/bp1/invalid") + assert response.text == "BP1_ERR_OK" + _, response = app.test_client.get("/api/bp/bp2") assert response.text == "BP2_OK" + _, response = app.test_client.get("/api/bp/bp2/error") + assert response.text == "BP2_ERR_OK" + _, response = app.test_client.get("/api/bp3") assert response.text == "BP3_OK" - assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3 - assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4 + _, response = app.test_client.get("/api/bp3/forbidden") + assert response.text == "BP2_ERR_OK" + + assert MIDDLEWARE_INVOKE_COUNTER["response"] == 18 + assert MIDDLEWARE_INVOKE_COUNTER["request"] == 16 def test_bp_group_list_operations(app: Sanic): diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index fec7b50a..b6a23151 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -83,7 +83,6 @@ def test_versioned_routes_get(app, method): return text("OK") else: - print(func) raise Exception(f"{func} is not callable") app.blueprint(bp) @@ -477,6 +476,58 @@ def test_bp_exception_handler(app): assert response.status == 200 +def test_bp_exception_handler_applied(app): + class Error(Exception): + pass + + handled = Blueprint("handled") + nothandled = Blueprint("nothandled") + + @handled.exception(Error) + def handle_error(req, e): + return text("handled {}".format(e)) + + @handled.route("/ok") + def ok(request): + raise Error("uh oh") + + @nothandled.route("/notok") + def notok(request): + raise Error("uh oh") + + app.blueprint(handled) + app.blueprint(nothandled) + + _, response = app.test_client.get("/ok") + assert response.status == 200 + assert response.text == "handled uh oh" + + _, response = app.test_client.get("/notok") + assert response.status == 500 + + +def test_bp_exception_handler_not_applied(app): + class Error(Exception): + pass + + handled = Blueprint("handled") + nothandled = Blueprint("nothandled") + + @handled.exception(Error) + def handle_error(req, e): + return text("handled {}".format(e)) + + @nothandled.route("/notok") + def notok(request): + raise Error("uh oh") + + app.blueprint(handled) + app.blueprint(nothandled) + + _, response = app.test_client.get("/notok") + assert response.status == 500 + + def test_bp_listeners(app): app.route("/")(lambda x: x) blueprint = Blueprint("test_middleware") @@ -1034,6 +1085,6 @@ def test_bp_set_attribute_warning(): assert len(record) == 1 assert record[0].message.args[0] == ( "Setting variables on Blueprint instances is deprecated " - "and will be removed in version 21.9. You should change your " + "and will be removed in version 21.12. You should change your " "Blueprint instance to use instance.ctx.foo instead." ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5f69dd95..908a91a3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -89,7 +89,7 @@ def test_debug(cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO @@ -103,7 +103,7 @@ def test_auto_reload(cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert info["debug"] is False @@ -118,7 +118,7 @@ def test_access_logs(cmd, expected): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert info["access_log"] is expected diff --git a/tests/test_config.py b/tests/test_config.py index ce790800..42a7e3ec 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,7 @@ from sanic.exceptions import PyFileError @contextmanager def temp_path(): - """ a simple cross platform replacement for NamedTemporaryFile """ + """a simple cross platform replacement for NamedTemporaryFile""" with TemporaryDirectory() as td: yield Path(td, "file") diff --git a/tests/test_constants.py b/tests/test_constants.py index 7ce6e4d7..2f1eb3d0 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -1,6 +1,4 @@ -from crypt import methods - -from sanic import text +from sanic import Sanic, text from sanic.constants import HTTP_METHODS, HTTPMethod @@ -14,7 +12,7 @@ def test_string_compat(): assert HTTPMethod.GET.upper() == "GET" -def test_use_in_routes(app): +def test_use_in_routes(app: Sanic): @app.route("/", methods=[HTTPMethod.GET, HTTPMethod.POST]) def handler(_): return text("It works") diff --git a/tests/test_create_task.py b/tests/test_create_task.py index e128263b..99f724b5 100644 --- a/tests/test_create_task.py +++ b/tests/test_create_task.py @@ -1,6 +1,5 @@ import asyncio -from queue import Queue from threading import Event from sanic.response import text @@ -13,8 +12,6 @@ def test_create_task(app): await asyncio.sleep(0.05) e.set() - app.add_task(coro) - @app.route("/early") def not_set(request): return text(str(e.is_set())) @@ -24,24 +21,30 @@ def test_create_task(app): await asyncio.sleep(0.1) return text(str(e.is_set())) + app.add_task(coro) + request, response = app.test_client.get("/early") assert response.body == b"False" + app.signal_router.reset() + app.add_task(coro) request, response = app.test_client.get("/late") assert response.body == b"True" def test_create_task_with_app_arg(app): - q = Queue() + @app.after_server_start + async def setup_q(app, _): + app.ctx.q = asyncio.Queue() @app.route("/") - def not_set(request): - return "hello" + async def not_set(request): + return text(await request.app.ctx.q.get()) async def coro(app): - q.put(app.name) + await app.ctx.q.put(app.name) app.add_task(coro) - request, response = app.test_client.get("/") - assert q.get() == "test_create_task_with_app_arg" + _, response = app.test_client.get("/") + assert response.text == "test_create_task_with_app_arg" diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 495c764f..5af4ca5f 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,10 +1,10 @@ import pytest from sanic import Sanic -from sanic.errorpages import exception_response -from sanic.exceptions import NotFound +from sanic.errorpages import HTMLRenderer, exception_response +from sanic.exceptions import NotFound, SanicException from sanic.request import Request -from sanic.response import HTTPResponse +from sanic.response import HTTPResponse, html, json, text @pytest.fixture @@ -20,7 +20,7 @@ def app(): @pytest.fixture def fake_request(app): - return Request(b"/foobar", {}, "1.1", "GET", None, app) + return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app) @pytest.mark.parametrize( @@ -47,7 +47,13 @@ def test_should_return_html_valid_setting( try: raise exception("bad stuff") except Exception as e: - response = exception_response(fake_request, e, True) + response = exception_response( + fake_request, + e, + True, + base=HTMLRenderer, + fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT, + ) assert isinstance(response, HTTPResponse) assert response.status == status @@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app): app.config.FALLBACK_ERROR_FORMAT = "auto" _, response = app.test_client.get( - "/error", headers={"content-type": "application/json"} + "/error", headers={"content-type": "application/json", "accept": "*/*"} ) assert response.status == 500 assert response.content_type == "application/json" _, response = app.test_client.get( - "/error", headers={"content-type": "text/plain"} + "/error", headers={"content-type": "foo/bar", "accept": "*/*"} + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + +def test_route_error_format_set_on_auto(app): + @app.get("/text") + def text_response(request): + return text(request.route.ctx.error_format) + + @app.get("/json") + def json_response(request): + return json({"format": request.route.ctx.error_format}) + + @app.get("/html") + def html_response(request): + return html(request.route.ctx.error_format) + + _, response = app.test_client.get("/text") + assert response.text == "text" + + _, response = app.test_client.get("/json") + assert response.json["format"] == "json" + + _, response = app.test_client.get("/html") + assert response.text == "html" + + +def test_route_error_response_from_auto_route(app): + @app.get("/text") + def text_response(request): + raise Exception("oops") + return text("Never gonna see this") + + @app.get("/json") + def json_response(request): + raise Exception("oops") + return json({"message": "Never gonna see this"}) + + @app.get("/html") + def html_response(request): + raise Exception("oops") + return html("

Never gonna see this

") + + _, response = app.test_client.get("/text") + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get("/json") + assert response.content_type == "application/json" + + _, response = app.test_client.get("/html") + assert response.content_type == "text/html; charset=utf-8" + + +def test_route_error_response_from_explicit_format(app): + @app.get("/text", error_format="json") + def text_response(request): + raise Exception("oops") + return text("Never gonna see this") + + @app.get("/json", error_format="text") + def json_response(request): + raise Exception("oops") + return json({"message": "Never gonna see this"}) + + _, response = app.test_client.get("/text") + assert response.content_type == "application/json" + + _, response = app.test_client.get("/json") + assert response.content_type == "text/plain; charset=utf-8" + + +def test_unknown_fallback_format(app): + with pytest.raises(SanicException, match="Unknown format: bad"): + app.config.FALLBACK_ERROR_FORMAT = "bad" + + +def test_route_error_format_unknown(app): + with pytest.raises(SanicException, match="Unknown format: bad"): + + @app.get("/text", error_format="bad") + def handler(request): + ... + + +def test_fallback_with_content_type_mismatch_accept(app): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "application/json", "accept": "text/plain"}, ) assert response.status == 500 assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "text/plain", "accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + app.router.reset() + + @app.route("/alt1") + @app.route("/alt2", error_format="text") + @app.route("/alt3", error_format="html") + def handler(_): + raise Exception("problem here") + # Yes, we know this return value is unreachable. This is on purpose. + return json({}) + + app.router.finalize() + + _, response = app.test_client.get( + "/alt1", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + _, response = app.test_client.get( + "/alt1", + headers={"accept": "foo/bar,*/*"}, + ) + assert response.status == 500 + assert response.content_type == "application/json" + + _, response = app.test_client.get( + "/alt2", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + _, response = app.test_client.get( + "/alt2", + headers={"accept": "foo/bar,*/*"}, + ) + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/alt3", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + +@pytest.mark.parametrize( + "accept,content_type,expected", + ( + (None, None, "text/plain; charset=utf-8"), + ("foo/bar", None, "text/html; charset=utf-8"), + ("application/json", None, "application/json"), + ("application/json,text/plain", None, "application/json"), + ("text/plain,application/json", None, "application/json"), + ("text/plain,foo/bar", None, "text/plain; charset=utf-8"), + # Following test is valid after v22.3 + # ("text/plain,text/html", None, "text/plain; charset=utf-8"), + ("*/*", "foo/bar", "text/html; charset=utf-8"), + ("*/*", "application/json", "application/json"), + ), +) +def test_combinations_for_auto(fake_request, accept, content_type, expected): + if accept: + fake_request.headers["accept"] = accept + else: + del fake_request.headers["accept"] + + if content_type: + fake_request.headers["content-type"] = content_type + + try: + raise Exception("bad stuff") + except Exception as e: + response = exception_response( + fake_request, + e, + True, + base=HTMLRenderer, + fallback="auto", + ) + + assert response.content_type == expected diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8487c70b..29797e1e 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,3 +1,4 @@ +import logging import warnings import pytest @@ -15,6 +16,7 @@ from sanic.exceptions import ( abort, ) from sanic.response import text +from websockets.version import version as websockets_version class SanicExceptionTestException(Exception): @@ -232,3 +234,41 @@ def test_sanic_exception(exception_app): request, response = exception_app.test_client.get("/old_abort") assert response.status == 500 assert len(w) == 1 and "deprecated" in w[0].message.args[0] + + +def test_custom_exception_default_message(exception_app): + class TeaError(SanicException): + message = "Tempest in a teapot" + status_code = 418 + + exception_app.router.reset() + + @exception_app.get("/tempest") + def tempest(_): + raise TeaError + + _, response = exception_app.test_client.get("/tempest", debug=True) + assert response.status == 418 + assert b"Tempest in a teapot" in response.body + + +def test_exception_in_ws_logged(caplog): + app = Sanic(__file__) + + @app.websocket("/feed") + async def feed(request, ws): + raise Exception("...") + + with caplog.at_level(logging.INFO): + app.test_client.websocket("/feed") + # Websockets v10.0 and above output an additional + # INFO message when a ws connection is accepted + ws_version_parts = websockets_version.split(".") + ws_major = int(ws_version_parts[0]) + record_index = 2 if ws_major >= 10 else 1 + assert caplog.record_tuples[record_index][0] == "sanic.error" + assert caplog.record_tuples[record_index][1] == logging.ERROR + assert ( + "Exception occurred while handling uri:" + in caplog.record_tuples[record_index][2] + ) diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index e6fd42eb..dbf9fcbb 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -1,5 +1,7 @@ import asyncio +import pytest + from bs4 import BeautifulSoup from sanic import Sanic @@ -8,9 +10,6 @@ from sanic.handlers import ErrorHandler from sanic.response import stream, text -exception_handler_app = Sanic("test_exception_handler") - - async def sample_streaming_fn(response): await response.write("foo,") await asyncio.sleep(0.001) @@ -21,113 +20,107 @@ class ErrorWithRequestCtx(ServerError): pass -@exception_handler_app.route("/1") -def handler_1(request): - raise InvalidUsage("OK") +@pytest.fixture +def exception_handler_app(): + exception_handler_app = Sanic("test_exception_handler") + + @exception_handler_app.route("/1", error_format="html") + def handler_1(request): + raise InvalidUsage("OK") + + @exception_handler_app.route("/2", error_format="html") + def handler_2(request): + raise ServerError("OK") + + @exception_handler_app.route("/3", error_format="html") + def handler_3(request): + raise NotFound("OK") + + @exception_handler_app.route("/4", error_format="html") + def handler_4(request): + foo = bar # noqa -- F821 + return text(foo) + + @exception_handler_app.route("/5", error_format="html") + def handler_5(request): + class CustomServerError(ServerError): + pass + + raise CustomServerError("Custom server error") + + @exception_handler_app.route("/6/", error_format="html") + def handler_6(request, arg): + try: + foo = 1 / arg + except Exception as e: + raise e from ValueError(f"{arg}") + return text(foo) + + @exception_handler_app.route("/7", error_format="html") + def handler_7(request): + raise Forbidden("go away!") + + @exception_handler_app.route("/8", error_format="html") + def handler_8(request): + + raise ErrorWithRequestCtx("OK") + + @exception_handler_app.exception(ErrorWithRequestCtx, NotFound) + def handler_exception_with_ctx(request, exception): + return text(request.ctx.middleware_ran) + + @exception_handler_app.exception(ServerError) + def handler_exception(request, exception): + return text("OK") + + @exception_handler_app.exception(Forbidden) + async def async_handler_exception(request, exception): + return stream( + sample_streaming_fn, + content_type="text/csv", + ) + + @exception_handler_app.middleware + async def some_request_middleware(request): + request.ctx.middleware_ran = "Done." + + return exception_handler_app -@exception_handler_app.route("/2") -def handler_2(request): - raise ServerError("OK") - - -@exception_handler_app.route("/3") -def handler_3(request): - raise NotFound("OK") - - -@exception_handler_app.route("/4") -def handler_4(request): - foo = bar # noqa -- F821 undefined name 'bar' is done to throw exception - return text(foo) - - -@exception_handler_app.route("/5") -def handler_5(request): - class CustomServerError(ServerError): - pass - - raise CustomServerError("Custom server error") - - -@exception_handler_app.route("/6/") -def handler_6(request, arg): - try: - foo = 1 / arg - except Exception as e: - raise e from ValueError(f"{arg}") - return text(foo) - - -@exception_handler_app.route("/7") -def handler_7(request): - raise Forbidden("go away!") - - -@exception_handler_app.route("/8") -def handler_8(request): - - raise ErrorWithRequestCtx("OK") - - -@exception_handler_app.exception(ErrorWithRequestCtx, NotFound) -def handler_exception_with_ctx(request, exception): - return text(request.ctx.middleware_ran) - - -@exception_handler_app.exception(ServerError) -def handler_exception(request, exception): - return text("OK") - - -@exception_handler_app.exception(Forbidden) -async def async_handler_exception(request, exception): - return stream( - sample_streaming_fn, - content_type="text/csv", - ) - - -@exception_handler_app.middleware -async def some_request_middleware(request): - request.ctx.middleware_ran = "Done." - - -def test_invalid_usage_exception_handler(): +def test_invalid_usage_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/1") assert response.status == 400 -def test_server_error_exception_handler(): +def test_server_error_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/2") assert response.status == 200 assert response.text == "OK" -def test_not_found_exception_handler(): +def test_not_found_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/3") assert response.status == 200 -def test_text_exception__handler(): +def test_text_exception__handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/random") assert response.status == 200 assert response.text == "Done." -def test_async_exception_handler(): +def test_async_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/7") assert response.status == 200 assert response.text == "foo,bar" -def test_html_traceback_output_in_debug_mode(): +def test_html_traceback_output_in_debug_mode(exception_handler_app): request, response = exception_handler_app.test_client.get("/4", debug=True) assert response.status == 500 soup = BeautifulSoup(response.body, "html.parser") html = str(soup) - assert "response = handler(request, **kwargs)" in html assert "handler_4" in html assert "foo = bar" in html @@ -137,12 +130,12 @@ def test_html_traceback_output_in_debug_mode(): ) == summary_text -def test_inherited_exception_handler(): +def test_inherited_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/5") assert response.status == 200 -def test_chained_exception_handler(): +def test_chained_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get( "/6/0", debug=True ) @@ -151,11 +144,9 @@ def test_chained_exception_handler(): soup = BeautifulSoup(response.body, "html.parser") html = str(soup) - assert "response = handler(request, **kwargs)" in html assert "handler_6" in html assert "foo = 1 / arg" in html assert "ValueError" in html - assert "The above exception was the direct cause" in html summary_text = " ".join(soup.select(".summary")[0].text.split()) assert ( @@ -163,7 +154,7 @@ def test_chained_exception_handler(): ) == summary_text -def test_exception_handler_lookup(): +def test_exception_handler_lookup(exception_handler_app): class CustomError(Exception): pass @@ -186,26 +177,32 @@ def test_exception_handler_lookup(): class ModuleNotFoundError(ImportError): pass - handler = ErrorHandler() + handler = ErrorHandler("auto") handler.add(ImportError, import_error_handler) handler.add(CustomError, custom_error_handler) handler.add(ServerError, server_error_handler) - assert handler.lookup(ImportError()) == import_error_handler - assert handler.lookup(ModuleNotFoundError()) == import_error_handler - assert handler.lookup(CustomError()) == custom_error_handler - assert handler.lookup(ServerError("Error")) == server_error_handler - assert handler.lookup(CustomServerError("Error")) == server_error_handler + assert handler.lookup(ImportError(), None) == import_error_handler + assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler + assert handler.lookup(CustomError(), None) == custom_error_handler + assert handler.lookup(ServerError("Error"), None) == server_error_handler + assert ( + handler.lookup(CustomServerError("Error"), None) + == server_error_handler + ) # once again to ensure there is no caching bug - assert handler.lookup(ImportError()) == import_error_handler - assert handler.lookup(ModuleNotFoundError()) == import_error_handler - assert handler.lookup(CustomError()) == custom_error_handler - assert handler.lookup(ServerError("Error")) == server_error_handler - assert handler.lookup(CustomServerError("Error")) == server_error_handler + assert handler.lookup(ImportError(), None) == import_error_handler + assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler + assert handler.lookup(CustomError(), None) == custom_error_handler + assert handler.lookup(ServerError("Error"), None) == server_error_handler + assert ( + handler.lookup(CustomServerError("Error"), None) + == server_error_handler + ) -def test_exception_handler_processed_request_middleware(): +def test_exception_handler_processed_request_middleware(exception_handler_app): request, response = exception_handler_app.test_client.get("/8") assert response.status == 200 assert response.text == "Done." diff --git a/tests/test_graceful_shutdown.py b/tests/test_graceful_shutdown.py new file mode 100644 index 00000000..8380ed50 --- /dev/null +++ b/tests/test_graceful_shutdown.py @@ -0,0 +1,46 @@ +import asyncio +import logging +import time + +from collections import Counter +from multiprocessing import Process + +import httpx + + +PORT = 42101 + + +def test_no_exceptions_when_cancel_pending_request(app, caplog): + app.config.GRACEFUL_SHUTDOWN_TIMEOUT = 1 + + @app.get("/") + async def handler(request): + await asyncio.sleep(5) + + @app.after_server_start + def shutdown(app, _): + time.sleep(0.2) + app.stop() + + def ping(): + time.sleep(0.1) + response = httpx.get("http://127.0.0.1:8000") + print(response.status_code) + + p = Process(target=ping) + p.start() + + with caplog.at_level(logging.INFO): + app.run() + + p.kill() + + counter = Counter([r[1] for r in caplog.record_tuples]) + + assert counter[logging.INFO] == 5 + assert logging.ERROR not in counter + assert ( + caplog.record_tuples[3][2] + == "Request: GET http://127.0.0.1:8000/ stopped. Transport is closed." + ) diff --git a/tests/test_handler_annotations.py b/tests/test_handler_annotations.py new file mode 100644 index 00000000..14d1d7b7 --- /dev/null +++ b/tests/test_handler_annotations.py @@ -0,0 +1,39 @@ +from uuid import UUID + +import pytest + +from sanic import json + + +@pytest.mark.parametrize( + "idx,path,expectation", + ( + (0, "/abc", "str"), + (1, "/123", "int"), + (2, "/123.5", "float"), + (3, "/8af729fe-2b94-4a95-a168-c07068568429", "UUID"), + ), +) +def test_annotated_handlers(app, idx, path, expectation): + def build_response(num, foo): + return json({"num": num, "type": type(foo).__name__}) + + @app.get("/") + def handler0(_, foo: str): + return build_response(0, foo) + + @app.get("/") + def handler1(_, foo: int): + return build_response(1, foo) + + @app.get("/") + def handler2(_, foo: float): + return build_response(2, foo) + + @app.get("/") + def handler3(_, foo: UUID): + return build_response(3, foo) + + _, response = app.test_client.get(path) + assert response.json["num"] == idx + assert response.json["type"] == expectation diff --git a/tests/test_headers.py b/tests/test_headers.py index 546a9ef7..115bed86 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -3,8 +3,9 @@ from unittest.mock import Mock import pytest from sanic import headers, text -from sanic.exceptions import PayloadTooLarge +from sanic.exceptions import InvalidHeader, PayloadTooLarge from sanic.http import Http +from sanic.request import Request @pytest.fixture @@ -182,3 +183,187 @@ def test_request_line(app): ) assert request.request_line == b"GET / HTTP/1.1" + + +@pytest.mark.parametrize( + "raw", + ( + "show/first, show/second", + "show/*, show/first", + "*/*, show/first", + "*/*, show/*", + "other/*; q=0.1, show/*; q=0.2", + "show/first; q=0.5, show/second; q=0.5", + "show/first; foo=bar, show/second; foo=bar", + "show/second, show/first; foo=bar", + "show/second; q=0.5, show/first; foo=bar; q=0.5", + "show/second; q=0.5, show/first; q=1.0", + "show/first, show/second; q=1.0", + ), +) +def test_parse_accept_ordered_okay(raw): + ordered = headers.parse_accept(raw) + expected_subtype = ( + "*" if all(q.subtype.is_wildcard for q in ordered) else "first" + ) + assert ordered[0].type_ == "show" + assert ordered[0].subtype == expected_subtype + + +@pytest.mark.parametrize( + "raw", + ( + "missing", + "missing/", + "/missing", + ), +) +def test_bad_accept(raw): + with pytest.raises(InvalidHeader): + headers.parse_accept(raw) + + +def test_empty_accept(): + assert headers.parse_accept("") == [] + + +def test_wildcard_accept_set_ok(): + accept = headers.parse_accept("*/*")[0] + assert accept.type_.is_wildcard + assert accept.subtype.is_wildcard + + accept = headers.parse_accept("foo/bar")[0] + assert not accept.type_.is_wildcard + assert not accept.subtype.is_wildcard + + +def test_accept_parsed_against_str(): + accept = headers.Accept.parse("foo/bar") + assert accept > "foo/bar; q=0.1" + + +def test_media_type_equality(): + assert headers.MediaType("foo") == headers.MediaType("foo") == "foo" + assert headers.MediaType("foo") == headers.MediaType("*") == "*" + assert headers.MediaType("foo") != headers.MediaType("bar") + assert headers.MediaType("foo") != "bar" + + +def test_media_type_matching(): + assert headers.MediaType("foo").match(headers.MediaType("foo")) + assert headers.MediaType("foo").match("foo") + + assert not headers.MediaType("foo").match(headers.MediaType("*")) + assert not headers.MediaType("foo").match("*") + + assert not headers.MediaType("foo").match(headers.MediaType("bar")) + assert not headers.MediaType("foo").match("bar") + + +@pytest.mark.parametrize( + "value,other,outcome,allow_type,allow_subtype", + ( + # ALLOW BOTH + ("foo/bar", "foo/bar", True, True, True), + ("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/bar", "foo/*", True, True, True), + ("foo/bar", headers.Accept.parse("foo/*"), True, True, True), + ("foo/bar", "*/*", True, True, True), + ("foo/bar", headers.Accept.parse("*/*"), True, True, True), + ("foo/*", "foo/bar", True, True, True), + ("foo/*", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/*", "foo/*", True, True, True), + ("foo/*", headers.Accept.parse("foo/*"), True, True, True), + ("foo/*", "*/*", True, True, True), + ("foo/*", headers.Accept.parse("*/*"), True, True, True), + ("*/*", "foo/bar", True, True, True), + ("*/*", headers.Accept.parse("foo/bar"), True, True, True), + ("*/*", "foo/*", True, True, True), + ("*/*", headers.Accept.parse("foo/*"), True, True, True), + ("*/*", "*/*", True, True, True), + ("*/*", headers.Accept.parse("*/*"), True, True, True), + # ALLOW TYPE + ("foo/bar", "foo/bar", True, True, False), + ("foo/bar", headers.Accept.parse("foo/bar"), True, True, False), + ("foo/bar", "foo/*", False, True, False), + ("foo/bar", headers.Accept.parse("foo/*"), False, True, False), + ("foo/bar", "*/*", False, True, False), + ("foo/bar", headers.Accept.parse("*/*"), False, True, False), + ("foo/*", "foo/bar", False, True, False), + ("foo/*", headers.Accept.parse("foo/bar"), False, True, False), + ("foo/*", "foo/*", False, True, False), + ("foo/*", headers.Accept.parse("foo/*"), False, True, False), + ("foo/*", "*/*", False, True, False), + ("foo/*", headers.Accept.parse("*/*"), False, True, False), + ("*/*", "foo/bar", False, True, False), + ("*/*", headers.Accept.parse("foo/bar"), False, True, False), + ("*/*", "foo/*", False, True, False), + ("*/*", headers.Accept.parse("foo/*"), False, True, False), + ("*/*", "*/*", False, True, False), + ("*/*", headers.Accept.parse("*/*"), False, True, False), + # ALLOW SUBTYPE + ("foo/bar", "foo/bar", True, False, True), + ("foo/bar", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/bar", "foo/*", True, False, True), + ("foo/bar", headers.Accept.parse("foo/*"), True, False, True), + ("foo/bar", "*/*", False, False, True), + ("foo/bar", headers.Accept.parse("*/*"), False, False, True), + ("foo/*", "foo/bar", True, False, True), + ("foo/*", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/*", "foo/*", True, False, True), + ("foo/*", headers.Accept.parse("foo/*"), True, False, True), + ("foo/*", "*/*", False, False, True), + ("foo/*", headers.Accept.parse("*/*"), False, False, True), + ("*/*", "foo/bar", False, False, True), + ("*/*", headers.Accept.parse("foo/bar"), False, False, True), + ("*/*", "foo/*", False, False, True), + ("*/*", headers.Accept.parse("foo/*"), False, False, True), + ("*/*", "*/*", False, False, True), + ("*/*", headers.Accept.parse("*/*"), False, False, True), + ), +) +def test_accept_matching(value, other, outcome, allow_type, allow_subtype): + assert ( + headers.Accept.parse(value).match( + other, + allow_type_wildcard=allow_type, + allow_subtype_wildcard=allow_subtype, + ) + is outcome + ) + + +@pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*")) +def test_value_in_accept(value): + acceptable = headers.parse_accept(value) + assert "foo/bar" in acceptable + assert "foo/*" in acceptable + assert "*/*" in acceptable + + +@pytest.mark.parametrize("value", ("foo/bar", "foo/*")) +def test_value_not_in_accept(value): + acceptable = headers.parse_accept(value) + assert "no/match" not in acceptable + assert "no/*" not in acceptable + + +@pytest.mark.parametrize( + "header,expected", + ( + ( + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501 + [ + "text/html", + "application/xhtml+xml", + "image/avif", + "image/webp", + "application/xml;q=0.9", + "*/*;q=0.8", + ], + ), + ), +) +def test_browser_headers(header, expected): + request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) + assert request.accept == expected diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 00000000..653857a1 --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,137 @@ +import asyncio +import json as stdjson + +from collections import namedtuple +from textwrap import dedent +from typing import AnyStr + +import pytest + +from sanic_testing.reusable import ReusableClient + +from sanic import json, text +from sanic.app import Sanic + + +PORT = 1234 + + +class RawClient: + CRLF = b"\r\n" + + def __init__(self, host: str, port: int): + self.reader = None + self.writer = None + self.host = host + self.port = port + + async def connect(self): + self.reader, self.writer = await asyncio.open_connection( + self.host, self.port + ) + + async def close(self): + self.writer.close() + await self.writer.wait_closed() + + async def send(self, message: AnyStr): + if isinstance(message, str): + msg = self._clean(message).encode("utf-8") + else: + msg = message + await self._send(msg) + + async def _send(self, message: bytes): + if not self.writer: + raise Exception("No open write stream") + self.writer.write(message) + + async def recv(self, nbytes: int = -1) -> bytes: + if not self.reader: + raise Exception("No open read stream") + return await self.reader.read(nbytes) + + def _clean(self, message: str) -> str: + return ( + dedent(message) + .lstrip("\n") + .replace("\n", self.CRLF.decode("utf-8")) + ) + + +@pytest.fixture +def test_app(app: Sanic): + app.config.KEEP_ALIVE_TIMEOUT = 1 + + @app.get("/") + async def base_handler(request): + return text("111122223333444455556666777788889999") + + @app.post("/upload", stream=True) + async def upload_handler(request): + data = [part.decode("utf-8") async for part in request.stream] + return json(data) + + return app + + +@pytest.fixture +def runner(test_app): + client = ReusableClient(test_app, port=PORT) + client.run() + yield client + client.stop() + + +@pytest.fixture +def client(runner): + client = namedtuple("Client", ("raw", "send", "recv")) + + raw = RawClient(runner.host, runner.port) + runner._run(raw.connect()) + + def send(msg): + nonlocal runner + nonlocal raw + runner._run(raw.send(msg)) + + def recv(**kwargs): + nonlocal runner + nonlocal raw + method = raw.recv_until if "until" in kwargs else raw.recv + return runner._run(method(**kwargs)) + + yield client(raw, send, recv) + + runner._run(raw.close()) + + +def test_full_message(client): + client.send( + """ + GET / HTTP/1.1 + host: localhost:7777 + + """ + ) + response = client.recv() + assert len(response) == 140 + assert b"200 OK" in response + + +def test_transfer_chunked(client): + client.send( + """ + POST /upload HTTP/1.1 + transfer-encoding: chunked + + """ + ) + client.send(b"3\r\nfoo\r\n") + client.send(b"3\r\nbar\r\n") + client.send(b"0\r\n\r\n") + response = client.recv() + _, body = response.rsplit(b"\r\n\r\n", 1) + data = stdjson.loads(body) + + assert data == ["foo", "bar"] diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index e777de2e..e30761ed 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -2,16 +2,13 @@ import asyncio import platform from asyncio import sleep as aio_sleep -from json import JSONDecodeError from os import environ -import httpcore -import httpx import pytest -from sanic_testing.testing import HOST, SanicTestClient +from sanic_testing.reusable import ReusableClient -from sanic import Sanic, server +from sanic import Sanic from sanic.compat import OS_IS_WINDOWS from sanic.response import text @@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port -class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool): - last_reused_connection = None - - async def _get_connection_from_pool(self, *args, **kwargs): - conn = await super()._get_connection_from_pool(*args, **kwargs) - self.__class__.last_reused_connection = conn - return conn - - -class ResusableSanicSession(httpx.AsyncClient): - def __init__(self, *args, **kwargs) -> None: - transport = ReusableSanicConnectionPool() - super().__init__(transport=transport, *args, **kwargs) - - -class ReuseableSanicTestClient(SanicTestClient): - def __init__(self, app, loop=None): - super().__init__(app) - if loop is None: - loop = asyncio.get_event_loop() - self._loop = loop - self._server = None - self._tcp_connector = None - self._session = None - - def get_new_session(self): - return ResusableSanicSession() - - # Copied from SanicTestClient, but with some changes to reuse the - # same loop for the same app. - def _sanic_endpoint_test( - self, - method="get", - uri="/", - gather_request=True, - debug=False, - server_kwargs=None, - *request_args, - **request_kwargs, - ): - loop = self._loop - results = [None, None] - exceptions = [] - server_kwargs = server_kwargs or {"return_asyncio_server": True} - if gather_request: - - def _collect_request(request): - if results[0] is None: - results[0] = request - - self.app.request_middleware.appendleft(_collect_request) - - if uri.startswith( - ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") - ): - url = uri - else: - uri = uri if uri.startswith("/") else f"/{uri}" - scheme = "http" - url = f"{scheme}://{HOST}:{PORT}{uri}" - - @self.app.listener("after_server_start") - async def _collect_response(loop): - try: - response = await self._local_request( - method, url, *request_args, **request_kwargs - ) - results[-1] = response - except Exception as e2: - exceptions.append(e2) - - if self._server is not None: - _server = self._server - else: - _server_co = self.app.create_server( - host=HOST, debug=debug, port=PORT, **server_kwargs - ) - - server.trigger_events( - self.app.listeners["before_server_start"], loop - ) - - try: - loop._stopping = False - _server = loop.run_until_complete(_server_co) - except Exception as e1: - raise e1 - self._server = _server - server.trigger_events(self.app.listeners["after_server_start"], loop) - self.app.listeners["after_server_start"].pop() - - if exceptions: - raise ValueError(f"Exception during request: {exceptions}") - - if gather_request: - self.app.request_middleware.pop() - try: - request, response = results - return request, response - except Exception: - raise ValueError( - f"Request and response object expected, got ({results})" - ) - else: - try: - return results[-1] - except Exception: - raise ValueError(f"Request object expected, got ({results})") - - def kill_server(self): - try: - if self._server: - self._server.close() - self._loop.run_until_complete(self._server.wait_closed()) - self._server = None - - if self._session: - self._loop.run_until_complete(self._session.aclose()) - self._session = None - - except Exception as e3: - raise e3 - - # Copied from SanicTestClient, but with some changes to reuse the - # same TCPConnection and the sane ClientSession more than once. - # Note, you cannot use the same session if you are in a _different_ - # loop, so the changes above are required too. - async def _local_request(self, method, url, *args, **kwargs): - raw_cookies = kwargs.pop("raw_cookies", None) - request_keepalive = kwargs.pop( - "request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"] - ) - if not self._session: - self._session = self.get_new_session() - try: - response = await getattr(self._session, method.lower())( - url, timeout=request_keepalive, *args, **kwargs - ) - except NameError: - raise Exception(response.status_code) - - try: - response.json = response.json() - except (JSONDecodeError, UnicodeDecodeError): - response.json = None - - response.body = await response.aread() - response.status = response.status_code - response.content_type = response.headers.get("content-type") - - if raw_cookies: - response.raw_cookies = {} - for cookie in response.cookies: - response.raw_cookies[cookie.name] = cookie - - return response - - keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") @@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse(): """If the server keep-alive timeout and client keep-alive timeout are both longer than the delay, the client _and_ server will successfully reuse the existing connection.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT) + with client: headers = {"Connection": "keep-alive"} request, response = client.get("/1", headers=headers) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 + loop.run_until_complete(aio_sleep(1)) + request, response = client.get("/1") assert response.status == 200 assert response.text == "OK" - assert ReusableSanicConnectionPool.last_reused_connection - finally: - client.kill_server() + assert request.protocol.state["requests_count"] == 2 @pytest.mark.skipif( @@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse(): def test_keep_alive_client_timeout(): """If the server keep-alive timeout is longer than the client keep-alive timeout, client will try to create a new connection here.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient( + keep_alive_app_client_timeout, loop=loop, port=PORT + ) + with client: headers = {"Connection": "keep-alive"} - _, response = client.get("/1", headers=headers, request_keepalive=1) + request, response = client.get("/1", headers=headers, timeout=1) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 loop.run_until_complete(aio_sleep(2)) - _, response = client.get("/1", request_keepalive=1) - - assert ReusableSanicConnectionPool.last_reused_connection is None - finally: - client.kill_server() + request, response = client.get("/1", timeout=1) + assert request.protocol.state["requests_count"] == 1 @pytest.mark.skipif( @@ -277,22 +117,23 @@ def test_keep_alive_server_timeout(): keep-alive timeout, the client will either a 'Connection reset' error _or_ a new connection. Depending on how the event-loop handles the broken server connection.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient( + keep_alive_app_server_timeout, loop=loop, port=PORT + ) + with client: headers = {"Connection": "keep-alive"} - _, response = client.get("/1", headers=headers, request_keepalive=60) + request, response = client.get("/1", headers=headers, timeout=60) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 loop.run_until_complete(aio_sleep(3)) - _, response = client.get("/1", request_keepalive=60) + request, response = client.get("/1", timeout=60) - assert ReusableSanicConnectionPool.last_reused_connection is None - finally: - client.kill_server() + assert request.protocol.state["requests_count"] == 1 @pytest.mark.skipif( @@ -300,10 +141,10 @@ def test_keep_alive_server_timeout(): reason="Not testable with current client", ) def test_keep_alive_connection_context(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_context, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT) + with client: headers = {"Connection": "keep-alive"} request1, _ = client.post("/ctx", headers=headers) @@ -315,5 +156,4 @@ def test_keep_alive_connection_context(): assert ( request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello" ) - finally: - client.kill_server() + assert request2.protocol.state["requests_count"] == 2 diff --git a/tests/test_logging.py b/tests/test_logging.py index 5f531670..639bb2ee 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -5,6 +5,7 @@ import uuid from importlib import reload from io import StringIO +from unittest.mock import Mock import pytest @@ -51,7 +52,7 @@ def test_log(app): def test_logging_defaults(): # reset_logging() - app = Sanic("test_logging") + Sanic("test_logging") for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: assert ( @@ -87,7 +88,7 @@ def test_logging_pass_customer_logconfig(): "format" ] = "%(asctime)s - (%(name)s)[%(levelname)s]: %(message)s" - app = Sanic("test_logging", log_config=modified_config) + Sanic("test_logging", log_config=modified_config) for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: assert fmt._fmt == modified_config["formatters"]["generic"]["format"] @@ -111,11 +112,13 @@ def test_logging_pass_customer_logconfig(): ), ) def test_log_connection_lost(app, debug, monkeypatch): - """ Should not log Connection lost exception on non debug """ + """Should not log Connection lost exception on non debug""" stream = StringIO() error = logging.getLogger("sanic.error") error.addHandler(logging.StreamHandler(stream)) - monkeypatch.setattr(sanic.server, "error_logger", error) + monkeypatch.setattr( + sanic.server.protocols.http_protocol, "error_logger", error + ) @app.route("/conn_lost") async def conn_lost(request): @@ -208,6 +211,56 @@ def test_logging_modified_root_logger_config(): modified_config = LOGGING_CONFIG_DEFAULTS modified_config["loggers"]["sanic.root"]["level"] = "DEBUG" - app = Sanic("test_logging", log_config=modified_config) + Sanic("test_logging", log_config=modified_config) assert logging.getLogger("sanic.root").getEffectiveLevel() == logging.DEBUG + + +def test_access_log_client_ip_remote_addr(monkeypatch): + access = Mock() + monkeypatch.setattr(sanic.http, "access_logger", access) + + app = Sanic("test_logging") + app.config.PROXIES_COUNT = 2 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Forwarded-For": "1.1.1.1, 2.2.2.2"} + + request, response = app.test_client.get("/", headers=headers) + + assert request.remote_addr == "1.1.1.1" + access.info.assert_called_with( + "", + extra={ + "status": 200, + "byte": len(response.content), + "host": f"{request.remote_addr}:{request.port}", + "request": f"GET {request.scheme}://{request.host}/", + }, + ) + + +def test_access_log_client_ip_reqip(monkeypatch): + access = Mock() + monkeypatch.setattr(sanic.http, "access_logger", access) + + app = Sanic("test_logging") + + @app.route("/") + async def handler(request): + return text(request.ip) + + request, response = app.test_client.get("/") + + access.info.assert_called_with( + "", + extra={ + "status": 200, + "byte": len(response.content), + "host": f"{request.ip}:{request.port}", + "request": f"GET {request.scheme}://{request.host}/", + }, + ) diff --git a/tests/test_logo.py b/tests/test_logo.py index 3fff32db..e59975c3 100644 --- a/tests/test_logo.py +++ b/tests/test_logo.py @@ -6,85 +6,37 @@ from sanic_testing.testing import PORT from sanic.config import BASE_LOGO -def test_logo_base(app, caplog): - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False +def test_logo_base(app, run_startup): + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == BASE_LOGO + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == BASE_LOGO -def test_logo_false(app, caplog): +def test_logo_false(app, caplog, run_startup): app.config.LOGO = False - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - banner, port = caplog.record_tuples[0][2].rsplit(":", 1) - assert caplog.record_tuples[0][1] == logging.INFO + banner, port = logs[0][2].rsplit(":", 1) + assert logs[0][1] == logging.INFO assert banner == "Goin' Fast @ http://127.0.0.1" assert int(port) > 0 -def test_logo_true(app, caplog): +def test_logo_true(app, run_startup): app.config.LOGO = True - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == BASE_LOGO + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == BASE_LOGO -def test_logo_custom(app, caplog): +def test_logo_custom(app, run_startup): app.config.LOGO = "My Custom Logo" - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == "My Custom Logo" + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == "My Custom Logo" diff --git a/tests/test_middleware.py b/tests/test_middleware.py index cc7edae2..c19386e7 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -5,7 +5,7 @@ from itertools import count from sanic.exceptions import NotFound from sanic.request import Request -from sanic.response import HTTPResponse, text +from sanic.response import HTTPResponse, json, text # ------------------------------------------------------------ # @@ -37,14 +37,19 @@ def test_middleware_request_as_convenience(app): async def handler1(request): results.append(request) - @app.route("/") + @app.on_request() async def handler2(request): + results.append(request) + + @app.route("/") + async def handler3(request): return text("OK") request, response = app.test_client.get("/") assert response.text == "OK" assert type(results[0]) is Request + assert type(results[1]) is Request def test_middleware_response(app): @@ -79,7 +84,12 @@ def test_middleware_response_as_convenience(app): results.append(request) @app.on_response - async def process_response(request, response): + async def process_response_1(request, response): + results.append(request) + results.append(response) + + @app.on_response() + async def process_response_2(request, response): results.append(request) results.append(response) @@ -93,6 +103,8 @@ def test_middleware_response_as_convenience(app): assert type(results[0]) is Request assert type(results[1]) is Request assert isinstance(results[2], HTTPResponse) + assert type(results[3]) is Request + assert isinstance(results[4], HTTPResponse) def test_middleware_response_as_convenience_called(app): @@ -271,3 +283,17 @@ def test_request_middleware_executes_once(app): request, response = app.test_client.get("/") assert next(i) == 3 + + +def test_middleware_added_response(app): + @app.on_response + def display(_, response): + response["foo"] = "bar" + return json(response) + + @app.get("/") + async def handler(request): + return {} + + _, response = app.test_client.get("/") + assert response.json["foo"] == "bar" diff --git a/tests/test_request.py b/tests/test_request.py index e4b21f66..ca2c1e4a 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -140,3 +140,39 @@ def test_ipv6_address_is_not_wrapped(app): assert resp.json["client"] == "[::1]" assert resp.json["client_ip"] == "::1" assert request.ip == "::1" + + +def test_request_accept(): + app = Sanic("req-generator") + + @app.get("/") + async def get(request): + return response.empty() + + request, _ = app.test_client.get( + "/", + headers={ + "Accept": "text/*, text/plain, text/plain;format=flowed, */*" + }, + ) + assert request.accept == [ + "text/plain;format=flowed", + "text/plain", + "text/*", + "*/*", + ] + + request, _ = app.test_client.get( + "/", + headers={ + "Accept": ( + "text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c" + ) + }, + ) + assert request.accept == [ + "text/html", + "text/x-c", + "text/x-dvi; q=0.8", + "text/plain; q=0.5", + ] diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index 89cb46df..48e23f1d 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -2,6 +2,7 @@ import asyncio import httpcore import httpx +import pytest from sanic_testing.testing import SanicTestClient @@ -48,42 +49,51 @@ class DelayableSanicTestClient(SanicTestClient): return DelayableSanicSession(request_delay=self._request_delay) -request_timeout_default_app = Sanic("test_request_timeout_default") -request_no_timeout_app = Sanic("test_request_no_timeout") -request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6 -request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6 +@pytest.fixture +def request_no_timeout_app(): + app = Sanic("test_request_no_timeout") + app.config.REQUEST_TIMEOUT = 0.6 + + @app.route("/1") + async def handler2(request): + return text("OK") + + return app -@request_timeout_default_app.route("/1") -async def handler1(request): - return text("OK") +@pytest.fixture +def request_timeout_default_app(): + app = Sanic("test_request_timeout_default") + app.config.REQUEST_TIMEOUT = 0.6 + + @app.route("/1") + async def handler1(request): + return text("OK") + + @app.websocket("/ws1") + async def ws_handler1(request, ws): + await ws.send("OK") + + return app -@request_no_timeout_app.route("/1") -async def handler2(request): - return text("OK") - - -@request_timeout_default_app.websocket("/ws1") -async def ws_handler1(request, ws): - await ws.send("OK") - - -def test_default_server_error_request_timeout(): +def test_default_server_error_request_timeout(request_timeout_default_app): client = DelayableSanicTestClient(request_timeout_default_app, 2) - request, response = client.get("/1") + _, response = client.get("/1") assert response.status == 408 assert "Request Timeout" in response.text -def test_default_server_error_request_dont_timeout(): +def test_default_server_error_request_dont_timeout(request_no_timeout_app): client = DelayableSanicTestClient(request_no_timeout_app, 0.2) - request, response = client.get("/1") + _, response = client.get("/1") assert response.status == 200 assert response.text == "OK" -def test_default_server_error_websocket_request_timeout(): +def test_default_server_error_websocket_request_timeout( + request_timeout_default_app, +): headers = { "Upgrade": "websocket", @@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout(): } client = DelayableSanicTestClient(request_timeout_default_app, 2) - request, response = client.get("/ws1", headers=headers) + _, response = client.get("/ws1", headers=headers) assert response.status == 408 assert "Request Timeout" in response.text diff --git a/tests/test_routes.py b/tests/test_routes.py index 0f4980f6..520ab5be 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app): @pytest.mark.asyncio @pytest.mark.parametrize("url", ["/ws", "ws"]) async def test_websocket_route_asgi(app, url): - ev = asyncio.Event() + @app.after_server_start + async def setup_ev(app, _): + app.ctx.ev = asyncio.Event() @app.websocket(url) async def handler(request, ws): - ev.set() + request.app.ctx.ev.set() - request, response = await app.asgi_client.websocket(url) - assert ev.is_set() + @app.get("/ev") + async def check(request): + return json({"set": request.app.ctx.ev.is_set()}) + + _, response = await app.asgi_client.websocket(url) + _, response = await app.asgi_client.get("/") + assert response.json["set"] -def test_websocket_route_with_subprotocols(app): +@pytest.mark.parametrize( + "subprotocols,expected", + ( + (["one"], "one"), + (["three", "one"], "one"), + (["tree"], None), + (None, None), + ), +) +def test_websocket_route_with_subprotocols(app, subprotocols, expected): results = [] - @app.websocket("/ws", subprotocols=["foo", "bar"]) + @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"]) async def handler(request, ws): - results.append(ws.subprotocol) + nonlocal results + results = ws.subprotocol assert ws.subprotocol is not None - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"]) - assert response.opened is True - assert results == ["bar"] - _, response = SanicTestClient(app).websocket( - "/ws", subprotocols=["bar", "foo"] + "/ws", subprotocols=subprotocols ) assert response.opened is True - assert results == ["bar", "bar"] - - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"]) - assert response.opened is True - assert results == ["bar", "bar", None] - - _, response = SanicTestClient(app).websocket("/ws") - assert response.opened is True - assert results == ["bar", "bar", None, None] + assert results == expected @pytest.mark.parametrize("strict_slashes", [True, False, None]) diff --git a/tests/test_server_events.py b/tests/test_server_events.py index 2e48f408..7ce1859c 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -8,7 +8,7 @@ import pytest from sanic_testing.testing import HOST, PORT -from sanic.exceptions import InvalidUsage +from sanic.exceptions import InvalidUsage, SanicException AVAILABLE_LISTENERS = [ @@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app): async def init_db(app, loop): app.db = MySanicDb() - await app.create_server(debug=True, return_asyncio_server=True, port=PORT) + srv = await app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + await srv.startup() + await srv.before_start() assert hasattr(app, "db") assert isinstance(app.db, MySanicDb) @@ -157,14 +161,15 @@ def test_create_server_trigger_events(app): serv_coro = app.create_server(return_asyncio_server=True, sock=sock) serv_task = asyncio.ensure_future(serv_coro, loop=loop) server = loop.run_until_complete(serv_task) - server.after_start() + loop.run_until_complete(server.startup()) + loop.run_until_complete(server.after_start()) try: loop.run_forever() - except KeyboardInterrupt as e: + except KeyboardInterrupt: loop.stop() finally: # Run the on_stop function if provided - server.before_stop() + loop.run_until_complete(server.before_stop()) # Wait for server to close close_task = server.close() @@ -174,5 +179,19 @@ def test_create_server_trigger_events(app): signal.stopped = True for connection in server.connections: connection.close_if_idle() - server.after_stop() + loop.run_until_complete(server.after_stop()) assert flag1 and flag2 and flag3 + + +@pytest.mark.asyncio +async def test_missing_startup_raises_exception(app): + @app.listener("before_server_start") + async def init_db(app, loop): + ... + + srv = await app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + + with pytest.raises(SanicException): + await srv.before_start() diff --git a/tests/test_signal_handlers.py b/tests/test_signal_handlers.py index 857b5283..f7657ad6 100644 --- a/tests/test_signal_handlers.py +++ b/tests/test_signal_handlers.py @@ -95,7 +95,7 @@ def test_windows_workaround(): os.kill(os.getpid(), signal.SIGINT) await asyncio.sleep(0.2) assert app.is_stopping - assert app.stay_active_task.result() == None + assert app.stay_active_task.result() is None # Second Ctrl+C should raise with pytest.raises(KeyboardInterrupt): os.kill(os.getpid(), signal.SIGINT) diff --git a/tests/test_signals.py b/tests/test_signals.py index 5d116f90..9b8a9495 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app): app.signal_router.finalize() + assert len(app.signal_router.routes) == 3 await app.dispatch("foo.bar.baz") assert counter == 2 @@ -331,7 +332,8 @@ def test_event_on_bp_not_registered(): "event,expected", ( ("foo.bar.baz", True), - ("server.init.before", False), + ("server.init.before", True), + ("server.init.somethingelse", False), ("http.request.start", False), ("sanic.notice.anything", True), ), diff --git a/tests/test_static.py b/tests/test_static.py index 00e5611d..7d62d2d3 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -461,6 +461,22 @@ def test_nested_dir(app, static_file_directory): assert response.text == "foo\n" +def test_handle_is_a_directory_error(app, static_file_directory): + error_text = "Is a directory. Access denied" + app.static("/static", static_file_directory) + + @app.exception(Exception) + async def handleStaticDirError(request, exception): + if isinstance(exception, IsADirectoryError): + return text(error_text, status=403) + raise exception + + request, response = app.test_client.get("/static/") + + assert response.status == 403 + assert response.text == error_text + + def test_stack_trace_on_not_found(app, static_file_directory, caplog): app.static("/static", static_file_directory) @@ -507,3 +523,56 @@ def test_multiple_statics(app, static_file_directory): assert response.body == get_file_content( static_file_directory, "python.png" ) + + +def test_resource_type_default(app, static_file_directory): + app.static("/static", static_file_directory) + app.static("/file", get_file_path(static_file_directory, "test.file")) + + _, response = app.test_client.get("/static") + assert response.status == 404 + + _, response = app.test_client.get("/file") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "test.file" + ) + + +def test_resource_type_file(app, static_file_directory): + app.static( + "/file", + get_file_path(static_file_directory, "test.file"), + resource_type="file", + ) + + _, response = app.test_client.get("/file") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "test.file" + ) + + with pytest.raises(TypeError): + app.static("/static", static_file_directory, resource_type="file") + + +def test_resource_type_dir(app, static_file_directory): + app.static("/static", static_file_directory, resource_type="dir") + + _, response = app.test_client.get("/static/test.file") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "test.file" + ) + + with pytest.raises(TypeError): + app.static( + "/file", + get_file_path(static_file_directory, "test.file"), + resource_type="dir", + ) + + +def test_resource_type_unknown(app, static_file_directory, caplog): + with pytest.raises(ValueError): + app.static("/static", static_file_directory, resource_type="unknown") diff --git a/tests/test_touchup.py b/tests/test_touchup.py new file mode 100644 index 00000000..3079aa1b --- /dev/null +++ b/tests/test_touchup.py @@ -0,0 +1,21 @@ +import logging + +from sanic.signals import RESERVED_NAMESPACES +from sanic.touchup import TouchUp + + +def test_touchup_methods(app): + assert len(TouchUp._registry) == 9 + + +async def test_ode_removes_dispatch_events(app, caplog): + with caplog.at_level(logging.DEBUG, logger="sanic.root"): + await app._startup() + logs = caplog.record_tuples + + for signal in RESERVED_NAMESPACES["http"]: + assert ( + "sanic.root", + logging.DEBUG, + f"Disabling event: {signal}", + ) in logs diff --git a/tests/test_url_for.py b/tests/test_url_for.py index d623cc4a..6ec6a93f 100644 --- a/tests/test_url_for.py +++ b/tests/test_url_for.py @@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app): ) -def test_websocket_bp_route_name(app): +@pytest.mark.parametrize( + "name,expected", + ( + ("test_route", "/bp/route"), + ("test_route2", "/bp/route2"), + ("foobar_3", "/bp/route3"), + ), +) +def test_websocket_bp_route_name(app, name, expected): """Tests that blueprint websocket route is named.""" event = asyncio.Event() bp = Blueprint("test_bp", url_prefix="/bp") @@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app): uri = app.url_for("test_bp.main") assert uri == "/bp/main" - uri = app.url_for("test_bp.test_route") - assert uri == "/bp/route" + uri = app.url_for(f"test_bp.{name}") + assert uri == expected request, response = SanicTestClient(app).websocket(uri) assert response.opened is True assert event.is_set() - event.clear() - uri = app.url_for("test_bp.test_route2") - assert uri == "/bp/route2" - request, response = SanicTestClient(app).websocket(uri) - assert response.opened is True - assert event.is_set() - - uri = app.url_for("test_bp.foobar_3") - assert uri == "/bp/route3" - # TODO: add test with a route with multiple hosts # TODO: add test with a route with _host in url_for diff --git a/tests/test_worker.py b/tests/test_worker.py index 252bdb36..3850b8a6 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -175,7 +175,7 @@ def test_worker_close(worker): worker.wsgi = mock.Mock() conn = mock.Mock() conn.websocket = mock.Mock() - conn.websocket.close_connection = mock.Mock(wraps=_a_noop) + conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) worker.connections = set([conn]) worker.log = mock.Mock() worker.loop = loop @@ -190,5 +190,5 @@ def test_worker_close(worker): loop.run_until_complete(_close) assert worker.signal.stopped - assert conn.websocket.close_connection.called + assert conn.websocket.fail_connection.called assert len(worker.servers) == 0 diff --git a/tox.ini b/tox.ini index 590dc25a..5612f6de 100644 --- a/tox.ini +++ b/tox.ini @@ -2,53 +2,28 @@ envlist = py37, py38, py39, pyNightly, pypy37, {py37,py38,py39,pyNightly,pypy37}-no-ext, lint, check, security, docs, type-checking [testenv] -usedevelop = True +usedevelop = true setenv = {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UJSON=1 {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 -deps = - sanic-testing>=0.6.0 - coverage==5.3 - pytest==5.2.1 - pytest-cov - pytest-sanic - pytest-sugar - pytest-benchmark - chardet==3.* - beautifulsoup4 - gunicorn==20.0.4 - uvicorn - websockets>=9.0 +extras = test commands = pytest {posargs:tests --cov sanic} - coverage combine --append - coverage report -m + coverage report -m -i coverage html -i [testenv:lint] -deps = - flake8 - black - isort>=5.0.0 - bandit - commands = flake8 sanic black --config ./.black.toml --check --verbose sanic/ isort --check-only sanic --profile=black [testenv:type-checking] -deps = - mypy>=0.901 - types-ujson - commands = mypy sanic [testenv:check] -deps = - docutils - pygments commands = python setup.py check -r -s @@ -60,8 +35,6 @@ markers = asyncio [testenv:security] -deps = - bandit commands = bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py @@ -69,30 +42,10 @@ commands = [testenv:docs] platform = linux|linux2|darwin whitelist_externals = make -deps = - sphinx>=2.1.2 - sphinx_rtd_theme>=0.4.3 - recommonmark>=0.5.0 - docutils - pygments - gunicorn==20.0.4 +extras = docs commands = make docs-test [testenv:coverage] -usedevelop = True -deps = - sanic-testing>=0.6.0 - coverage==5.3 - pytest==5.2.1 - pytest-cov - pytest-sanic - pytest-sugar - pytest-benchmark - chardet==3.* - beautifulsoup4 - gunicorn==20.0.4 - uvicorn - websockets>=9.0 commands = pytest tests --cov=./sanic --cov-report=xml