Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e12edbc38 | ||
|
|
50a606adee | ||
|
|
f995612073 | ||
|
|
bc08383acd | ||
|
|
b83a1a184c | ||
|
|
59dd6814f8 | ||
|
|
f7abf3db1b | ||
|
|
cf1d2148ac | ||
|
|
b5f2bd9b0e | ||
|
|
ba2670e99c | ||
|
|
6ffc4d9756 | ||
|
|
595d2c76ac | ||
|
|
d9796e9b1e | ||
|
|
404c5f9f9e | ||
|
|
a937e08ef0 | ||
|
|
ef4f058a6c | ||
|
|
69c5dde9bf | ||
|
|
945885d501 | ||
|
|
9d0b54c90d | ||
|
|
2e5c288fea | ||
|
|
f32ef20b74 | ||
|
|
e2eefaac55 | ||
|
|
e1cfbf0fd9 | ||
|
|
08c5689441 | ||
|
|
8dbda247d6 | ||
|
|
71a631237d | ||
|
|
e22ff3828b | ||
|
|
b1b12e004e | ||
|
|
5308fec354 | ||
|
|
0ba57d4701 | ||
|
|
54ca6a6178 | ||
|
|
7dd4a78cf2 | ||
|
|
52ff49512a | ||
|
|
5a48b94089 | ||
|
|
ba1c73d947 | ||
|
|
4732b6bdfa | ||
|
|
a6e78b70ab | ||
|
|
bb1174afc5 | ||
|
|
df8abe9cfd | ||
|
|
c3bca97ee1 | ||
|
|
c3b6fa1bba | ||
|
|
94d496afe1 | ||
|
|
7b7a572f9b | ||
|
|
1b8cb742f9 | ||
|
|
3492d180a8 | ||
|
|
021da38373 | ||
|
|
ac784759d5 | ||
|
|
36eda2cd62 | ||
|
|
08a4b3013f | ||
|
|
1dd0332e8b | ||
|
|
a90877ac31 | ||
|
|
8b7ea27a48 |
@@ -10,3 +10,15 @@ exclude_patterns:
|
||||
- "examples/"
|
||||
- "hack/"
|
||||
- "scripts/"
|
||||
- "tests/"
|
||||
checks:
|
||||
argument-count:
|
||||
enabled: false
|
||||
file-lines:
|
||||
config:
|
||||
threshold: 1000
|
||||
method-count:
|
||||
config:
|
||||
threshold: 40
|
||||
complex-logic:
|
||||
enabled: false
|
||||
|
||||
62
.github/workflows/pr-windows.yml
vendored
62
.github/workflows/pr-windows.yml
vendored
@@ -1,34 +1,34 @@
|
||||
# name: Run Unit Tests on Windows
|
||||
# on:
|
||||
# pull_request:
|
||||
# branches:
|
||||
# - main
|
||||
name: Run Unit Tests on Windows
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
# jobs:
|
||||
# testsOnWindows:
|
||||
# name: ut-${{ matrix.config.tox-env }}
|
||||
# runs-on: windows-latest
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# config:
|
||||
# - { python-version: 3.7, tox-env: py37-no-ext }
|
||||
# - { python-version: 3.8, tox-env: py38-no-ext }
|
||||
# - { python-version: 3.9, tox-env: py39-no-ext }
|
||||
# - { python-version: pypy-3.7, tox-env: pypy37-no-ext }
|
||||
jobs:
|
||||
testsOnWindows:
|
||||
name: ut-${{ matrix.config.tox-env }}
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- { python-version: 3.7, tox-env: py37-no-ext }
|
||||
- { python-version: 3.8, tox-env: py38-no-ext }
|
||||
- { python-version: 3.9, tox-env: py39-no-ext }
|
||||
- { python-version: pypy-3.7, tox-env: pypy37-no-ext }
|
||||
|
||||
# steps:
|
||||
# - name: Checkout Repository
|
||||
# uses: actions/checkout@v2
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
# - name: Run Unit Tests
|
||||
# uses: ahopkins/custom-actions@pip-extra-args
|
||||
# with:
|
||||
# python-version: ${{ matrix.config.python-version }}
|
||||
# test-infra-tool: tox
|
||||
# test-infra-version: latest
|
||||
# action: tests
|
||||
# test-additional-args: "-e=${{ matrix.config.tox-env }}"
|
||||
# experimental-ignore-error: "true"
|
||||
# command-timeout: "600000"
|
||||
# pip-extra-args: "--user"
|
||||
- name: Run Unit Tests
|
||||
uses: ahopkins/custom-actions@pip-extra-args
|
||||
with:
|
||||
python-version: ${{ matrix.config.python-version }}
|
||||
test-infra-tool: tox
|
||||
test-infra-version: latest
|
||||
action: tests
|
||||
test-additional-args: "-e=${{ matrix.config.tox-env }}"
|
||||
experimental-ignore-error: "true"
|
||||
command-timeout: "600000"
|
||||
pip-extra-args: "--user"
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
.. note::
|
||||
|
||||
From v21.9, CHANGELOG files are maintained in ``./docs/sanic/releases``
|
||||
|
||||
Version 21.6.1
|
||||
--------------
|
||||
|
||||
Bugfixes
|
||||
********
|
||||
|
||||
* `#2178 <https://github.com/sanic-org/sanic/pull/2178>`_
|
||||
Update sanic-routing to allow for better splitting of complex URI templates
|
||||
* `#2183 <https://github.com/sanic-org/sanic/pull/2183>`_
|
||||
Proper handling of chunked request bodies to resolve phantom 503 in logs
|
||||
* `#2181 <https://github.com/sanic-org/sanic/pull/2181>`_
|
||||
Resolve regression in exception logging
|
||||
* `#2201 <https://github.com/sanic-org/sanic/pull/2201>`_
|
||||
Cleanup request info in pipelined requests
|
||||
|
||||
Version 21.6.0
|
||||
--------------
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ a virtual environment already set up, then run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 install -e . ".[dev]"
|
||||
pip install -e ".[dev]"
|
||||
|
||||
Dependency Changes
|
||||
------------------
|
||||
|
||||
12
README.rst
12
README.rst
@@ -77,17 +77,7 @@ The goal of the project is to provide a simple way to get up and running a highl
|
||||
Sponsor
|
||||
-------
|
||||
|
||||
|Try CodeStream|
|
||||
|
||||
.. |Try CodeStream| image:: https://alt-images.codestream.com/codestream_logo_sanicorg.png
|
||||
:target: https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner
|
||||
:alt: Try CodeStream
|
||||
|
||||
Manage pull requests and conduct code reviews in your IDE with full source-tree context. Comment on any line, not just the diffs. Use jump-to-definition, your favorite keybindings, and code intelligence with more of your workflow.
|
||||
|
||||
`Learn More <https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner>`_
|
||||
|
||||
Thank you to our sponsor. Check out `open collective <https://opencollective.com/sanic-org>`_ to learn more about helping to fund Sanic.
|
||||
Check out `open collective <https://opencollective.com/sanic-org>`_ to learn more about helping to fund Sanic.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
20
docs/conf.py
20
docs/conf.py
@@ -10,10 +10,8 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add support for auto-doc
|
||||
import recommonmark
|
||||
|
||||
from recommonmark.transform import AutoStructify
|
||||
# Add support for auto-doc
|
||||
|
||||
|
||||
# Ensure that sanic is present in the path, to allow sphinx-apidoc to
|
||||
@@ -26,7 +24,7 @@ import sanic
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
extensions = ["sphinx.ext.autodoc", "recommonmark"]
|
||||
extensions = ["sphinx.ext.autodoc", "m2r2"]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
|
||||
@@ -162,20 +160,6 @@ autodoc_default_options = {
|
||||
"member-order": "groupwise",
|
||||
}
|
||||
|
||||
|
||||
# app setup hook
|
||||
def setup(app):
|
||||
app.add_config_value(
|
||||
"recommonmark_config",
|
||||
{
|
||||
"enable_eval_rst": True,
|
||||
"enable_auto_doc_ref": False,
|
||||
},
|
||||
True,
|
||||
)
|
||||
app.add_transform(AutoStructify)
|
||||
|
||||
|
||||
html_theme_options = {
|
||||
"style_external_links": False,
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
📜 Changelog
|
||||
============
|
||||
|
||||
.. mdinclude:: ./releases/21.9.md
|
||||
|
||||
.. include:: ../../CHANGELOG.rst
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
♥️ Contributing
|
||||
==============
|
||||
===============
|
||||
|
||||
.. include:: ../../CONTRIBUTING.rst
|
||||
|
||||
40
docs/sanic/releases/21.9.md
Normal file
40
docs/sanic/releases/21.9.md
Normal file
@@ -0,0 +1,40 @@
|
||||
## Version 21.9
|
||||
|
||||
### Features
|
||||
- [#2158](https://github.com/sanic-org/sanic/pull/2158), [#2248](https://github.com/sanic-org/sanic/pull/2248) Complete overhaul of I/O to websockets
|
||||
- [#2160](https://github.com/sanic-org/sanic/pull/2160) Add new 17 signals into server and request lifecycles
|
||||
- [#2162](https://github.com/sanic-org/sanic/pull/2162) Smarter `auto` fallback formatting upon exception
|
||||
- [#2184](https://github.com/sanic-org/sanic/pull/2184) Introduce implementation for copying a Blueprint
|
||||
- [#2200](https://github.com/sanic-org/sanic/pull/2200) Accept header parsing
|
||||
- [#2207](https://github.com/sanic-org/sanic/pull/2207) Log remote address if available
|
||||
- [#2209](https://github.com/sanic-org/sanic/pull/2209) Add convenience methods to BP groups
|
||||
- [#2216](https://github.com/sanic-org/sanic/pull/2216) Add default messages to SanicExceptions
|
||||
- [#2225](https://github.com/sanic-org/sanic/pull/2225) Type annotation convenience for annotated handlers with path parameters
|
||||
- [#2236](https://github.com/sanic-org/sanic/pull/2236) Allow Falsey (but not-None) responses from route handlers
|
||||
- [#2238](https://github.com/sanic-org/sanic/pull/2238) Add `exception` decorator to Blueprint Groups
|
||||
- [#2244](https://github.com/sanic-org/sanic/pull/2244) Explicit static directive for serving file or dir (ex: `static(..., resource_type="file")`)
|
||||
- [#2245](https://github.com/sanic-org/sanic/pull/2245) Close HTTP loop when connection task cancelled
|
||||
|
||||
### Bugfixes
|
||||
- [#2188](https://github.com/sanic-org/sanic/pull/2188) Fix the handling of the end of a chunked request
|
||||
- [#2195](https://github.com/sanic-org/sanic/pull/2195) Resolve unexpected error handling on static requests
|
||||
- [#2208](https://github.com/sanic-org/sanic/pull/2208) Make blueprint-based exceptions attach and trigger in a more intuitive manner
|
||||
- [#2211](https://github.com/sanic-org/sanic/pull/2211) Fixed for handling exceptions of asgi app call
|
||||
- [#2213](https://github.com/sanic-org/sanic/pull/2213) Fix bug where ws exceptions not being logged
|
||||
- [#2231](https://github.com/sanic-org/sanic/pull/2231) Cleaner closing of tasks by using `abort()` in strategic places to avoid dangling sockets
|
||||
- [#2247](https://github.com/sanic-org/sanic/pull/2247) Fix logging of auto-reload status in debug mode
|
||||
- [#2246](https://github.com/sanic-org/sanic/pull/2246) Account for BP with exception handler but no routes
|
||||
|
||||
### Developer infrastructure
|
||||
- [#2194](https://github.com/sanic-org/sanic/pull/2194) HTTP unit tests with raw client
|
||||
- [#2199](https://github.com/sanic-org/sanic/pull/2199) Switch to codeclimate
|
||||
- [#2214](https://github.com/sanic-org/sanic/pull/2214) Try Reopening Windows Tests
|
||||
- [#2229](https://github.com/sanic-org/sanic/pull/2229) Refactor `HttpProtocol` into a base class
|
||||
- [#2230](https://github.com/sanic-org/sanic/pull/2230) Refactor `server.py` into multi-file module
|
||||
|
||||
### Miscellaneous
|
||||
- [#2173](https://github.com/sanic-org/sanic/pull/2173) Remove Duplicated Dependencies and PEP 517 Support
|
||||
- [#2193](https://github.com/sanic-org/sanic/pull/2193), [#2196](https://github.com/sanic-org/sanic/pull/2196), [#2217](https://github.com/sanic-org/sanic/pull/2217) Type annotation changes
|
||||
|
||||
|
||||
|
||||
@@ -1,29 +1,44 @@
|
||||
from sanic import Sanic
|
||||
from sanic import response
|
||||
from signal import signal, SIGINT
|
||||
import asyncio
|
||||
|
||||
from signal import SIGINT, signal
|
||||
|
||||
import uvloop
|
||||
|
||||
from sanic import Sanic, response
|
||||
from sanic.server import AsyncioServer
|
||||
|
||||
|
||||
app = Sanic(__name__)
|
||||
|
||||
@app.listener('after_server_start')
|
||||
|
||||
@app.listener("after_server_start")
|
||||
async def after_start_test(app, loop):
|
||||
print("Async Server Started!")
|
||||
|
||||
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
return response.json({"answer": "42"})
|
||||
|
||||
|
||||
asyncio.set_event_loop(uvloop.new_event_loop())
|
||||
serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True)
|
||||
serv_coro = app.create_server(
|
||||
host="0.0.0.0", port=8000, return_asyncio_server=True
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
|
||||
signal(SIGINT, lambda s, f: loop.stop())
|
||||
server = loop.run_until_complete(serv_task)
|
||||
server: AsyncioServer = loop.run_until_complete(serv_task) # type: ignore
|
||||
server.startup()
|
||||
|
||||
# When using app.run(), this actually triggers before the serv_coro.
|
||||
# But, in this example, we are using the convenience method, even if it is
|
||||
# out of order.
|
||||
server.before_start()
|
||||
server.after_start()
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt as e:
|
||||
except KeyboardInterrupt:
|
||||
loop.stop()
|
||||
finally:
|
||||
server.before_stop()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from sanic import Sanic
|
||||
from sanic.response import file
|
||||
from sanic.response import redirect
|
||||
|
||||
app = Sanic(__name__)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
async def index(request):
|
||||
return await file('websocket.html')
|
||||
app.static('index.html', "websocket.html")
|
||||
|
||||
@app.route('/')
|
||||
def index(request):
|
||||
return redirect("index.html")
|
||||
|
||||
@app.websocket('/feed')
|
||||
async def feed(request, ws):
|
||||
|
||||
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
@@ -0,0 +1,3 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
@@ -1 +1 @@
|
||||
__version__ = "21.6.0"
|
||||
__version__ = "21.9.1"
|
||||
|
||||
314
sanic/app.py
314
sanic/app.py
@@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import re
|
||||
|
||||
from asyncio import (
|
||||
AbstractEventLoop,
|
||||
CancelledError,
|
||||
Protocol,
|
||||
ensure_future,
|
||||
@@ -21,6 +24,7 @@ from traceback import format_exc
|
||||
from types import SimpleNamespace
|
||||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
@@ -30,6 +34,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
@@ -69,20 +74,29 @@ from sanic.router import Router
|
||||
from sanic.server import AsyncioServer, HttpProtocol
|
||||
from sanic.server import Signal as ServerSignal
|
||||
from sanic.server import serve, serve_multiple, serve_single
|
||||
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
|
||||
from sanic.server.websockets.impl import ConnectionClosed
|
||||
from sanic.signals import Signal, SignalRouter
|
||||
from sanic.websocket import ConnectionClosed, WebSocketProtocol
|
||||
from sanic.touchup import TouchUp, TouchUpMeta
|
||||
|
||||
|
||||
class Sanic(BaseSanic):
|
||||
class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||
"""
|
||||
The main application instance
|
||||
"""
|
||||
|
||||
__touchup__ = (
|
||||
"handle_request",
|
||||
"handle_exception",
|
||||
"_run_response_middleware",
|
||||
"_run_request_middleware",
|
||||
)
|
||||
__fake_slots__ = (
|
||||
"_asgi_app",
|
||||
"_app_registry",
|
||||
"_asgi_client",
|
||||
"_blueprint_order",
|
||||
"_delayed_tasks",
|
||||
"_future_routes",
|
||||
"_future_statics",
|
||||
"_future_middleware",
|
||||
@@ -137,7 +151,7 @@ class Sanic(BaseSanic):
|
||||
log_config: Optional[Dict[str, Any]] = None,
|
||||
configure_logging: bool = True,
|
||||
register: Optional[bool] = None,
|
||||
dumps: Optional[Callable[..., str]] = None,
|
||||
dumps: Optional[Callable[..., AnyStr]] = None,
|
||||
) -> None:
|
||||
super().__init__(name=name)
|
||||
|
||||
@@ -153,6 +167,7 @@ class Sanic(BaseSanic):
|
||||
|
||||
self._asgi_client = None
|
||||
self._blueprint_order: List[Blueprint] = []
|
||||
self._delayed_tasks: List[str] = []
|
||||
self._test_client = None
|
||||
self._test_manager = None
|
||||
self.asgi = False
|
||||
@@ -164,7 +179,9 @@ class Sanic(BaseSanic):
|
||||
self.configure_logging = configure_logging
|
||||
self.ctx = ctx or SimpleNamespace()
|
||||
self.debug = None
|
||||
self.error_handler = error_handler or ErrorHandler()
|
||||
self.error_handler = error_handler or ErrorHandler(
|
||||
fallback=self.config.FALLBACK_ERROR_FORMAT,
|
||||
)
|
||||
self.is_running = False
|
||||
self.is_stopping = False
|
||||
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
|
||||
@@ -190,9 +207,10 @@ class Sanic(BaseSanic):
|
||||
self.__class__.register_app(self)
|
||||
|
||||
self.router.ctx.app = self
|
||||
self.signal_router.ctx.app = self
|
||||
|
||||
if dumps:
|
||||
BaseHTTPResponse._dumps = dumps
|
||||
BaseHTTPResponse._dumps = dumps # type: ignore
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
@@ -230,9 +248,12 @@ class Sanic(BaseSanic):
|
||||
loop = self.loop # Will raise SanicError if loop is not started
|
||||
self._loop_add_task(task, self, loop)
|
||||
except SanicException:
|
||||
self.listener("before_server_start")(
|
||||
partial(self._loop_add_task, task)
|
||||
)
|
||||
task_name = f"sanic.delayed_task.{hash(task)}"
|
||||
if not self._delayed_tasks:
|
||||
self.after_server_start(partial(self.dispatch_delayed_tasks))
|
||||
|
||||
self.signal(task_name)(partial(self.run_delayed_task, task=task))
|
||||
self._delayed_tasks.append(task_name)
|
||||
|
||||
def register_listener(self, listener: Callable, event: str) -> Any:
|
||||
"""
|
||||
@@ -244,12 +265,20 @@ class Sanic(BaseSanic):
|
||||
"""
|
||||
|
||||
try:
|
||||
_event = ListenerEvent(event)
|
||||
except ValueError:
|
||||
valid = ", ".join(ListenerEvent.__members__.values())
|
||||
_event = ListenerEvent[event.upper()]
|
||||
except (ValueError, AttributeError):
|
||||
valid = ", ".join(
|
||||
map(lambda x: x.lower(), ListenerEvent.__members__.keys())
|
||||
)
|
||||
raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}")
|
||||
|
||||
self.listeners[_event].append(listener)
|
||||
if "." in _event:
|
||||
self.signal(_event.value)(
|
||||
partial(self._listener, listener=listener)
|
||||
)
|
||||
else:
|
||||
self.listeners[_event.value].append(listener)
|
||||
|
||||
return listener
|
||||
|
||||
def register_middleware(self, middleware, attach_to: str = "request"):
|
||||
@@ -308,7 +337,11 @@ class Sanic(BaseSanic):
|
||||
self.named_response_middleware[_rn].appendleft(middleware)
|
||||
return middleware
|
||||
|
||||
def _apply_exception_handler(self, handler: FutureException):
|
||||
def _apply_exception_handler(
|
||||
self,
|
||||
handler: FutureException,
|
||||
route_names: Optional[List[str]] = None,
|
||||
):
|
||||
"""Decorate a function to be registered as a handler for exceptions
|
||||
|
||||
:param exceptions: exceptions
|
||||
@@ -318,9 +351,9 @@ class Sanic(BaseSanic):
|
||||
for exception in handler.exceptions:
|
||||
if isinstance(exception, (tuple, list)):
|
||||
for e in exception:
|
||||
self.error_handler.add(e, handler.handler)
|
||||
self.error_handler.add(e, handler.handler, route_names)
|
||||
else:
|
||||
self.error_handler.add(exception, handler.handler)
|
||||
self.error_handler.add(exception, handler.handler, route_names)
|
||||
return handler.handler
|
||||
|
||||
def _apply_listener(self, listener: FutureListener):
|
||||
@@ -377,11 +410,17 @@ class Sanic(BaseSanic):
|
||||
*,
|
||||
condition: Optional[Dict[str, str]] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
fail_not_found: bool = True,
|
||||
inline: bool = False,
|
||||
reverse: bool = False,
|
||||
) -> Coroutine[Any, Any, Awaitable[Any]]:
|
||||
return self.signal_router.dispatch(
|
||||
event,
|
||||
context=context,
|
||||
condition=condition,
|
||||
inline=inline,
|
||||
reverse=reverse,
|
||||
fail_not_found=fail_not_found,
|
||||
)
|
||||
|
||||
async def event(
|
||||
@@ -411,7 +450,13 @@ class Sanic(BaseSanic):
|
||||
|
||||
self.websocket_enabled = enable
|
||||
|
||||
def blueprint(self, blueprint, **options):
|
||||
def blueprint(
|
||||
self,
|
||||
blueprint: Union[
|
||||
Blueprint, List[Blueprint], Tuple[Blueprint], BlueprintGroup
|
||||
],
|
||||
**options: Any,
|
||||
):
|
||||
"""Register a blueprint on the application.
|
||||
|
||||
:param blueprint: Blueprint object or (list, tuple) thereof
|
||||
@@ -651,7 +696,7 @@ class Sanic(BaseSanic):
|
||||
|
||||
async def handle_exception(
|
||||
self, request: Request, exception: BaseException
|
||||
):
|
||||
): # no cov
|
||||
"""
|
||||
A handler that catches specific exceptions and outputs a response.
|
||||
|
||||
@@ -661,6 +706,12 @@ class Sanic(BaseSanic):
|
||||
:type exception: BaseException
|
||||
:raises ServerError: response 500
|
||||
"""
|
||||
await self.dispatch(
|
||||
"http.lifecycle.exception",
|
||||
inline=True,
|
||||
context={"request": request, "exception": exception},
|
||||
)
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
@@ -707,7 +758,7 @@ class Sanic(BaseSanic):
|
||||
f"Invalid response type {response!r} (need HTTPResponse)"
|
||||
)
|
||||
|
||||
async def handle_request(self, request: Request):
|
||||
async def handle_request(self, request: Request): # no cov
|
||||
"""Take a request from the HTTP Server and return a response object
|
||||
to be sent back The HTTP Server only expects a response object, so
|
||||
exception handling must be done here
|
||||
@@ -715,10 +766,22 @@ class Sanic(BaseSanic):
|
||||
:param request: HTTP Request object
|
||||
:return: Nothing
|
||||
"""
|
||||
await self.dispatch(
|
||||
"http.lifecycle.handle",
|
||||
inline=True,
|
||||
context={"request": request},
|
||||
)
|
||||
|
||||
# Define `response` var here to remove warnings about
|
||||
# allocation before assignment below.
|
||||
response = None
|
||||
try:
|
||||
|
||||
await self.dispatch(
|
||||
"http.routing.before",
|
||||
inline=True,
|
||||
context={"request": request},
|
||||
)
|
||||
# Fetch handler from router
|
||||
route, handler, kwargs = self.router.get(
|
||||
request.path,
|
||||
@@ -726,19 +789,29 @@ class Sanic(BaseSanic):
|
||||
request.headers.getone("host", None),
|
||||
)
|
||||
|
||||
request._match_info = kwargs
|
||||
request._match_info = {**kwargs}
|
||||
request.route = route
|
||||
|
||||
await self.dispatch(
|
||||
"http.routing.after",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"route": route,
|
||||
"kwargs": kwargs,
|
||||
"handler": handler,
|
||||
},
|
||||
)
|
||||
|
||||
if (
|
||||
request.stream.request_body # type: ignore
|
||||
request.stream
|
||||
and request.stream.request_body
|
||||
and not route.ctx.ignore_body
|
||||
):
|
||||
|
||||
if hasattr(handler, "is_stream"):
|
||||
# Streaming handler: lift the size limit
|
||||
request.stream.request_max_size = float( # type: ignore
|
||||
"inf"
|
||||
)
|
||||
request.stream.request_max_size = float("inf")
|
||||
else:
|
||||
# Non-streaming handler: preload body
|
||||
await request.receive_body()
|
||||
@@ -765,17 +838,25 @@ class Sanic(BaseSanic):
|
||||
)
|
||||
|
||||
# Run response handler
|
||||
response = handler(request, **kwargs)
|
||||
response = handler(request, **request.match_info)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
|
||||
if response:
|
||||
if response is not None:
|
||||
response = await request.respond(response)
|
||||
elif not hasattr(handler, "is_websocket"):
|
||||
response = request.stream.response # type: ignore
|
||||
|
||||
# Make sure that response is finished / run StreamingHTTP callback
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
await self.dispatch(
|
||||
"http.lifecycle.response",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
await response.send(end_stream=True)
|
||||
else:
|
||||
if not hasattr(handler, "is_websocket"):
|
||||
@@ -793,23 +874,11 @@ class Sanic(BaseSanic):
|
||||
async def _websocket_handler(
|
||||
self, handler, request, *args, subprotocols=None, **kwargs
|
||||
):
|
||||
request.app = self
|
||||
if not getattr(handler, "__blueprintname__", False):
|
||||
request._name = handler.__name__
|
||||
else:
|
||||
request._name = (
|
||||
getattr(handler, "__blueprintname__", "") + handler.__name__
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
if self.asgi:
|
||||
ws = request.transport.get_websocket_connection()
|
||||
await ws.accept(subprotocols)
|
||||
else:
|
||||
protocol = request.transport.get_protocol()
|
||||
protocol.app = self
|
||||
|
||||
ws = await protocol.websocket_handshake(request, subprotocols)
|
||||
|
||||
# schedule the application handler
|
||||
@@ -817,12 +886,18 @@ class Sanic(BaseSanic):
|
||||
# needs to be cancelled due to the server being stopped
|
||||
fut = ensure_future(handler(request, ws, *args, **kwargs))
|
||||
self.websocket_tasks.add(fut)
|
||||
cancelled = False
|
||||
try:
|
||||
await fut
|
||||
except Exception as e:
|
||||
self.error_handler.log(request, e)
|
||||
except (CancelledError, ConnectionClosed):
|
||||
pass
|
||||
cancelled = True
|
||||
finally:
|
||||
self.websocket_tasks.remove(fut)
|
||||
if cancelled:
|
||||
ws.end_connection(1000)
|
||||
else:
|
||||
await ws.close()
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
@@ -869,7 +944,7 @@ class Sanic(BaseSanic):
|
||||
*,
|
||||
debug: bool = False,
|
||||
auto_reload: Optional[bool] = None,
|
||||
ssl: Union[dict, SSLContext, None] = None,
|
||||
ssl: Union[Dict[str, str], SSLContext, None] = None,
|
||||
sock: Optional[socket] = None,
|
||||
workers: int = 1,
|
||||
protocol: Optional[Type[Protocol]] = None,
|
||||
@@ -999,7 +1074,7 @@ class Sanic(BaseSanic):
|
||||
port: Optional[int] = None,
|
||||
*,
|
||||
debug: bool = False,
|
||||
ssl: Union[dict, SSLContext, None] = None,
|
||||
ssl: Union[Dict[str, str], SSLContext, None] = None,
|
||||
sock: Optional[socket] = None,
|
||||
protocol: Type[Protocol] = None,
|
||||
backlog: int = 100,
|
||||
@@ -1071,11 +1146,6 @@ class Sanic(BaseSanic):
|
||||
run_async=return_asyncio_server,
|
||||
)
|
||||
|
||||
# Trigger before_start events
|
||||
await self.trigger_events(
|
||||
server_settings.get("before_start", []),
|
||||
server_settings.get("loop"),
|
||||
)
|
||||
main_start = server_settings.pop("main_start", None)
|
||||
main_stop = server_settings.pop("main_stop", None)
|
||||
if main_start or main_stop:
|
||||
@@ -1088,17 +1158,9 @@ class Sanic(BaseSanic):
|
||||
asyncio_server_kwargs=asyncio_server_kwargs, **server_settings
|
||||
)
|
||||
|
||||
async def trigger_events(self, events, loop):
|
||||
"""Trigger events (functions or async)
|
||||
:param events: one or more sync or async functions to execute
|
||||
:param loop: event loop
|
||||
"""
|
||||
for event in events:
|
||||
result = event(loop)
|
||||
if isawaitable(result):
|
||||
await result
|
||||
|
||||
async def _run_request_middleware(self, request, request_name=None):
|
||||
async def _run_request_middleware(
|
||||
self, request, request_name=None
|
||||
): # no cov
|
||||
# The if improves speed. I don't know why
|
||||
named_middleware = self.named_request_middleware.get(
|
||||
request_name, deque()
|
||||
@@ -1111,25 +1173,67 @@ class Sanic(BaseSanic):
|
||||
request.request_middleware_started = True
|
||||
|
||||
for middleware in applicable_middleware:
|
||||
await self.dispatch(
|
||||
"http.middleware.before",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": None,
|
||||
},
|
||||
condition={"attach_to": "request"},
|
||||
)
|
||||
|
||||
response = middleware(request)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
|
||||
await self.dispatch(
|
||||
"http.middleware.after",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": None,
|
||||
},
|
||||
condition={"attach_to": "request"},
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
return None
|
||||
|
||||
async def _run_response_middleware(
|
||||
self, request, response, request_name=None
|
||||
):
|
||||
): # no cov
|
||||
named_middleware = self.named_response_middleware.get(
|
||||
request_name, deque()
|
||||
)
|
||||
applicable_middleware = self.response_middleware + named_middleware
|
||||
if applicable_middleware:
|
||||
for middleware in applicable_middleware:
|
||||
await self.dispatch(
|
||||
"http.middleware.before",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": response,
|
||||
},
|
||||
condition={"attach_to": "response"},
|
||||
)
|
||||
|
||||
_response = middleware(request, response)
|
||||
if isawaitable(_response):
|
||||
_response = await _response
|
||||
|
||||
await self.dispatch(
|
||||
"http.middleware.after",
|
||||
inline=True,
|
||||
context={
|
||||
"request": request,
|
||||
"response": _response if _response else response,
|
||||
},
|
||||
condition={"attach_to": "response"},
|
||||
)
|
||||
|
||||
if _response:
|
||||
response = _response
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
@@ -1155,10 +1259,6 @@ class Sanic(BaseSanic):
|
||||
):
|
||||
"""Helper function used by `run` and `create_server`."""
|
||||
|
||||
self.listeners["before_server_start"] = [
|
||||
self.finalize
|
||||
] + self.listeners["before_server_start"]
|
||||
|
||||
if isinstance(ssl, dict):
|
||||
# try common aliaseses
|
||||
cert = ssl.get("cert") or ssl.get("certificate")
|
||||
@@ -1195,10 +1295,6 @@ class Sanic(BaseSanic):
|
||||
# Register start/stop events
|
||||
|
||||
for event_name, settings_name, reverse in (
|
||||
("before_server_start", "before_start", False),
|
||||
("after_server_start", "after_start", False),
|
||||
("before_server_stop", "before_stop", True),
|
||||
("after_server_stop", "after_stop", True),
|
||||
("main_process_start", "main_start", False),
|
||||
("main_process_stop", "main_stop", True),
|
||||
):
|
||||
@@ -1236,7 +1332,8 @@ class Sanic(BaseSanic):
|
||||
logger.info(f"Goin' Fast @ {proto}://{host}:{port}")
|
||||
|
||||
debug_mode = "enabled" if self.debug else "disabled"
|
||||
logger.debug("Sanic auto-reload: enabled")
|
||||
reload_mode = "enabled" if auto_reload else "disabled"
|
||||
logger.debug(f"Sanic auto-reload: {reload_mode}")
|
||||
logger.debug(f"Sanic debug mode: {debug_mode}")
|
||||
|
||||
return server_settings
|
||||
@@ -1246,20 +1343,44 @@ class Sanic(BaseSanic):
|
||||
return ".".join(parts)
|
||||
|
||||
@classmethod
|
||||
def _loop_add_task(cls, task, app, loop):
|
||||
def _prep_task(cls, task, app, loop):
|
||||
if callable(task):
|
||||
try:
|
||||
loop.create_task(task(app))
|
||||
task = task(app)
|
||||
except TypeError:
|
||||
loop.create_task(task())
|
||||
else:
|
||||
loop.create_task(task)
|
||||
task = task()
|
||||
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def _loop_add_task(cls, task, app, loop):
|
||||
prepped = cls._prep_task(task, app, loop)
|
||||
loop.create_task(prepped)
|
||||
|
||||
@classmethod
|
||||
def _cancel_websocket_tasks(cls, app, loop):
|
||||
for task in app.websocket_tasks:
|
||||
task.cancel()
|
||||
|
||||
@staticmethod
|
||||
async def dispatch_delayed_tasks(app, loop):
|
||||
for name in app._delayed_tasks:
|
||||
await app.dispatch(name, context={"app": app, "loop": loop})
|
||||
app._delayed_tasks.clear()
|
||||
|
||||
@staticmethod
|
||||
async def run_delayed_task(app, loop, task):
|
||||
prepped = app._prep_task(task, app, loop)
|
||||
await prepped
|
||||
|
||||
@staticmethod
|
||||
async def _listener(
|
||||
app: Sanic, loop: AbstractEventLoop, listener: ListenerType
|
||||
):
|
||||
maybe_coro = listener(app, loop)
|
||||
if maybe_coro and isawaitable(maybe_coro):
|
||||
await maybe_coro
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# ASGI
|
||||
# -------------------------------------------------------------------- #
|
||||
@@ -1333,15 +1454,52 @@ class Sanic(BaseSanic):
|
||||
raise SanicException(f'Sanic app name "{name}" not found.')
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# Static methods
|
||||
# Lifecycle
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
@staticmethod
|
||||
async def finalize(app, _):
|
||||
def finalize(self):
|
||||
try:
|
||||
app.router.finalize()
|
||||
if app.signal_router.routes:
|
||||
app.signal_router.finalize() # noqa
|
||||
self.router.finalize()
|
||||
except FinalizationError as e:
|
||||
if not Sanic.test_mode:
|
||||
raise e # noqa
|
||||
raise e
|
||||
|
||||
def signalize(self):
|
||||
try:
|
||||
self.signal_router.finalize()
|
||||
except FinalizationError as e:
|
||||
if not Sanic.test_mode:
|
||||
raise e
|
||||
|
||||
async def _startup(self):
|
||||
self.signalize()
|
||||
self.finalize()
|
||||
ErrorHandler.finalize(self.error_handler)
|
||||
TouchUp.run(self)
|
||||
|
||||
async def _server_event(
|
||||
self,
|
||||
concern: str,
|
||||
action: str,
|
||||
loop: Optional[AbstractEventLoop] = None,
|
||||
) -> None:
|
||||
event = f"server.{concern}.{action}"
|
||||
if action not in ("before", "after") or concern not in (
|
||||
"init",
|
||||
"shutdown",
|
||||
):
|
||||
raise SanicException(f"Invalid server event: {event}")
|
||||
logger.debug(f"Triggering server events: {event}")
|
||||
reverse = concern == "shutdown"
|
||||
if loop is None:
|
||||
loop = self.loop
|
||||
await self.dispatch(
|
||||
event,
|
||||
fail_not_found=False,
|
||||
reverse=reverse,
|
||||
inline=True,
|
||||
context={
|
||||
"app": self,
|
||||
"loop": loop,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import warnings
|
||||
|
||||
from inspect import isawaitable
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -11,21 +10,27 @@ from sanic.exceptions import ServerError
|
||||
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
|
||||
from sanic.request import Request
|
||||
from sanic.server import ConnInfo
|
||||
from sanic.websocket import WebSocketConnection
|
||||
from sanic.server.websockets.connection import WebSocketConnection
|
||||
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self, asgi_app: "ASGIApp") -> None:
|
||||
self.asgi_app = asgi_app
|
||||
|
||||
if "before_server_start" in self.asgi_app.sanic_app.listeners:
|
||||
if (
|
||||
"server.init.before"
|
||||
in self.asgi_app.sanic_app.signal_router.name_index
|
||||
):
|
||||
warnings.warn(
|
||||
'You have set a listener for "before_server_start" '
|
||||
"in ASGI mode. "
|
||||
"It will be executed as early as possible, but not before "
|
||||
"the ASGI server is started."
|
||||
)
|
||||
if "after_server_stop" in self.asgi_app.sanic_app.listeners:
|
||||
if (
|
||||
"server.shutdown.after"
|
||||
in self.asgi_app.sanic_app.signal_router.name_index
|
||||
):
|
||||
warnings.warn(
|
||||
'You have set a listener for "after_server_stop" '
|
||||
"in ASGI mode. "
|
||||
@@ -42,19 +47,9 @@ class Lifespan:
|
||||
in sequence since the ASGI lifespan protocol only supports a single
|
||||
startup event.
|
||||
"""
|
||||
self.asgi_app.sanic_app.router.finalize()
|
||||
if self.asgi_app.sanic_app.signal_router.routes:
|
||||
self.asgi_app.sanic_app.signal_router.finalize()
|
||||
listeners = self.asgi_app.sanic_app.listeners.get(
|
||||
"before_server_start", []
|
||||
) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
|
||||
|
||||
for handler in listeners:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if response and isawaitable(response):
|
||||
await response
|
||||
await self.asgi_app.sanic_app._startup()
|
||||
await self.asgi_app.sanic_app._server_event("init", "before")
|
||||
await self.asgi_app.sanic_app._server_event("init", "after")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
@@ -65,16 +60,8 @@ class Lifespan:
|
||||
in sequence since the ASGI lifespan protocol only supports a single
|
||||
shutdown event.
|
||||
"""
|
||||
listeners = self.asgi_app.sanic_app.listeners.get(
|
||||
"before_server_stop", []
|
||||
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
|
||||
|
||||
for handler in listeners:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if response and isawaitable(response):
|
||||
await response
|
||||
await self.asgi_app.sanic_app._server_event("shutdown", "before")
|
||||
await self.asgi_app.sanic_app._server_event("shutdown", "after")
|
||||
|
||||
async def __call__(
|
||||
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
||||
@@ -207,4 +194,7 @@ class ASGIApp:
|
||||
"""
|
||||
Handle the incoming request.
|
||||
"""
|
||||
try:
|
||||
await self.sanic_app.handle_request(self.request)
|
||||
except Exception as e:
|
||||
await self.sanic_app.handle_exception(self.request, e)
|
||||
|
||||
@@ -58,7 +58,7 @@ class BaseSanic(
|
||||
if name not in self.__fake_slots__:
|
||||
warn(
|
||||
f"Setting variables on {self.__class__.__name__} instances is "
|
||||
"deprecated and will be removed in version 21.9. You should "
|
||||
"deprecated and will be removed in version 21.12. You should "
|
||||
f"change your {self.__class__.__name__} instance to use "
|
||||
f"instance.ctx.{name} instead.",
|
||||
DeprecationWarning,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import MutableSequence
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
|
||||
@@ -196,6 +197,27 @@ class BlueprintGroup(MutableSequence):
|
||||
"""
|
||||
self._blueprints.append(value)
|
||||
|
||||
def exception(self, *exceptions, **kwargs):
|
||||
"""
|
||||
A decorator that can be used to implement a global exception handler
|
||||
for all the Blueprints that belong to this Blueprint Group.
|
||||
|
||||
In case of nested Blueprint Groups, the same handler is applied
|
||||
across each of the Blueprints recursively.
|
||||
|
||||
:param args: List of Python exceptions to be caught by the handler
|
||||
:param kwargs: Additional optional arguments to be passed to the
|
||||
exception handler
|
||||
:return a decorated method to handle global exceptions for any
|
||||
blueprint registered under this group.
|
||||
"""
|
||||
|
||||
def register_exception_handler_for_blueprints(fn):
|
||||
for blueprint in self.blueprints:
|
||||
blueprint.exception(*exceptions, **kwargs)(fn)
|
||||
|
||||
return register_exception_handler_for_blueprints
|
||||
|
||||
def insert(self, index: int, item: Blueprint) -> None:
|
||||
"""
|
||||
The Abstract class `MutableSequence` leverages this insert method to
|
||||
@@ -229,3 +251,15 @@ class BlueprintGroup(MutableSequence):
|
||||
args = list(args)[1:]
|
||||
return register_middleware_for_blueprints(fn)
|
||||
return register_middleware_for_blueprints
|
||||
|
||||
def on_request(self, middleware=None):
|
||||
if callable(middleware):
|
||||
return self.middleware(middleware, "request")
|
||||
else:
|
||||
return partial(self.middleware, attach_to="request")
|
||||
|
||||
def on_response(self, middleware=None):
|
||||
if callable(middleware):
|
||||
return self.middleware(middleware, "response")
|
||||
else:
|
||||
return partial(self.middleware, attach_to="response")
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union
|
||||
|
||||
@@ -12,6 +13,7 @@ from sanic_routing.route import Route # type: ignore
|
||||
from sanic.base import BaseSanic
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.exceptions import SanicException
|
||||
from sanic.helpers import Default, _default
|
||||
from sanic.models.futures import FutureRoute, FutureStatic
|
||||
from sanic.models.handler_types import (
|
||||
ListenerType,
|
||||
@@ -40,7 +42,7 @@ class Blueprint(BaseSanic):
|
||||
:param host: IP Address of FQDN for the sanic server to use.
|
||||
:param version: Blueprint Version
|
||||
:param strict_slashes: Enforce the API urls are requested with a
|
||||
training */*
|
||||
trailing */*
|
||||
"""
|
||||
|
||||
__fake_slots__ = (
|
||||
@@ -76,15 +78,9 @@ class Blueprint(BaseSanic):
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
super().__init__(name=name)
|
||||
|
||||
self._apps: Set[Sanic] = set()
|
||||
self.reset()
|
||||
self.ctx = SimpleNamespace()
|
||||
self.exceptions: List[RouteHandler] = []
|
||||
self.host = host
|
||||
self.listeners: Dict[str, List[ListenerType]] = {}
|
||||
self.middlewares: List[MiddlewareType] = []
|
||||
self.routes: List[Route] = []
|
||||
self.statics: List[RouteHandler] = []
|
||||
self.strict_slashes = strict_slashes
|
||||
self.url_prefix = (
|
||||
url_prefix[:-1]
|
||||
@@ -93,7 +89,6 @@ class Blueprint(BaseSanic):
|
||||
)
|
||||
self.version = version
|
||||
self.version_prefix = version_prefix
|
||||
self.websocket_routes: List[Route] = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
args = ", ".join(
|
||||
@@ -144,12 +139,87 @@ class Blueprint(BaseSanic):
|
||||
kwargs["apply"] = False
|
||||
return super().signal(event, *args, **kwargs)
|
||||
|
||||
def reset(self):
|
||||
self._apps: Set[Sanic] = set()
|
||||
self.exceptions: List[RouteHandler] = []
|
||||
self.listeners: Dict[str, List[ListenerType]] = {}
|
||||
self.middlewares: List[MiddlewareType] = []
|
||||
self.routes: List[Route] = []
|
||||
self.statics: List[RouteHandler] = []
|
||||
self.websocket_routes: List[Route] = []
|
||||
|
||||
def copy(
|
||||
self,
|
||||
name: str,
|
||||
url_prefix: Optional[Union[str, Default]] = _default,
|
||||
version: Optional[Union[int, str, float, Default]] = _default,
|
||||
version_prefix: Union[str, Default] = _default,
|
||||
strict_slashes: Optional[Union[bool, Default]] = _default,
|
||||
with_registration: bool = True,
|
||||
with_ctx: bool = False,
|
||||
):
|
||||
"""
|
||||
Copy a blueprint instance with some optional parameters to
|
||||
override the values of attributes in the old instance.
|
||||
|
||||
:param name: unique name of the blueprint
|
||||
:param url_prefix: URL to be prefixed before all route URLs
|
||||
:param version: Blueprint Version
|
||||
:param version_prefix: the prefix of the version number shown in the
|
||||
URL.
|
||||
:param strict_slashes: Enforce the API urls are requested with a
|
||||
trailing */*
|
||||
:param with_registration: whether register new blueprint instance with
|
||||
sanic apps that were registered with the old instance or not.
|
||||
:param with_ctx: whether ``ctx`` will be copied or not.
|
||||
"""
|
||||
|
||||
attrs_backup = {
|
||||
"_apps": self._apps,
|
||||
"routes": self.routes,
|
||||
"websocket_routes": self.websocket_routes,
|
||||
"middlewares": self.middlewares,
|
||||
"exceptions": self.exceptions,
|
||||
"listeners": self.listeners,
|
||||
"statics": self.statics,
|
||||
}
|
||||
|
||||
self.reset()
|
||||
new_bp = deepcopy(self)
|
||||
new_bp.name = name
|
||||
|
||||
if not isinstance(url_prefix, Default):
|
||||
new_bp.url_prefix = url_prefix
|
||||
if not isinstance(version, Default):
|
||||
new_bp.version = version
|
||||
if not isinstance(strict_slashes, Default):
|
||||
new_bp.strict_slashes = strict_slashes
|
||||
if not isinstance(version_prefix, Default):
|
||||
new_bp.version_prefix = version_prefix
|
||||
|
||||
for key, value in attrs_backup.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if with_registration and self._apps:
|
||||
if new_bp._future_statics:
|
||||
raise SanicException(
|
||||
"Static routes registered with the old blueprint instance,"
|
||||
" cannot be registered again."
|
||||
)
|
||||
for app in self._apps:
|
||||
app.blueprint(new_bp)
|
||||
|
||||
if not with_ctx:
|
||||
new_bp.ctx = SimpleNamespace()
|
||||
|
||||
return new_bp
|
||||
|
||||
@staticmethod
|
||||
def group(
|
||||
*blueprints,
|
||||
url_prefix="",
|
||||
version=None,
|
||||
strict_slashes=None,
|
||||
*blueprints: Union[Blueprint, BlueprintGroup],
|
||||
url_prefix: Optional[str] = None,
|
||||
version: Optional[Union[int, str, float]] = None,
|
||||
strict_slashes: Optional[bool] = None,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
"""
|
||||
@@ -196,6 +266,9 @@ class Blueprint(BaseSanic):
|
||||
opt_version = options.get("version", None)
|
||||
opt_strict_slashes = options.get("strict_slashes", None)
|
||||
opt_version_prefix = options.get("version_prefix", self.version_prefix)
|
||||
error_format = options.get(
|
||||
"error_format", app.config.FALLBACK_ERROR_FORMAT
|
||||
)
|
||||
|
||||
routes = []
|
||||
middleware = []
|
||||
@@ -243,6 +316,7 @@ class Blueprint(BaseSanic):
|
||||
future.unquote,
|
||||
future.static,
|
||||
version_prefix,
|
||||
error_format,
|
||||
)
|
||||
|
||||
route = app._apply_route(apply_route)
|
||||
@@ -261,19 +335,22 @@ class Blueprint(BaseSanic):
|
||||
|
||||
route_names = [route.name for route in routes if route]
|
||||
|
||||
# Middleware
|
||||
if route_names:
|
||||
# Middleware
|
||||
for future in self._future_middleware:
|
||||
middleware.append(app._apply_middleware(future, route_names))
|
||||
|
||||
# Exceptions
|
||||
for future in self._future_exceptions:
|
||||
exception_handlers.append(app._apply_exception_handler(future))
|
||||
exception_handlers.append(
|
||||
app._apply_exception_handler(future, route_names)
|
||||
)
|
||||
|
||||
# Event listeners
|
||||
for listener in self._future_listeners:
|
||||
listeners[listener.event].append(app._apply_listener(listener))
|
||||
|
||||
# Signals
|
||||
for signal in self._future_signals:
|
||||
signal.condition.update({"blueprint": self.name})
|
||||
app._apply_signal(signal)
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from warnings import warn
|
||||
|
||||
from sanic.errorpages import check_error_format
|
||||
from sanic.http import Http
|
||||
|
||||
from .utils import load_module_from_file_location, str_to_bool
|
||||
@@ -20,7 +21,7 @@ BASE_LOGO = """
|
||||
DEFAULT_CONFIG = {
|
||||
"ACCESS_LOG": True,
|
||||
"EVENT_AUTOREGISTER": False,
|
||||
"FALLBACK_ERROR_FORMAT": "html",
|
||||
"FALLBACK_ERROR_FORMAT": "auto",
|
||||
"FORWARDED_FOR_HEADER": "X-Forwarded-For",
|
||||
"FORWARDED_SECRET": None,
|
||||
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
|
||||
@@ -35,12 +36,9 @@ DEFAULT_CONFIG = {
|
||||
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes
|
||||
"REQUEST_TIMEOUT": 60, # 60 seconds
|
||||
"RESPONSE_TIMEOUT": 60, # 60 seconds
|
||||
"WEBSOCKET_MAX_QUEUE": 32,
|
||||
"WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte
|
||||
"WEBSOCKET_PING_INTERVAL": 20,
|
||||
"WEBSOCKET_PING_TIMEOUT": 20,
|
||||
"WEBSOCKET_READ_LIMIT": 2 ** 16,
|
||||
"WEBSOCKET_WRITE_LIMIT": 2 ** 16,
|
||||
}
|
||||
|
||||
|
||||
@@ -62,12 +60,10 @@ class Config(dict):
|
||||
REQUEST_MAX_SIZE: int
|
||||
REQUEST_TIMEOUT: int
|
||||
RESPONSE_TIMEOUT: int
|
||||
WEBSOCKET_MAX_QUEUE: int
|
||||
SERVER_NAME: str
|
||||
WEBSOCKET_MAX_SIZE: int
|
||||
WEBSOCKET_PING_INTERVAL: int
|
||||
WEBSOCKET_PING_TIMEOUT: int
|
||||
WEBSOCKET_READ_LIMIT: int
|
||||
WEBSOCKET_WRITE_LIMIT: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -100,6 +96,7 @@ class Config(dict):
|
||||
self.load_environment_vars(SANIC_PREFIX)
|
||||
|
||||
self._configure_header_size()
|
||||
self._check_error_format()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
try:
|
||||
@@ -115,6 +112,8 @@ class Config(dict):
|
||||
"REQUEST_MAX_SIZE",
|
||||
):
|
||||
self._configure_header_size()
|
||||
elif attr == "FALLBACK_ERROR_FORMAT":
|
||||
self._check_error_format()
|
||||
|
||||
def _configure_header_size(self):
|
||||
Http.set_header_max_size(
|
||||
@@ -123,6 +122,9 @@ class Config(dict):
|
||||
self.REQUEST_MAX_SIZE,
|
||||
)
|
||||
|
||||
def _check_error_format(self):
|
||||
check_error_format(self.FALLBACK_ERROR_FORMAT)
|
||||
|
||||
def load_environment_vars(self, prefix=SANIC_PREFIX):
|
||||
"""
|
||||
Looks for prefixed environment variables and applies
|
||||
|
||||
@@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = {
|
||||
}
|
||||
|
||||
RENDERERS_BY_CONTENT_TYPE = {
|
||||
"multipart/form-data": HTMLRenderer,
|
||||
"application/json": JSONRenderer,
|
||||
"text/plain": TextRenderer,
|
||||
"application/json": JSONRenderer,
|
||||
"multipart/form-data": HTMLRenderer,
|
||||
"text/html": HTMLRenderer,
|
||||
}
|
||||
CONTENT_TYPE_BY_RENDERERS = {
|
||||
v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items()
|
||||
}
|
||||
|
||||
RESPONSE_MAPPING = {
|
||||
"empty": "html",
|
||||
"json": "json",
|
||||
"text": "text",
|
||||
"raw": "text",
|
||||
"html": "html",
|
||||
"file": "html",
|
||||
"file_stream": "text",
|
||||
"stream": "text",
|
||||
"redirect": "html",
|
||||
"text/plain": "text",
|
||||
"text/html": "html",
|
||||
"application/json": "json",
|
||||
}
|
||||
|
||||
|
||||
def check_error_format(format):
|
||||
if format not in RENDERERS_BY_CONFIG and format != "auto":
|
||||
raise SanicException(f"Unknown format: {format}")
|
||||
|
||||
|
||||
def exception_response(
|
||||
request: Request,
|
||||
exception: Exception,
|
||||
debug: bool,
|
||||
fallback: str,
|
||||
base: t.Type[BaseRenderer],
|
||||
renderer: t.Type[t.Optional[BaseRenderer]] = None,
|
||||
) -> HTTPResponse:
|
||||
"""
|
||||
Render a response for the default FALLBACK exception handler.
|
||||
"""
|
||||
content_type = None
|
||||
|
||||
if not renderer:
|
||||
renderer = HTMLRenderer
|
||||
# Make sure we have something set
|
||||
renderer = base
|
||||
render_format = fallback
|
||||
|
||||
if request:
|
||||
if request.app.config.FALLBACK_ERROR_FORMAT == "auto":
|
||||
# If there is a request, try and get the format
|
||||
# from the route
|
||||
if request.route:
|
||||
try:
|
||||
renderer = JSONRenderer if request.json else HTMLRenderer
|
||||
except InvalidUsage:
|
||||
render_format = request.route.ctx.error_format
|
||||
except AttributeError:
|
||||
...
|
||||
|
||||
content_type = request.headers.getone("content-type", "").split(
|
||||
";"
|
||||
)[0]
|
||||
|
||||
acceptable = request.accept
|
||||
|
||||
# If the format is auto still, make a guess
|
||||
if render_format == "auto":
|
||||
# First, if there is an Accept header, check if text/html
|
||||
# is the first option
|
||||
# According to MDN Web Docs, all major browsers use text/html
|
||||
# as the primary value in Accept (with the exception of IE 8,
|
||||
# and, well, if you are supporting IE 8, then you have bigger
|
||||
# problems to concern yourself with than what default exception
|
||||
# renderer is used)
|
||||
# Source:
|
||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values
|
||||
|
||||
if acceptable and acceptable[0].match(
|
||||
"text/html",
|
||||
allow_type_wildcard=False,
|
||||
allow_subtype_wildcard=False,
|
||||
):
|
||||
renderer = HTMLRenderer
|
||||
|
||||
content_type, *_ = request.headers.getone(
|
||||
"content-type", ""
|
||||
).split(";")
|
||||
renderer = RENDERERS_BY_CONTENT_TYPE.get(
|
||||
content_type, renderer
|
||||
# Second, if there is an Accept header, check if
|
||||
# application/json is an option, or if the content-type
|
||||
# is application/json
|
||||
elif (
|
||||
acceptable
|
||||
and acceptable.match(
|
||||
"application/json",
|
||||
allow_type_wildcard=False,
|
||||
allow_subtype_wildcard=False,
|
||||
)
|
||||
or content_type == "application/json"
|
||||
):
|
||||
renderer = JSONRenderer
|
||||
|
||||
# Third, if there is no Accept header, assume we want text.
|
||||
# The likely use case here is a raw socket.
|
||||
elif not acceptable:
|
||||
renderer = TextRenderer
|
||||
else:
|
||||
# Fourth, look to see if there was a JSON body
|
||||
# When in this situation, the request is probably coming
|
||||
# from curl, an API client like Postman or Insomnia, or a
|
||||
# package like requests or httpx
|
||||
try:
|
||||
# Give them the benefit of the doubt if they did:
|
||||
# $ curl localhost:8000 -d '{"foo": "bar"}'
|
||||
# And provide them with JSONRenderer
|
||||
renderer = JSONRenderer if request.json else base
|
||||
except InvalidUsage:
|
||||
renderer = base
|
||||
else:
|
||||
render_format = request.app.config.FALLBACK_ERROR_FORMAT
|
||||
renderer = RENDERERS_BY_CONFIG.get(render_format, renderer)
|
||||
|
||||
# Lastly, if there is an Accept header, make sure
|
||||
# our choice is okay
|
||||
if acceptable:
|
||||
type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore
|
||||
if type_ and type_ not in acceptable:
|
||||
# If the renderer selected is not in the Accept header
|
||||
# look through what is in the Accept header, and select
|
||||
# the first option that matches. Otherwise, just drop back
|
||||
# to the original default
|
||||
for accept in acceptable:
|
||||
mtype = f"{accept.type_}/{accept.subtype}"
|
||||
maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype)
|
||||
if maybe:
|
||||
renderer = maybe
|
||||
break
|
||||
else:
|
||||
renderer = base
|
||||
|
||||
renderer = t.cast(t.Type[BaseRenderer], renderer)
|
||||
return renderer(request, exception, debug).render()
|
||||
|
||||
@@ -4,14 +4,18 @@ from sanic.helpers import STATUS_CODES
|
||||
|
||||
|
||||
class SanicException(Exception):
|
||||
message: str = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: Optional[Union[str, bytes]] = None,
|
||||
status_code: Optional[int] = None,
|
||||
quiet: Optional[bool] = None,
|
||||
) -> None:
|
||||
|
||||
if message is None and status_code is not None:
|
||||
if message is None:
|
||||
if self.message:
|
||||
message = self.message
|
||||
elif status_code is not None:
|
||||
msg: bytes = STATUS_CODES.get(status_code, b"")
|
||||
message = msg.decode("utf8")
|
||||
|
||||
@@ -31,6 +35,7 @@ class NotFound(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 404
|
||||
quiet = True
|
||||
|
||||
|
||||
class InvalidUsage(SanicException):
|
||||
@@ -39,6 +44,7 @@ class InvalidUsage(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 400
|
||||
quiet = True
|
||||
|
||||
|
||||
class MethodNotSupported(SanicException):
|
||||
@@ -47,6 +53,7 @@ class MethodNotSupported(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 405
|
||||
quiet = True
|
||||
|
||||
def __init__(self, message, method, allowed_methods):
|
||||
super().__init__(message)
|
||||
@@ -70,6 +77,7 @@ class ServiceUnavailable(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 503
|
||||
quiet = True
|
||||
|
||||
|
||||
class URLBuildError(ServerError):
|
||||
@@ -101,6 +109,7 @@ class RequestTimeout(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 408
|
||||
quiet = True
|
||||
|
||||
|
||||
class PayloadTooLarge(SanicException):
|
||||
@@ -109,6 +118,7 @@ class PayloadTooLarge(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 413
|
||||
quiet = True
|
||||
|
||||
|
||||
class HeaderNotFound(InvalidUsage):
|
||||
@@ -116,7 +126,11 @@ class HeaderNotFound(InvalidUsage):
|
||||
**Status**: 400 Bad Request
|
||||
"""
|
||||
|
||||
status_code = 400
|
||||
|
||||
class InvalidHeader(InvalidUsage):
|
||||
"""
|
||||
**Status**: 400 Bad Request
|
||||
"""
|
||||
|
||||
|
||||
class ContentRangeError(SanicException):
|
||||
@@ -125,6 +139,7 @@ class ContentRangeError(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 416
|
||||
quiet = True
|
||||
|
||||
def __init__(self, message, content_range):
|
||||
super().__init__(message)
|
||||
@@ -137,6 +152,7 @@ class HeaderExpectationFailed(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 417
|
||||
quiet = True
|
||||
|
||||
|
||||
class Forbidden(SanicException):
|
||||
@@ -145,6 +161,7 @@ class Forbidden(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 403
|
||||
quiet = True
|
||||
|
||||
|
||||
class InvalidRangeType(ContentRangeError):
|
||||
@@ -153,6 +170,7 @@ class InvalidRangeType(ContentRangeError):
|
||||
"""
|
||||
|
||||
status_code = 416
|
||||
quiet = True
|
||||
|
||||
|
||||
class PyFileError(Exception):
|
||||
@@ -196,6 +214,7 @@ class Unauthorized(SanicException):
|
||||
"""
|
||||
|
||||
status_code = 401
|
||||
quiet = True
|
||||
|
||||
def __init__(self, message, status_code=None, scheme=None, **kwargs):
|
||||
super().__init__(message, status_code)
|
||||
@@ -218,6 +237,11 @@ class InvalidSignal(SanicException):
|
||||
pass
|
||||
|
||||
|
||||
class WebsocketClosed(SanicException):
|
||||
quiet = True
|
||||
message = "Client has closed the websocket connection"
|
||||
|
||||
|
||||
def abort(status_code: int, message: Optional[Union[str, bytes]] = None):
|
||||
"""
|
||||
Raise an exception based on SanicException. Returns the HTTP response
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from traceback import format_exc
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from sanic.errorpages import exception_response
|
||||
from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response
|
||||
from sanic.exceptions import (
|
||||
ContentRangeError,
|
||||
HeaderNotFound,
|
||||
InvalidRangeType,
|
||||
)
|
||||
from sanic.log import error_logger
|
||||
from sanic.models.handler_types import RouteHandler
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
@@ -23,15 +25,47 @@ class ErrorHandler:
|
||||
|
||||
"""
|
||||
|
||||
handlers = None
|
||||
cached_handlers = None
|
||||
|
||||
def __init__(self):
|
||||
self.handlers = []
|
||||
self.cached_handlers = {}
|
||||
# Beginning in v22.3, the base renderer will be TextRenderer
|
||||
def __init__(
|
||||
self, fallback: str = "auto", base: Type[BaseRenderer] = HTMLRenderer
|
||||
):
|
||||
self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = []
|
||||
self.cached_handlers: Dict[
|
||||
Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler]
|
||||
] = {}
|
||||
self.debug = False
|
||||
self.fallback = fallback
|
||||
self.base = base
|
||||
|
||||
def add(self, exception, handler):
|
||||
@classmethod
|
||||
def finalize(cls, error_handler):
|
||||
if not isinstance(error_handler, cls):
|
||||
error_logger.warning(
|
||||
f"Error handler is non-conforming: {type(error_handler)}"
|
||||
)
|
||||
|
||||
sig = signature(error_handler.lookup)
|
||||
if len(sig.parameters) == 1:
|
||||
error_logger.warning(
|
||||
DeprecationWarning(
|
||||
"You are using a deprecated error handler. The lookup "
|
||||
"method should accept two positional parameters: "
|
||||
"(exception, route_name: Optional[str]). "
|
||||
"Until you upgrade your ErrorHandler.lookup, Blueprint "
|
||||
"specific exceptions will not work properly. Beginning "
|
||||
"in v22.3, the legacy style lookup method will not "
|
||||
"work at all."
|
||||
),
|
||||
)
|
||||
error_handler._lookup = error_handler._legacy_lookup
|
||||
|
||||
def _full_lookup(self, exception, route_name: Optional[str] = None):
|
||||
return self.lookup(exception, route_name)
|
||||
|
||||
def _legacy_lookup(self, exception, route_name: Optional[str] = None):
|
||||
return self.lookup(exception)
|
||||
|
||||
def add(self, exception, handler, route_names: Optional[List[str]] = None):
|
||||
"""
|
||||
Add a new exception handler to an already existing handler object.
|
||||
|
||||
@@ -44,11 +78,16 @@ class ErrorHandler:
|
||||
|
||||
:return: None
|
||||
"""
|
||||
# self.handlers to be deprecated and removed in version 21.12
|
||||
# self.handlers is deprecated and will be removed in version 22.3
|
||||
self.handlers.append((exception, handler))
|
||||
self.cached_handlers[exception] = handler
|
||||
|
||||
def lookup(self, exception):
|
||||
if route_names:
|
||||
for route in route_names:
|
||||
self.cached_handlers[(exception, route)] = handler
|
||||
else:
|
||||
self.cached_handlers[(exception, None)] = handler
|
||||
|
||||
def lookup(self, exception, route_name: Optional[str] = None):
|
||||
"""
|
||||
Lookup the existing instance of :class:`ErrorHandler` and fetch the
|
||||
registered handler for a specific type of exception.
|
||||
@@ -63,20 +102,31 @@ class ErrorHandler:
|
||||
:return: Registered function if found ``None`` otherwise
|
||||
"""
|
||||
exception_class = type(exception)
|
||||
if exception_class in self.cached_handlers:
|
||||
return self.cached_handlers[exception_class]
|
||||
|
||||
for ancestor in type.mro(exception_class):
|
||||
if ancestor in self.cached_handlers:
|
||||
handler = self.cached_handlers[ancestor]
|
||||
self.cached_handlers[exception_class] = handler
|
||||
for name in (route_name, None):
|
||||
exception_key = (exception_class, name)
|
||||
handler = self.cached_handlers.get(exception_key)
|
||||
if handler:
|
||||
return handler
|
||||
|
||||
for name in (route_name, None):
|
||||
for ancestor in type.mro(exception_class):
|
||||
exception_key = (ancestor, name)
|
||||
if exception_key in self.cached_handlers:
|
||||
handler = self.cached_handlers[exception_key]
|
||||
self.cached_handlers[
|
||||
(exception_class, route_name)
|
||||
] = handler
|
||||
return handler
|
||||
|
||||
if ancestor is BaseException:
|
||||
break
|
||||
self.cached_handlers[exception_class] = None
|
||||
self.cached_handlers[(exception_class, route_name)] = None
|
||||
handler = None
|
||||
return handler
|
||||
|
||||
_lookup = _full_lookup
|
||||
|
||||
def response(self, request, exception):
|
||||
"""Fetches and executes an exception handler and returns a response
|
||||
object
|
||||
@@ -91,7 +141,8 @@ class ErrorHandler:
|
||||
:return: Wrap the return value obtained from :func:`default`
|
||||
or registered handler for that type of exception.
|
||||
"""
|
||||
handler = self.lookup(exception)
|
||||
route_name = request.name if request else None
|
||||
handler = self._lookup(exception, route_name)
|
||||
response = None
|
||||
try:
|
||||
if handler:
|
||||
@@ -99,7 +150,6 @@ class ErrorHandler:
|
||||
if response is None:
|
||||
response = self.default(request, exception)
|
||||
except Exception:
|
||||
self.log(format_exc())
|
||||
try:
|
||||
url = repr(request.url)
|
||||
except AttributeError:
|
||||
@@ -115,11 +165,6 @@ class ErrorHandler:
|
||||
return text("An error occurred while handling an error", 500)
|
||||
return response
|
||||
|
||||
def log(self, message, level="error"):
|
||||
"""
|
||||
Deprecated, do not use.
|
||||
"""
|
||||
|
||||
def default(self, request, exception):
|
||||
"""
|
||||
Provide a default behavior for the objects of :class:`ErrorHandler`.
|
||||
@@ -135,6 +180,17 @@ class ErrorHandler:
|
||||
:class:`Exception`
|
||||
:return:
|
||||
"""
|
||||
self.log(request, exception)
|
||||
return exception_response(
|
||||
request,
|
||||
exception,
|
||||
debug=self.debug,
|
||||
base=self.base,
|
||||
fallback=self.fallback,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log(request, exception):
|
||||
quiet = getattr(exception, "quiet", False)
|
||||
if quiet is False:
|
||||
try:
|
||||
@@ -142,13 +198,10 @@ class ErrorHandler:
|
||||
except AttributeError:
|
||||
url = "unknown"
|
||||
|
||||
self.log(format_exc())
|
||||
error_logger.exception(
|
||||
"Exception occurred while handling uri: %s", url
|
||||
)
|
||||
|
||||
return exception_response(request, exception, self.debug)
|
||||
|
||||
|
||||
class ContentRangeHandler:
|
||||
"""
|
||||
|
||||
200
sanic/headers.py
200
sanic/headers.py
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
from sanic.exceptions import InvalidHeader
|
||||
from sanic.helpers import STATUS_CODES
|
||||
|
||||
|
||||
@@ -30,6 +33,175 @@ _host_re = re.compile(
|
||||
# For more information, consult ../tests/test_requests.py
|
||||
|
||||
|
||||
def parse_arg_as_accept(f):
|
||||
def func(self, other, *args, **kwargs):
|
||||
if not isinstance(other, Accept) and other:
|
||||
other = Accept.parse(other)
|
||||
return f(self, other, *args, **kwargs)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class MediaType(str):
|
||||
def __new__(cls, value: str):
|
||||
return str.__new__(cls, value)
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
self.is_wildcard = self.check_if_wildcard(value)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.is_wildcard:
|
||||
return True
|
||||
|
||||
if self.match(other):
|
||||
return True
|
||||
|
||||
other_is_wildcard = (
|
||||
other.is_wildcard
|
||||
if isinstance(other, MediaType)
|
||||
else self.check_if_wildcard(other)
|
||||
)
|
||||
|
||||
return other_is_wildcard
|
||||
|
||||
def match(self, other):
|
||||
other_value = other.value if isinstance(other, MediaType) else other
|
||||
return self.value == other_value
|
||||
|
||||
@staticmethod
|
||||
def check_if_wildcard(value):
|
||||
return value == "*"
|
||||
|
||||
|
||||
class Accept(str):
|
||||
def __new__(cls, value: str, *args, **kwargs):
|
||||
return str.__new__(cls, value)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: str,
|
||||
type_: MediaType,
|
||||
subtype: MediaType,
|
||||
*,
|
||||
q: str = "1.0",
|
||||
**kwargs: str,
|
||||
):
|
||||
qvalue = float(q)
|
||||
if qvalue > 1 or qvalue < 0:
|
||||
raise InvalidHeader(
|
||||
f"Accept header qvalue must be between 0 and 1, not: {qvalue}"
|
||||
)
|
||||
self.value = value
|
||||
self.type_ = type_
|
||||
self.subtype = subtype
|
||||
self.qvalue = qvalue
|
||||
self.params = kwargs
|
||||
|
||||
def _compare(self, other, method):
|
||||
try:
|
||||
return method(self.qvalue, other.qvalue)
|
||||
except (AttributeError, TypeError):
|
||||
return NotImplemented
|
||||
|
||||
@parse_arg_as_accept
|
||||
def __lt__(self, other: Union[str, Accept]):
|
||||
return self._compare(other, lambda s, o: s < o)
|
||||
|
||||
@parse_arg_as_accept
|
||||
def __le__(self, other: Union[str, Accept]):
|
||||
return self._compare(other, lambda s, o: s <= o)
|
||||
|
||||
@parse_arg_as_accept
|
||||
def __eq__(self, other: Union[str, Accept]): # type: ignore
|
||||
return self._compare(other, lambda s, o: s == o)
|
||||
|
||||
@parse_arg_as_accept
|
||||
def __ge__(self, other: Union[str, Accept]):
|
||||
return self._compare(other, lambda s, o: s >= o)
|
||||
|
||||
@parse_arg_as_accept
|
||||
def __gt__(self, other: Union[str, Accept]):
|
||||
return self._compare(other, lambda s, o: s > o)
|
||||
|
||||
@parse_arg_as_accept
|
||||
def __ne__(self, other: Union[str, Accept]): # type: ignore
|
||||
return self._compare(other, lambda s, o: s != o)
|
||||
|
||||
@parse_arg_as_accept
|
||||
def match(
|
||||
self,
|
||||
other,
|
||||
*,
|
||||
allow_type_wildcard: bool = True,
|
||||
allow_subtype_wildcard: bool = True,
|
||||
) -> bool:
|
||||
type_match = (
|
||||
self.type_ == other.type_
|
||||
if allow_type_wildcard
|
||||
else (
|
||||
self.type_.match(other.type_)
|
||||
and not self.type_.is_wildcard
|
||||
and not other.type_.is_wildcard
|
||||
)
|
||||
)
|
||||
subtype_match = (
|
||||
self.subtype == other.subtype
|
||||
if allow_subtype_wildcard
|
||||
else (
|
||||
self.subtype.match(other.subtype)
|
||||
and not self.subtype.is_wildcard
|
||||
and not other.subtype.is_wildcard
|
||||
)
|
||||
)
|
||||
|
||||
return type_match and subtype_match
|
||||
|
||||
@classmethod
|
||||
def parse(cls, raw: str) -> Accept:
|
||||
invalid = False
|
||||
mtype = raw.strip()
|
||||
|
||||
try:
|
||||
media, *raw_params = mtype.split(";")
|
||||
type_, subtype = media.split("/")
|
||||
except ValueError:
|
||||
invalid = True
|
||||
|
||||
if invalid or not type_ or not subtype:
|
||||
raise InvalidHeader(f"Header contains invalid Accept value: {raw}")
|
||||
|
||||
params = dict(
|
||||
[
|
||||
(key.strip(), value.strip())
|
||||
for key, value in (param.split("=", 1) for param in raw_params)
|
||||
]
|
||||
)
|
||||
|
||||
return cls(mtype, MediaType(type_), MediaType(subtype), **params)
|
||||
|
||||
|
||||
class AcceptContainer(list):
|
||||
def __contains__(self, o: object) -> bool:
|
||||
return any(item.match(o) for item in self)
|
||||
|
||||
def match(
|
||||
self,
|
||||
o: object,
|
||||
*,
|
||||
allow_type_wildcard: bool = True,
|
||||
allow_subtype_wildcard: bool = True,
|
||||
) -> bool:
|
||||
return any(
|
||||
item.match(
|
||||
o,
|
||||
allow_type_wildcard=allow_type_wildcard,
|
||||
allow_subtype_wildcard=allow_subtype_wildcard,
|
||||
)
|
||||
for item in self
|
||||
)
|
||||
|
||||
|
||||
def parse_content_header(value: str) -> Tuple[str, Options]:
|
||||
"""Parse content-type and content-disposition header values.
|
||||
|
||||
@@ -194,3 +366,31 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:
|
||||
ret += b"%b: %b\r\n" % h
|
||||
ret += b"\r\n"
|
||||
return ret
|
||||
|
||||
|
||||
def _sort_accept_value(accept: Accept):
|
||||
return (
|
||||
accept.qvalue,
|
||||
len(accept.params),
|
||||
accept.subtype != "*",
|
||||
accept.type_ != "*",
|
||||
)
|
||||
|
||||
|
||||
def parse_accept(accept: str) -> AcceptContainer:
|
||||
"""Parse an Accept header and order the acceptable media types in
|
||||
accorsing to RFC 7231, s. 5.3.2
|
||||
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
|
||||
"""
|
||||
media_types = accept.split(",")
|
||||
accept_list: List[Accept] = []
|
||||
|
||||
for mtype in media_types:
|
||||
if not mtype:
|
||||
continue
|
||||
|
||||
accept_list.append(Accept.parse(mtype))
|
||||
|
||||
return AcceptContainer(
|
||||
sorted(accept_list, key=_sort_accept_value, reverse=True)
|
||||
)
|
||||
|
||||
@@ -155,3 +155,17 @@ def import_string(module_name, package=None):
|
||||
if ismodule(obj):
|
||||
return obj
|
||||
return obj()
|
||||
|
||||
|
||||
class Default:
|
||||
"""
|
||||
It is used to replace `None` or `object()` as a sentinel
|
||||
that represents a default value. Sometimes we want to set
|
||||
a value to `None` so we cannot use `None` to represent the
|
||||
default value, and `object()` is hard to be typed.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_default = Default()
|
||||
|
||||
@@ -21,6 +21,7 @@ from sanic.exceptions import (
|
||||
from sanic.headers import format_http1_response
|
||||
from sanic.helpers import has_message_body
|
||||
from sanic.log import access_logger, error_logger, logger
|
||||
from sanic.touchup import TouchUpMeta
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
@@ -45,7 +46,7 @@ class Stage(Enum):
|
||||
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
|
||||
|
||||
|
||||
class Http:
|
||||
class Http(metaclass=TouchUpMeta):
|
||||
"""
|
||||
Internal helper for managing the HTTP request/response cycle
|
||||
|
||||
@@ -67,9 +68,15 @@ class Http:
|
||||
HEADER_CEILING = 16_384
|
||||
HEADER_MAX_SIZE = 0
|
||||
|
||||
__touchup__ = (
|
||||
"http1_request_header",
|
||||
"http1_response_header",
|
||||
"read",
|
||||
)
|
||||
__slots__ = [
|
||||
"_send",
|
||||
"_receive_more",
|
||||
"dispatch",
|
||||
"recv_buffer",
|
||||
"protocol",
|
||||
"expecting_continue",
|
||||
@@ -95,19 +102,24 @@ class Http:
|
||||
self._receive_more = protocol.receive_more
|
||||
self.recv_buffer = protocol.recv_buffer
|
||||
self.protocol = protocol
|
||||
self.expecting_continue: bool = False
|
||||
self.keep_alive = True
|
||||
self.stage: Stage = Stage.IDLE
|
||||
self.dispatch = self.protocol.app.dispatch
|
||||
self.init_for_request()
|
||||
|
||||
def init_for_request(self):
|
||||
"""Init/reset all per-request variables."""
|
||||
self.exception = None
|
||||
self.expecting_continue: bool = False
|
||||
self.head_only = None
|
||||
self.request_body = None
|
||||
self.request_bytes = None
|
||||
self.request_bytes_left = None
|
||||
self.request_max_size = protocol.request_max_size
|
||||
self.keep_alive = True
|
||||
self.head_only = None
|
||||
self.request_max_size = self.protocol.request_max_size
|
||||
self.request: Request = None
|
||||
self.response: BaseHTTPResponse = None
|
||||
self.exception = None
|
||||
self.url = None
|
||||
self.upgrade_websocket = False
|
||||
self.url = None
|
||||
|
||||
def __bool__(self):
|
||||
"""Test if request handling is in progress"""
|
||||
@@ -136,6 +148,12 @@ class Http:
|
||||
await self.response.send(end_stream=True)
|
||||
except CancelledError:
|
||||
# Write an appropriate response before exiting
|
||||
if not self.protocol.transport:
|
||||
logger.info(
|
||||
f"Request: {self.request.method} {self.request.url} "
|
||||
"stopped. Transport is closed."
|
||||
)
|
||||
return
|
||||
e = self.exception or ServiceUnavailable("Cancelled")
|
||||
self.exception = None
|
||||
self.keep_alive = False
|
||||
@@ -148,7 +166,10 @@ class Http:
|
||||
if self.request_body:
|
||||
if self.response and 200 <= self.response.status < 300:
|
||||
error_logger.error(f"{self.request} body not consumed.")
|
||||
|
||||
# Limit the size because the handler may have set it infinite
|
||||
self.request_max_size = min(
|
||||
self.request_max_size, self.protocol.request_max_size
|
||||
)
|
||||
try:
|
||||
async for _ in self:
|
||||
pass
|
||||
@@ -160,15 +181,23 @@ class Http:
|
||||
await sleep(0.001)
|
||||
self.keep_alive = False
|
||||
|
||||
# Clean up to free memory and for the next request
|
||||
if self.request:
|
||||
self.request.stream = None
|
||||
if self.response:
|
||||
self.response.stream = None
|
||||
|
||||
# Exit and disconnect if no more requests can be taken
|
||||
if self.stage is not Stage.IDLE or not self.keep_alive:
|
||||
break
|
||||
|
||||
# Wait for next request
|
||||
self.init_for_request()
|
||||
|
||||
# Wait for the next request
|
||||
if not self.recv_buffer:
|
||||
await self._receive_more()
|
||||
|
||||
async def http1_request_header(self):
|
||||
async def http1_request_header(self): # no cov
|
||||
"""
|
||||
Receive and parse request header into self.request.
|
||||
"""
|
||||
@@ -197,6 +226,12 @@ class Http:
|
||||
reqline, *split_headers = raw_headers.split("\r\n")
|
||||
method, self.url, protocol = reqline.split(" ")
|
||||
|
||||
await self.dispatch(
|
||||
"http.lifecycle.read_head",
|
||||
inline=True,
|
||||
context={"head": bytes(head)},
|
||||
)
|
||||
|
||||
if protocol == "HTTP/1.1":
|
||||
self.keep_alive = True
|
||||
elif protocol == "HTTP/1.0":
|
||||
@@ -235,6 +270,11 @@ class Http:
|
||||
transport=self.protocol.transport,
|
||||
app=self.protocol.app,
|
||||
)
|
||||
await self.dispatch(
|
||||
"http.lifecycle.request",
|
||||
inline=True,
|
||||
context={"request": request},
|
||||
)
|
||||
|
||||
# Prepare for request body
|
||||
self.request_bytes_left = self.request_bytes = 0
|
||||
@@ -265,7 +305,7 @@ class Http:
|
||||
|
||||
async def http1_response_header(
|
||||
self, data: bytes, end_stream: bool
|
||||
) -> None:
|
||||
) -> None: # no cov
|
||||
res = self.response
|
||||
|
||||
# Compatibility with simple response body
|
||||
@@ -437,8 +477,8 @@ class Http:
|
||||
"request": "nil",
|
||||
}
|
||||
if req is not None:
|
||||
if req.ip:
|
||||
extra["host"] = f"{req.ip}:{req.port}"
|
||||
if req.remote_addr or req.ip:
|
||||
extra["host"] = f"{req.remote_addr or req.ip}:{req.port}"
|
||||
extra["request"] = f"{req.method} {req.url}"
|
||||
access_logger.info("", extra=extra)
|
||||
|
||||
@@ -454,7 +494,7 @@ class Http:
|
||||
if data:
|
||||
yield data
|
||||
|
||||
async def read(self) -> Optional[bytes]:
|
||||
async def read(self) -> Optional[bytes]: # no cov
|
||||
"""
|
||||
Read some bytes of request body.
|
||||
"""
|
||||
@@ -486,8 +526,6 @@ class Http:
|
||||
self.keep_alive = False
|
||||
raise InvalidUsage("Bad chunked encoding")
|
||||
|
||||
del buf[: pos + 2]
|
||||
|
||||
if size <= 0:
|
||||
self.request_body = None
|
||||
|
||||
@@ -495,8 +533,17 @@ class Http:
|
||||
self.keep_alive = False
|
||||
raise InvalidUsage("Bad chunked encoding")
|
||||
|
||||
# Consume CRLF, chunk size 0 and the two CRLF that follow
|
||||
pos += 4
|
||||
# Might need to wait for the final CRLF
|
||||
while len(buf) < pos:
|
||||
await self._receive_more()
|
||||
del buf[:pos]
|
||||
return None
|
||||
|
||||
# Remove CRLF, chunk size and the CRLF that follows
|
||||
del buf[: pos + 2]
|
||||
|
||||
self.request_bytes_left = size
|
||||
self.request_bytes += size
|
||||
|
||||
@@ -521,6 +568,12 @@ class Http:
|
||||
|
||||
self.request_bytes_left -= size
|
||||
|
||||
await self.dispatch(
|
||||
"http.lifecycle.read_body",
|
||||
inline=True,
|
||||
context={"body": data},
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
# Response methods
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Coroutine, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sanic.models.futures import FutureListener
|
||||
from sanic.models.handler_types import ListenerType
|
||||
|
||||
|
||||
class ListenerEvent(str, Enum):
|
||||
def _generate_next_value_(name: str, *args) -> str: # type: ignore
|
||||
return name.lower()
|
||||
|
||||
BEFORE_SERVER_START = auto()
|
||||
AFTER_SERVER_START = auto()
|
||||
BEFORE_SERVER_STOP = auto()
|
||||
AFTER_SERVER_STOP = auto()
|
||||
BEFORE_SERVER_START = "server.init.before"
|
||||
AFTER_SERVER_START = "server.init.after"
|
||||
BEFORE_SERVER_STOP = "server.shutdown.before"
|
||||
AFTER_SERVER_STOP = "server.shutdown.after"
|
||||
MAIN_PROCESS_START = auto()
|
||||
MAIN_PROCESS_STOP = auto()
|
||||
|
||||
@@ -26,9 +27,7 @@ class ListenerMixin:
|
||||
|
||||
def listener(
|
||||
self,
|
||||
listener_or_event: Union[
|
||||
Callable[..., Coroutine[Any, Any, None]], str
|
||||
],
|
||||
listener_or_event: Union[ListenerType, str],
|
||||
event_or_none: Optional[str] = None,
|
||||
apply: bool = True,
|
||||
):
|
||||
@@ -63,20 +62,20 @@ class ListenerMixin:
|
||||
else:
|
||||
return partial(register_listener, event=listener_or_event)
|
||||
|
||||
def main_process_start(self, listener):
|
||||
def main_process_start(self, listener: ListenerType) -> ListenerType:
|
||||
return self.listener(listener, "main_process_start")
|
||||
|
||||
def main_process_stop(self, listener):
|
||||
def main_process_stop(self, listener: ListenerType) -> ListenerType:
|
||||
return self.listener(listener, "main_process_stop")
|
||||
|
||||
def before_server_start(self, listener):
|
||||
def before_server_start(self, listener: ListenerType) -> ListenerType:
|
||||
return self.listener(listener, "before_server_start")
|
||||
|
||||
def after_server_start(self, listener):
|
||||
def after_server_start(self, listener: ListenerType) -> ListenerType:
|
||||
return self.listener(listener, "after_server_start")
|
||||
|
||||
def before_server_stop(self, listener):
|
||||
def before_server_stop(self, listener: ListenerType) -> ListenerType:
|
||||
return self.listener(listener, "before_server_stop")
|
||||
|
||||
def after_server_stop(self, listener):
|
||||
def after_server_stop(self, listener: ListenerType) -> ListenerType:
|
||||
return self.listener(listener, "after_server_stop")
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
from ast import NodeVisitor, Return, parse
|
||||
from functools import partial, wraps
|
||||
from inspect import signature
|
||||
from inspect import getsource, signature
|
||||
from mimetypes import guess_type
|
||||
from os import path
|
||||
from pathlib import PurePath
|
||||
from re import sub
|
||||
from textwrap import dedent
|
||||
from time import gmtime, strftime
|
||||
from typing import Iterable, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
from sanic_routing.route import Route # type: ignore
|
||||
|
||||
from sanic.compat import stat_async
|
||||
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS
|
||||
from sanic.errorpages import RESPONSE_MAPPING
|
||||
from sanic.exceptions import (
|
||||
ContentRangeError,
|
||||
FileNotFound,
|
||||
@@ -21,10 +24,16 @@ from sanic.exceptions import (
|
||||
from sanic.handlers import ContentRangeHandler
|
||||
from sanic.log import error_logger
|
||||
from sanic.models.futures import FutureRoute, FutureStatic
|
||||
from sanic.models.handler_types import RouteHandler
|
||||
from sanic.response import HTTPResponse, file, file_stream
|
||||
from sanic.views import CompositionView
|
||||
|
||||
|
||||
RouteWrapper = Callable[
|
||||
[RouteHandler], Union[RouteHandler, Tuple[Route, RouteHandler]]
|
||||
]
|
||||
|
||||
|
||||
class RouteMixin:
|
||||
name: str
|
||||
|
||||
@@ -55,7 +64,8 @@ class RouteMixin:
|
||||
unquote: bool = False,
|
||||
static: bool = False,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Decorate a function to be registered as a route
|
||||
|
||||
@@ -97,6 +107,7 @@ class RouteMixin:
|
||||
nonlocal websocket
|
||||
nonlocal static
|
||||
nonlocal version_prefix
|
||||
nonlocal error_format
|
||||
|
||||
if isinstance(handler, tuple):
|
||||
# if a handler fn is already wrapped in a route, the handler
|
||||
@@ -115,10 +126,16 @@ class RouteMixin:
|
||||
"Expected either string or Iterable of host strings, "
|
||||
"not %s" % host
|
||||
)
|
||||
|
||||
if isinstance(subprotocols, (list, tuple, set)):
|
||||
if isinstance(subprotocols, list):
|
||||
# Ordered subprotocols, maintain order
|
||||
subprotocols = tuple(subprotocols)
|
||||
elif isinstance(subprotocols, set):
|
||||
# subprotocol is unordered, keep it unordered
|
||||
subprotocols = frozenset(subprotocols)
|
||||
|
||||
if not error_format or error_format == "auto":
|
||||
error_format = self._determine_error_format(handler)
|
||||
|
||||
route = FutureRoute(
|
||||
handler,
|
||||
uri,
|
||||
@@ -134,6 +151,7 @@ class RouteMixin:
|
||||
unquote,
|
||||
static,
|
||||
version_prefix,
|
||||
error_format,
|
||||
)
|
||||
|
||||
self._future_routes.add(route)
|
||||
@@ -168,7 +186,7 @@ class RouteMixin:
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
handler,
|
||||
handler: RouteHandler,
|
||||
uri: str,
|
||||
methods: Iterable[str] = frozenset({"GET"}),
|
||||
host: Optional[str] = None,
|
||||
@@ -177,7 +195,8 @@ class RouteMixin:
|
||||
name: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteHandler:
|
||||
"""A helper method to register class instance or
|
||||
functions as a handler to the application url
|
||||
routes.
|
||||
@@ -200,7 +219,8 @@ class RouteMixin:
|
||||
methods = set()
|
||||
|
||||
for method in HTTP_METHODS:
|
||||
_handler = getattr(handler.view_class, method.lower(), None)
|
||||
view_class = getattr(handler, "view_class")
|
||||
_handler = getattr(view_class, method.lower(), None)
|
||||
if _handler:
|
||||
methods.add(method)
|
||||
if hasattr(_handler, "is_stream"):
|
||||
@@ -226,6 +246,7 @@ class RouteMixin:
|
||||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)(handler)
|
||||
return handler
|
||||
|
||||
@@ -239,7 +260,8 @@ class RouteMixin:
|
||||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **GET** *HTTP* method
|
||||
|
||||
@@ -262,6 +284,7 @@ class RouteMixin:
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def post(
|
||||
@@ -273,7 +296,8 @@ class RouteMixin:
|
||||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **POST** *HTTP* method
|
||||
|
||||
@@ -296,6 +320,7 @@ class RouteMixin:
|
||||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def put(
|
||||
@@ -307,7 +332,8 @@ class RouteMixin:
|
||||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **PUT** *HTTP* method
|
||||
|
||||
@@ -330,6 +356,7 @@ class RouteMixin:
|
||||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def head(
|
||||
@@ -341,7 +368,8 @@ class RouteMixin:
|
||||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **HEAD** *HTTP* method
|
||||
|
||||
@@ -372,6 +400,7 @@ class RouteMixin:
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def options(
|
||||
@@ -383,7 +412,8 @@ class RouteMixin:
|
||||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **OPTIONS** *HTTP* method
|
||||
|
||||
@@ -414,6 +444,7 @@ class RouteMixin:
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def patch(
|
||||
@@ -425,7 +456,8 @@ class RouteMixin:
|
||||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **PATCH** *HTTP* method
|
||||
|
||||
@@ -458,6 +490,7 @@ class RouteMixin:
|
||||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def delete(
|
||||
@@ -469,7 +502,8 @@ class RouteMixin:
|
||||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
):
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **DELETE** *HTTP* method
|
||||
|
||||
@@ -492,6 +526,7 @@ class RouteMixin:
|
||||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def websocket(
|
||||
@@ -504,6 +539,7 @@ class RouteMixin:
|
||||
name: Optional[str] = None,
|
||||
apply: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Decorate a function to be registered as a websocket route
|
||||
@@ -530,6 +566,7 @@ class RouteMixin:
|
||||
subprotocols=subprotocols,
|
||||
websocket=True,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def add_websocket_route(
|
||||
@@ -542,6 +579,7 @@ class RouteMixin:
|
||||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
A helper method to register a function as a websocket route.
|
||||
@@ -570,6 +608,7 @@ class RouteMixin:
|
||||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)(handler)
|
||||
|
||||
def static(
|
||||
@@ -585,6 +624,7 @@ class RouteMixin:
|
||||
strict_slashes=None,
|
||||
content_type=None,
|
||||
apply=True,
|
||||
resource_type=None,
|
||||
):
|
||||
"""
|
||||
Register a root to serve files from. The input can either be a
|
||||
@@ -634,6 +674,7 @@ class RouteMixin:
|
||||
host,
|
||||
strict_slashes,
|
||||
content_type,
|
||||
resource_type,
|
||||
)
|
||||
self._future_statics.add(static)
|
||||
|
||||
@@ -777,10 +818,11 @@ class RouteMixin:
|
||||
)
|
||||
except Exception:
|
||||
error_logger.exception(
|
||||
f"Exception in static request handler:\
|
||||
path={file_or_directory}, "
|
||||
f"Exception in static request handler: "
|
||||
f"path={file_or_directory}, "
|
||||
f"relative_url={__file_uri__}"
|
||||
)
|
||||
raise
|
||||
|
||||
def _register_static(
|
||||
self,
|
||||
@@ -828,8 +870,27 @@ class RouteMixin:
|
||||
name = static.name
|
||||
# If we're not trying to match a file directly,
|
||||
# serve from the folder
|
||||
if not static.resource_type:
|
||||
if not path.isfile(file_or_directory):
|
||||
uri += "/<__file_uri__:path>"
|
||||
elif static.resource_type == "dir":
|
||||
if path.isfile(file_or_directory):
|
||||
raise TypeError(
|
||||
"Resource type improperly identified as directory. "
|
||||
f"'{file_or_directory}'"
|
||||
)
|
||||
uri += "/<__file_uri__:path>"
|
||||
elif static.resource_type == "file" and not path.isfile(
|
||||
file_or_directory
|
||||
):
|
||||
raise TypeError(
|
||||
"Resource type improperly identified as file. "
|
||||
f"'{file_or_directory}'"
|
||||
)
|
||||
elif static.resource_type != "file":
|
||||
raise ValueError(
|
||||
"The resource_type should be set to 'file' or 'dir'"
|
||||
)
|
||||
|
||||
# special prefix for static files
|
||||
# if not static.name.startswith("_static_"):
|
||||
@@ -846,7 +907,7 @@ class RouteMixin:
|
||||
)
|
||||
)
|
||||
|
||||
route, _ = self.route(
|
||||
route, _ = self.route( # type: ignore
|
||||
uri=uri,
|
||||
methods=["GET", "HEAD"],
|
||||
name=name,
|
||||
@@ -856,3 +917,43 @@ class RouteMixin:
|
||||
)(_handler)
|
||||
|
||||
return route
|
||||
|
||||
def _determine_error_format(self, handler) -> str:
|
||||
if not isinstance(handler, CompositionView):
|
||||
try:
|
||||
src = dedent(getsource(handler))
|
||||
tree = parse(src)
|
||||
http_response_types = self._get_response_types(tree)
|
||||
|
||||
if len(http_response_types) == 1:
|
||||
return next(iter(http_response_types))
|
||||
except (OSError, TypeError):
|
||||
...
|
||||
|
||||
return "auto"
|
||||
|
||||
def _get_response_types(self, node):
|
||||
types = set()
|
||||
|
||||
class HttpResponseVisitor(NodeVisitor):
|
||||
def visit_Return(self, node: Return) -> Any:
|
||||
nonlocal types
|
||||
|
||||
try:
|
||||
checks = [node.value.func.id] # type: ignore
|
||||
if node.value.keywords: # type: ignore
|
||||
checks += [
|
||||
k.value
|
||||
for k in node.value.keywords # type: ignore
|
||||
if k.arg == "content_type"
|
||||
]
|
||||
|
||||
for check in checks:
|
||||
if check in RESPONSE_MAPPING:
|
||||
types.add(RESPONSE_MAPPING[check])
|
||||
except AttributeError:
|
||||
...
|
||||
|
||||
HttpResponseVisitor().visit(node)
|
||||
|
||||
return types
|
||||
|
||||
@@ -23,7 +23,7 @@ class SignalMixin:
|
||||
*,
|
||||
apply: bool = True,
|
||||
condition: Dict[str, Any] = None,
|
||||
) -> Callable[[SignalHandler], FutureSignal]:
|
||||
) -> Callable[[SignalHandler], SignalHandler]:
|
||||
"""
|
||||
For creating a signal handler, used similar to a route handler:
|
||||
|
||||
@@ -54,7 +54,7 @@ class SignalMixin:
|
||||
if apply:
|
||||
self._apply_signal(future_signal)
|
||||
|
||||
return future_signal
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
|
||||
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.websocket import WebSocketConnection
|
||||
from sanic.server.websockets.connection import WebSocketConnection
|
||||
|
||||
|
||||
ASGIScope = MutableMapping[str, Any]
|
||||
|
||||
@@ -24,6 +24,7 @@ class FutureRoute(NamedTuple):
|
||||
unquote: bool
|
||||
static: bool
|
||||
version_prefix: str
|
||||
error_format: Optional[str]
|
||||
|
||||
|
||||
class FutureListener(NamedTuple):
|
||||
@@ -52,6 +53,7 @@ class FutureStatic(NamedTuple):
|
||||
host: Optional[str]
|
||||
strict_slashes: Optional[bool]
|
||||
content_type: Optional[bool]
|
||||
resource_type: Optional[str]
|
||||
|
||||
|
||||
class FutureSignal(NamedTuple):
|
||||
|
||||
@@ -21,5 +21,5 @@ MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType]
|
||||
ListenerType = Callable[
|
||||
[Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]]
|
||||
]
|
||||
RouteHandler = Callable[..., Coroutine[Any, Any, HTTPResponse]]
|
||||
RouteHandler = Callable[..., Coroutine[Any, Any, Optional[HTTPResponse]]]
|
||||
SignalHandler = Callable[..., Coroutine[Any, Any, None]]
|
||||
|
||||
52
sanic/models/server_types.py
Normal file
52
sanic/models/server_types.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sanic.models.protocol_types import TransportProtocol
|
||||
|
||||
|
||||
class Signal:
|
||||
stopped = False
|
||||
|
||||
|
||||
class ConnInfo:
|
||||
"""
|
||||
Local and remote addresses and SSL status info.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"client_port",
|
||||
"client",
|
||||
"client_ip",
|
||||
"ctx",
|
||||
"peername",
|
||||
"server_port",
|
||||
"server",
|
||||
"sockname",
|
||||
"ssl",
|
||||
)
|
||||
|
||||
def __init__(self, transport: TransportProtocol, unix=None):
|
||||
self.ctx = SimpleNamespace()
|
||||
self.peername = None
|
||||
self.server = self.client = ""
|
||||
self.server_port = self.client_port = 0
|
||||
self.client_ip = ""
|
||||
self.sockname = addr = transport.get_extra_info("sockname")
|
||||
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
if isinstance(addr, str): # UNIX socket
|
||||
self.server = unix or addr
|
||||
return
|
||||
|
||||
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
|
||||
if isinstance(addr, tuple):
|
||||
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
|
||||
self.server_port = addr[1]
|
||||
# self.server gets non-standard port appended
|
||||
if addr[1] != (443 if self.ssl else 80):
|
||||
self.server = f"{self.server}:{addr[1]}"
|
||||
self.peername = addr = transport.get_extra_info("peername")
|
||||
|
||||
if isinstance(addr, tuple):
|
||||
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
|
||||
self.client_ip = addr[0]
|
||||
self.client_port = addr[1]
|
||||
@@ -34,7 +34,9 @@ from sanic.compat import CancelledErrors, Header
|
||||
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.headers import (
|
||||
AcceptContainer,
|
||||
Options,
|
||||
parse_accept,
|
||||
parse_content_header,
|
||||
parse_forwarded,
|
||||
parse_host,
|
||||
@@ -94,6 +96,7 @@ class Request:
|
||||
"head",
|
||||
"headers",
|
||||
"method",
|
||||
"parsed_accept",
|
||||
"parsed_args",
|
||||
"parsed_not_grouped_args",
|
||||
"parsed_files",
|
||||
@@ -136,6 +139,7 @@ class Request:
|
||||
self.conn_info: Optional[ConnInfo] = None
|
||||
self.ctx = SimpleNamespace()
|
||||
self.parsed_forwarded: Optional[Options] = None
|
||||
self.parsed_accept: Optional[AcceptContainer] = None
|
||||
self.parsed_json = None
|
||||
self.parsed_form = None
|
||||
self.parsed_files = None
|
||||
@@ -296,6 +300,13 @@ class Request:
|
||||
|
||||
return self.parsed_json
|
||||
|
||||
@property
|
||||
def accept(self) -> AcceptContainer:
|
||||
if self.parsed_accept is None:
|
||||
accept_header = self.headers.getone("accept", "")
|
||||
self.parsed_accept = parse_accept(accept_header)
|
||||
return self.parsed_accept
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
"""Attempt to return the auth header token.
|
||||
@@ -497,6 +508,10 @@ class Request:
|
||||
"""
|
||||
return self._match_info
|
||||
|
||||
@match_info.setter
|
||||
def match_info(self, value):
|
||||
self._match_info = value
|
||||
|
||||
# Transport properties (obtained from local interface only)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from sanic_routing import BaseRouter # type: ignore
|
||||
from sanic_routing.exceptions import NoMethod # type: ignore
|
||||
@@ -9,6 +13,7 @@ from sanic_routing.exceptions import (
|
||||
from sanic_routing.route import Route # type: ignore
|
||||
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.errorpages import check_error_format
|
||||
from sanic.exceptions import MethodNotSupported, NotFound, SanicException
|
||||
from sanic.models.handler_types import RouteHandler
|
||||
|
||||
@@ -74,6 +79,7 @@ class Router(BaseRouter):
|
||||
unquote: bool = False,
|
||||
static: bool = False,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> Union[Route, List[Route]]:
|
||||
"""
|
||||
Add a handler to the router
|
||||
@@ -106,6 +112,8 @@ class Router(BaseRouter):
|
||||
version = str(version).strip("/").lstrip("v")
|
||||
uri = "/".join([f"{version_prefix}{version}", uri.lstrip("/")])
|
||||
|
||||
uri = self._normalize(uri, handler)
|
||||
|
||||
params = dict(
|
||||
path=uri,
|
||||
handler=handler,
|
||||
@@ -131,6 +139,11 @@ class Router(BaseRouter):
|
||||
route.ctx.stream = stream
|
||||
route.ctx.hosts = hosts
|
||||
route.ctx.static = static
|
||||
route.ctx.error_format = (
|
||||
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
|
||||
)
|
||||
|
||||
check_error_format(route.ctx.error_format)
|
||||
|
||||
routes.append(route)
|
||||
|
||||
@@ -187,3 +200,24 @@ class Router(BaseRouter):
|
||||
raise SanicException(
|
||||
f"Invalid route: {route}. Parameter names cannot use '__'."
|
||||
)
|
||||
|
||||
def _normalize(self, uri: str, handler: RouteHandler) -> str:
|
||||
if "<" not in uri:
|
||||
return uri
|
||||
|
||||
sig = signature(handler)
|
||||
mapping = {
|
||||
param.name: param.annotation.__name__.lower()
|
||||
for param in sig.parameters.values()
|
||||
if param.annotation in (str, int, float, UUID)
|
||||
}
|
||||
|
||||
reconstruction = []
|
||||
for part in uri.split("/"):
|
||||
if part.startswith("<") and ":" not in part:
|
||||
name = part[1:-1]
|
||||
annotation = mapping.get(name)
|
||||
if annotation:
|
||||
part = f"<{name}:{annotation}>"
|
||||
reconstruction.append(part)
|
||||
return "/".join(reconstruction)
|
||||
|
||||
793
sanic/server.py
793
sanic/server.py
@@ -1,793 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ssl import SSLContext
|
||||
from types import SimpleNamespace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from sanic.models.handler_types import ListenerType
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.app import Sanic
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import stat
|
||||
|
||||
from asyncio import CancelledError
|
||||
from asyncio.transports import Transport
|
||||
from functools import partial
|
||||
from inspect import isawaitable
|
||||
from ipaddress import ip_address
|
||||
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
||||
from signal import signal as signal_func
|
||||
from time import monotonic as current_time
|
||||
|
||||
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
|
||||
from sanic.config import Config
|
||||
from sanic.exceptions import RequestTimeout, ServiceUnavailable
|
||||
from sanic.http import Http, Stage
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.models.protocol_types import TransportProtocol
|
||||
from sanic.request import Request
|
||||
|
||||
|
||||
try:
|
||||
import uvloop # type: ignore
|
||||
|
||||
if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy):
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class Signal:
|
||||
stopped = False
|
||||
|
||||
|
||||
class ConnInfo:
|
||||
"""
|
||||
Local and remote addresses and SSL status info.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"client_port",
|
||||
"client",
|
||||
"client_ip",
|
||||
"ctx",
|
||||
"peername",
|
||||
"server_port",
|
||||
"server",
|
||||
"sockname",
|
||||
"ssl",
|
||||
)
|
||||
|
||||
def __init__(self, transport: TransportProtocol, unix=None):
|
||||
self.ctx = SimpleNamespace()
|
||||
self.peername = None
|
||||
self.server = self.client = ""
|
||||
self.server_port = self.client_port = 0
|
||||
self.client_ip = ""
|
||||
self.sockname = addr = transport.get_extra_info("sockname")
|
||||
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
if isinstance(addr, str): # UNIX socket
|
||||
self.server = unix or addr
|
||||
return
|
||||
|
||||
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
|
||||
if isinstance(addr, tuple):
|
||||
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
|
||||
self.server_port = addr[1]
|
||||
# self.server gets non-standard port appended
|
||||
if addr[1] != (443 if self.ssl else 80):
|
||||
self.server = f"{self.server}:{addr[1]}"
|
||||
self.peername = addr = transport.get_extra_info("peername")
|
||||
|
||||
if isinstance(addr, tuple):
|
||||
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
|
||||
self.client_ip = addr[0]
|
||||
self.client_port = addr[1]
|
||||
|
||||
|
||||
class HttpProtocol(asyncio.Protocol):
|
||||
"""
|
||||
This class provides a basic HTTP implementation of the sanic framework.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
# app
|
||||
"app",
|
||||
# event loop, connection
|
||||
"loop",
|
||||
"transport",
|
||||
"connections",
|
||||
"signal",
|
||||
"conn_info",
|
||||
"ctx",
|
||||
# request params
|
||||
"request",
|
||||
# request config
|
||||
"request_handler",
|
||||
"request_timeout",
|
||||
"response_timeout",
|
||||
"keep_alive_timeout",
|
||||
"request_max_size",
|
||||
"request_class",
|
||||
"error_handler",
|
||||
# enable or disable access log purpose
|
||||
"access_log",
|
||||
# connection management
|
||||
"state",
|
||||
"url",
|
||||
"_handler_task",
|
||||
"_can_write",
|
||||
"_data_received",
|
||||
"_time",
|
||||
"_task",
|
||||
"_http",
|
||||
"_exception",
|
||||
"recv_buffer",
|
||||
"_unix",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop,
|
||||
app: Sanic,
|
||||
signal=None,
|
||||
connections=None,
|
||||
state=None,
|
||||
unix=None,
|
||||
**kwargs,
|
||||
):
|
||||
asyncio.set_event_loop(loop)
|
||||
self.loop = loop
|
||||
self.app: Sanic = app
|
||||
self.url = None
|
||||
self.transport: Optional[Transport] = None
|
||||
self.conn_info: Optional[ConnInfo] = None
|
||||
self.request: Optional[Request] = None
|
||||
self.signal = signal or Signal()
|
||||
self.access_log = self.app.config.ACCESS_LOG
|
||||
self.connections = connections if connections is not None else set()
|
||||
self.request_handler = self.app.handle_request
|
||||
self.error_handler = self.app.error_handler
|
||||
self.request_timeout = self.app.config.REQUEST_TIMEOUT
|
||||
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
|
||||
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
|
||||
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
|
||||
self.request_class = self.app.request_class or Request
|
||||
self.state = state if state else {}
|
||||
if "requests_count" not in self.state:
|
||||
self.state["requests_count"] = 0
|
||||
self._data_received = asyncio.Event()
|
||||
self._can_write = asyncio.Event()
|
||||
self._can_write.set()
|
||||
self._exception = None
|
||||
self._unix = unix
|
||||
|
||||
def _setup_connection(self):
|
||||
self._http = Http(self)
|
||||
self._time = current_time()
|
||||
self.check_timeouts()
|
||||
|
||||
async def connection_task(self):
|
||||
"""
|
||||
Run a HTTP connection.
|
||||
|
||||
Timeouts and some additional error handling occur here, while most of
|
||||
everything else happens in class Http or in code called from there.
|
||||
"""
|
||||
try:
|
||||
self._setup_connection()
|
||||
await self._http.http1()
|
||||
except CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
error_logger.exception("protocol.connection_task uncaught")
|
||||
finally:
|
||||
if self.app.debug and self._http:
|
||||
ip = self.transport.get_extra_info("peername")
|
||||
error_logger.error(
|
||||
"Connection lost before response written"
|
||||
f" @ {ip} {self._http.request}"
|
||||
)
|
||||
self._http = None
|
||||
self._task = None
|
||||
try:
|
||||
self.close()
|
||||
except BaseException:
|
||||
error_logger.exception("Closing failed")
|
||||
|
||||
async def receive_more(self):
|
||||
"""
|
||||
Wait until more data is received into the Server protocol's buffer
|
||||
"""
|
||||
self.transport.resume_reading()
|
||||
self._data_received.clear()
|
||||
await self._data_received.wait()
|
||||
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Runs itself periodically to enforce any expired timeouts.
|
||||
"""
|
||||
try:
|
||||
if not self._task:
|
||||
return
|
||||
duration = current_time() - self._time
|
||||
stage = self._http.stage
|
||||
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
|
||||
logger.debug("KeepAlive Timeout. Closing connection.")
|
||||
elif stage is Stage.REQUEST and duration > self.request_timeout:
|
||||
logger.debug("Request Timeout. Closing connection.")
|
||||
self._http.exception = RequestTimeout("Request Timeout")
|
||||
elif stage is Stage.HANDLER and self._http.upgrade_websocket:
|
||||
logger.debug("Handling websocket. Timeouts disabled.")
|
||||
return
|
||||
elif (
|
||||
stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED)
|
||||
and duration > self.response_timeout
|
||||
):
|
||||
logger.debug("Response Timeout. Closing connection.")
|
||||
self._http.exception = ServiceUnavailable("Response Timeout")
|
||||
else:
|
||||
interval = (
|
||||
min(
|
||||
self.keep_alive_timeout,
|
||||
self.request_timeout,
|
||||
self.response_timeout,
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
self.loop.call_later(max(0.1, interval), self.check_timeouts)
|
||||
return
|
||||
self._task.cancel()
|
||||
except Exception:
|
||||
error_logger.exception("protocol.check_timeouts")
|
||||
|
||||
async def send(self, data):
|
||||
"""
|
||||
Writes data with backpressure control.
|
||||
"""
|
||||
await self._can_write.wait()
|
||||
if self.transport.is_closing():
|
||||
raise CancelledError
|
||||
self.transport.write(data)
|
||||
self._time = current_time()
|
||||
|
||||
def close_if_idle(self) -> bool:
|
||||
"""
|
||||
Close the connection if a request is not being sent or received
|
||||
|
||||
:return: boolean - True if closed, false if staying open
|
||||
"""
|
||||
if self._http is None or self._http.stage is Stage.IDLE:
|
||||
self.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Force close the connection.
|
||||
"""
|
||||
# Cause a call to connection_lost where further cleanup occurs
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Only asyncio.Protocol callbacks below this
|
||||
# -------------------------------------------- #
|
||||
|
||||
def connection_made(self, transport):
|
||||
try:
|
||||
# TODO: Benchmark to find suitable write buffer limits
|
||||
transport.set_write_buffer_limits(low=16384, high=65536)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self._task = self.loop.create_task(self.connection_task())
|
||||
self.recv_buffer = bytearray()
|
||||
self.conn_info = ConnInfo(self.transport, unix=self._unix)
|
||||
except Exception:
|
||||
error_logger.exception("protocol.connect_made")
|
||||
|
||||
def connection_lost(self, exc):
|
||||
try:
|
||||
self.connections.discard(self)
|
||||
self.resume_writing()
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
except Exception:
|
||||
error_logger.exception("protocol.connection_lost")
|
||||
|
||||
def pause_writing(self):
|
||||
self._can_write.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
self._can_write.set()
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
try:
|
||||
self._time = current_time()
|
||||
if not data:
|
||||
return self.close()
|
||||
self.recv_buffer += data
|
||||
|
||||
if (
|
||||
len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE
|
||||
and self.transport
|
||||
):
|
||||
self.transport.pause_reading()
|
||||
|
||||
if self._data_received:
|
||||
self._data_received.set()
|
||||
except Exception:
|
||||
error_logger.exception("protocol.data_received")
|
||||
|
||||
|
||||
def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop):
|
||||
"""
|
||||
Trigger event callbacks (functions or async)
|
||||
|
||||
:param events: one or more sync or async functions to execute
|
||||
:param loop: event loop
|
||||
"""
|
||||
if events:
|
||||
for event in events:
|
||||
result = event(loop)
|
||||
if isawaitable(result):
|
||||
loop.run_until_complete(result)
|
||||
|
||||
|
||||
class AsyncioServer:
|
||||
"""
|
||||
Wraps an asyncio server with functionality that might be useful to
|
||||
a user who needs to manage the server lifecycle manually.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"loop",
|
||||
"serve_coro",
|
||||
"_after_start",
|
||||
"_before_stop",
|
||||
"_after_stop",
|
||||
"server",
|
||||
"connections",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop,
|
||||
serve_coro,
|
||||
connections,
|
||||
after_start: Optional[Iterable[ListenerType]],
|
||||
before_stop: Optional[Iterable[ListenerType]],
|
||||
after_stop: Optional[Iterable[ListenerType]],
|
||||
):
|
||||
# Note, Sanic already called "before_server_start" events
|
||||
# before this helper was even created. So we don't need it here.
|
||||
self.loop = loop
|
||||
self.serve_coro = serve_coro
|
||||
self._after_start = after_start
|
||||
self._before_stop = before_stop
|
||||
self._after_stop = after_stop
|
||||
self.server = None
|
||||
self.connections = connections
|
||||
|
||||
def after_start(self):
|
||||
"""
|
||||
Trigger "after_server_start" events
|
||||
"""
|
||||
trigger_events(self._after_start, self.loop)
|
||||
|
||||
def before_stop(self):
|
||||
"""
|
||||
Trigger "before_server_stop" events
|
||||
"""
|
||||
trigger_events(self._before_stop, self.loop)
|
||||
|
||||
def after_stop(self):
|
||||
"""
|
||||
Trigger "after_server_stop" events
|
||||
"""
|
||||
trigger_events(self._after_stop, self.loop)
|
||||
|
||||
def is_serving(self) -> bool:
|
||||
if self.server:
|
||||
return self.server.is_serving()
|
||||
return False
|
||||
|
||||
def wait_closed(self):
|
||||
if self.server:
|
||||
return self.server.wait_closed()
|
||||
|
||||
def close(self):
|
||||
if self.server:
|
||||
self.server.close()
|
||||
coro = self.wait_closed()
|
||||
task = asyncio.ensure_future(coro, loop=self.loop)
|
||||
return task
|
||||
|
||||
def start_serving(self):
|
||||
if self.server:
|
||||
try:
|
||||
return self.server.start_serving()
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
"server.start_serving not available in this version "
|
||||
"of asyncio or uvloop."
|
||||
)
|
||||
|
||||
def serve_forever(self):
|
||||
if self.server:
|
||||
try:
|
||||
return self.server.serve_forever()
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
"server.serve_forever not available in this version "
|
||||
"of asyncio or uvloop."
|
||||
)
|
||||
|
||||
def __await__(self):
|
||||
"""
|
||||
Starts the asyncio server, returns AsyncServerCoro
|
||||
"""
|
||||
task = asyncio.ensure_future(self.serve_coro)
|
||||
while not task.done():
|
||||
yield
|
||||
self.server = task.result()
|
||||
return self
|
||||
|
||||
|
||||
def serve(
|
||||
host,
|
||||
port,
|
||||
app,
|
||||
before_start: Optional[Iterable[ListenerType]] = None,
|
||||
after_start: Optional[Iterable[ListenerType]] = None,
|
||||
before_stop: Optional[Iterable[ListenerType]] = None,
|
||||
after_stop: Optional[Iterable[ListenerType]] = None,
|
||||
ssl: Optional[SSLContext] = None,
|
||||
sock: Optional[socket.socket] = None,
|
||||
unix: Optional[str] = None,
|
||||
reuse_port: bool = False,
|
||||
loop=None,
|
||||
protocol: Type[asyncio.Protocol] = HttpProtocol,
|
||||
backlog: int = 100,
|
||||
register_sys_signals: bool = True,
|
||||
run_multiple: bool = False,
|
||||
run_async: bool = False,
|
||||
connections=None,
|
||||
signal=Signal(),
|
||||
state=None,
|
||||
asyncio_server_kwargs=None,
|
||||
):
|
||||
"""Start asynchronous HTTP Server on an individual process.
|
||||
|
||||
:param host: Address to host on
|
||||
:param port: Port to host on
|
||||
:param before_start: function to be executed before the server starts
|
||||
listening. Takes arguments `app` instance and `loop`
|
||||
:param after_start: function to be executed after the server starts
|
||||
listening. Takes arguments `app` instance and `loop`
|
||||
:param before_stop: function to be executed when a stop signal is
|
||||
received before it is respected. Takes arguments
|
||||
`app` instance and `loop`
|
||||
:param after_stop: function to be executed when a stop signal is
|
||||
received after it is respected. Takes arguments
|
||||
`app` instance and `loop`
|
||||
:param ssl: SSLContext
|
||||
:param sock: Socket for the server to accept connections from
|
||||
:param unix: Unix socket to listen on instead of TCP port
|
||||
:param reuse_port: `True` for multiple workers
|
||||
:param loop: asyncio compatible event loop
|
||||
:param run_async: bool: Do not create a new event loop for the server,
|
||||
and return an AsyncServer object rather than running it
|
||||
:param asyncio_server_kwargs: key-value args for asyncio/uvloop
|
||||
create_server method
|
||||
:return: Nothing
|
||||
"""
|
||||
if not run_async and not loop:
|
||||
# create new event_loop after fork
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if app.debug:
|
||||
loop.set_debug(app.debug)
|
||||
|
||||
app.asgi = False
|
||||
|
||||
connections = connections if connections is not None else set()
|
||||
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
|
||||
server = partial(
|
||||
protocol,
|
||||
loop=loop,
|
||||
connections=connections,
|
||||
signal=signal,
|
||||
app=app,
|
||||
state=state,
|
||||
unix=unix,
|
||||
**protocol_kwargs,
|
||||
)
|
||||
asyncio_server_kwargs = (
|
||||
asyncio_server_kwargs if asyncio_server_kwargs else {}
|
||||
)
|
||||
# UNIX sockets are always bound by us (to preserve semantics between modes)
|
||||
if unix:
|
||||
sock = bind_unix_socket(unix, backlog=backlog)
|
||||
server_coroutine = loop.create_server(
|
||||
server,
|
||||
None if sock else host,
|
||||
None if sock else port,
|
||||
ssl=ssl,
|
||||
reuse_port=reuse_port,
|
||||
sock=sock,
|
||||
backlog=backlog,
|
||||
**asyncio_server_kwargs,
|
||||
)
|
||||
|
||||
if run_async:
|
||||
return AsyncioServer(
|
||||
loop=loop,
|
||||
serve_coro=server_coroutine,
|
||||
connections=connections,
|
||||
after_start=after_start,
|
||||
before_stop=before_stop,
|
||||
after_stop=after_stop,
|
||||
)
|
||||
|
||||
trigger_events(before_start, loop)
|
||||
|
||||
try:
|
||||
http_server = loop.run_until_complete(server_coroutine)
|
||||
except BaseException:
|
||||
error_logger.exception("Unable to start server")
|
||||
return
|
||||
|
||||
trigger_events(after_start, loop)
|
||||
|
||||
# Ignore SIGINT when run_multiple
|
||||
if run_multiple:
|
||||
signal_func(SIGINT, SIG_IGN)
|
||||
|
||||
# Register signals for graceful termination
|
||||
if register_sys_signals:
|
||||
if OS_IS_WINDOWS:
|
||||
ctrlc_workaround_for_windows(app)
|
||||
else:
|
||||
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
|
||||
loop.add_signal_handler(_signal, app.stop)
|
||||
pid = os.getpid()
|
||||
try:
|
||||
logger.info("Starting worker [%s]", pid)
|
||||
loop.run_forever()
|
||||
finally:
|
||||
logger.info("Stopping worker [%s]", pid)
|
||||
|
||||
# Run the on_stop function if provided
|
||||
trigger_events(before_stop, loop)
|
||||
|
||||
# Wait for event loop to finish and all connections to drain
|
||||
http_server.close()
|
||||
loop.run_until_complete(http_server.wait_closed())
|
||||
|
||||
# Complete all tasks on the loop
|
||||
signal.stopped = True
|
||||
for connection in connections:
|
||||
connection.close_if_idle()
|
||||
|
||||
# Gracefully shutdown timeout.
|
||||
# We should provide graceful_shutdown_timeout,
|
||||
# instead of letting connection hangs forever.
|
||||
# Let's roughly calcucate time.
|
||||
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
start_shutdown: float = 0
|
||||
while connections and (start_shutdown < graceful):
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
start_shutdown = start_shutdown + 0.1
|
||||
|
||||
# Force close non-idle connection after waiting for
|
||||
# graceful_shutdown_timeout
|
||||
coros = []
|
||||
for conn in connections:
|
||||
if hasattr(conn, "websocket") and conn.websocket:
|
||||
coros.append(conn.websocket.close_connection())
|
||||
else:
|
||||
conn.close()
|
||||
|
||||
_shutdown = asyncio.gather(*coros)
|
||||
loop.run_until_complete(_shutdown)
|
||||
|
||||
trigger_events(after_stop, loop)
|
||||
|
||||
remove_unix_socket(unix)
|
||||
|
||||
|
||||
def _build_protocol_kwargs(
|
||||
protocol: Type[asyncio.Protocol], config: Config
|
||||
) -> Dict[str, Union[int, float]]:
|
||||
if hasattr(protocol, "websocket_handshake"):
|
||||
return {
|
||||
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
|
||||
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
|
||||
"websocket_read_limit": config.WEBSOCKET_READ_LIMIT,
|
||||
"websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT,
|
||||
"websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT,
|
||||
"websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
|
||||
"""Create TCP server socket.
|
||||
:param host: IPv4, IPv6 or hostname may be specified
|
||||
:param port: TCP port number
|
||||
:param backlog: Maximum number of connections to queue
|
||||
:return: socket.socket object
|
||||
"""
|
||||
try: # IP address: family must be specified for IPv6 at least
|
||||
ip = ip_address(host)
|
||||
host = str(ip)
|
||||
sock = socket.socket(
|
||||
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
|
||||
)
|
||||
except ValueError: # Hostname, may become AF_INET or AF_INET6
|
||||
sock = socket.socket()
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind((host, port))
|
||||
sock.listen(backlog)
|
||||
return sock
|
||||
|
||||
|
||||
def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
|
||||
"""Create unix socket.
|
||||
:param path: filesystem path
|
||||
:param backlog: Maximum number of connections to queue
|
||||
:return: socket.socket object
|
||||
"""
|
||||
"""Open or atomically replace existing socket with zero downtime."""
|
||||
# Sanitise and pre-verify socket path
|
||||
path = os.path.abspath(path)
|
||||
folder = os.path.dirname(path)
|
||||
if not os.path.isdir(folder):
|
||||
raise FileNotFoundError(f"Socket folder does not exist: {folder}")
|
||||
try:
|
||||
if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
|
||||
raise FileExistsError(f"Existing file is not a socket: {path}")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# Create new socket with a random temporary name
|
||||
tmp_path = f"{path}.{secrets.token_urlsafe()}"
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
try:
|
||||
# Critical section begins (filename races)
|
||||
sock.bind(tmp_path)
|
||||
try:
|
||||
os.chmod(tmp_path, mode)
|
||||
# Start listening before rename to avoid connection failures
|
||||
sock.listen(backlog)
|
||||
os.rename(tmp_path, path)
|
||||
except: # noqa: E722
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
finally:
|
||||
raise
|
||||
except: # noqa: E722
|
||||
try:
|
||||
sock.close()
|
||||
finally:
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
def remove_unix_socket(path: Optional[str]) -> None:
|
||||
"""Remove dead unix socket during server exit."""
|
||||
if not path:
|
||||
return
|
||||
try:
|
||||
if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
|
||||
# Is it actually dead (doesn't belong to a new server instance)?
|
||||
with socket.socket(socket.AF_UNIX) as testsock:
|
||||
try:
|
||||
testsock.connect(path)
|
||||
except ConnectionRefusedError:
|
||||
os.unlink(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def serve_single(server_settings):
|
||||
main_start = server_settings.pop("main_start", None)
|
||||
main_stop = server_settings.pop("main_stop", None)
|
||||
|
||||
if not server_settings.get("run_async"):
|
||||
# create new event_loop after fork
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_settings["loop"] = loop
|
||||
|
||||
trigger_events(main_start, server_settings["loop"])
|
||||
serve(**server_settings)
|
||||
trigger_events(main_stop, server_settings["loop"])
|
||||
|
||||
server_settings["loop"].close()
|
||||
|
||||
|
||||
def serve_multiple(server_settings, workers):
|
||||
"""Start multiple server processes simultaneously. Stop on interrupt
|
||||
and terminate signals, and drain connections when complete.
|
||||
|
||||
:param server_settings: kw arguments to be passed to the serve function
|
||||
:param workers: number of workers to launch
|
||||
:param stop_event: if provided, is used as a stop signal
|
||||
:return:
|
||||
"""
|
||||
server_settings["reuse_port"] = True
|
||||
server_settings["run_multiple"] = True
|
||||
|
||||
main_start = server_settings.pop("main_start", None)
|
||||
main_stop = server_settings.pop("main_stop", None)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
trigger_events(main_start, loop)
|
||||
|
||||
# Create a listening socket or use the one in settings
|
||||
sock = server_settings.get("sock")
|
||||
unix = server_settings["unix"]
|
||||
backlog = server_settings["backlog"]
|
||||
if unix:
|
||||
sock = bind_unix_socket(unix, backlog=backlog)
|
||||
server_settings["unix"] = unix
|
||||
if sock is None:
|
||||
sock = bind_socket(
|
||||
server_settings["host"], server_settings["port"], backlog=backlog
|
||||
)
|
||||
sock.set_inheritable(True)
|
||||
server_settings["sock"] = sock
|
||||
server_settings["host"] = None
|
||||
server_settings["port"] = None
|
||||
|
||||
processes = []
|
||||
|
||||
def sig_handler(signal, frame):
|
||||
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
|
||||
for process in processes:
|
||||
os.kill(process.pid, SIGTERM)
|
||||
|
||||
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
|
||||
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
|
||||
mp = multiprocessing.get_context("fork")
|
||||
|
||||
for _ in range(workers):
|
||||
process = mp.Process(target=serve, kwargs=server_settings)
|
||||
process.daemon = True
|
||||
process.start()
|
||||
processes.append(process)
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
|
||||
# the above processes will block this until they're stopped
|
||||
for process in processes:
|
||||
process.terminate()
|
||||
|
||||
trigger_events(main_stop, loop)
|
||||
|
||||
sock.close()
|
||||
loop.close()
|
||||
remove_unix_socket(unix)
|
||||
26
sanic/server/__init__.py
Normal file
26
sanic/server/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import asyncio
|
||||
|
||||
from sanic.models.server_types import ConnInfo, Signal
|
||||
from sanic.server.async_server import AsyncioServer
|
||||
from sanic.server.protocols.http_protocol import HttpProtocol
|
||||
from sanic.server.runners import serve, serve_multiple, serve_single
|
||||
|
||||
|
||||
try:
|
||||
import uvloop # type: ignore
|
||||
|
||||
if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy):
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = (
|
||||
"AsyncioServer",
|
||||
"ConnInfo",
|
||||
"HttpProtocol",
|
||||
"Signal",
|
||||
"serve",
|
||||
"serve_multiple",
|
||||
"serve_single",
|
||||
)
|
||||
115
sanic/server/async_server.py
Normal file
115
sanic/server/async_server.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from sanic.exceptions import SanicException
|
||||
|
||||
|
||||
class AsyncioServer:
|
||||
"""
|
||||
Wraps an asyncio server with functionality that might be useful to
|
||||
a user who needs to manage the server lifecycle manually.
|
||||
"""
|
||||
|
||||
__slots__ = ("app", "connections", "loop", "serve_coro", "server", "init")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
loop,
|
||||
serve_coro,
|
||||
connections,
|
||||
):
|
||||
# Note, Sanic already called "before_server_start" events
|
||||
# before this helper was even created. So we don't need it here.
|
||||
self.app = app
|
||||
self.connections = connections
|
||||
self.loop = loop
|
||||
self.serve_coro = serve_coro
|
||||
self.server = None
|
||||
self.init = False
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
Trigger "before_server_start" events
|
||||
"""
|
||||
self.init = True
|
||||
return self.app._startup()
|
||||
|
||||
def before_start(self):
|
||||
"""
|
||||
Trigger "before_server_start" events
|
||||
"""
|
||||
return self._server_event("init", "before")
|
||||
|
||||
def after_start(self):
|
||||
"""
|
||||
Trigger "after_server_start" events
|
||||
"""
|
||||
return self._server_event("init", "after")
|
||||
|
||||
def before_stop(self):
|
||||
"""
|
||||
Trigger "before_server_stop" events
|
||||
"""
|
||||
return self._server_event("shutdown", "before")
|
||||
|
||||
def after_stop(self):
|
||||
"""
|
||||
Trigger "after_server_stop" events
|
||||
"""
|
||||
return self._server_event("shutdown", "after")
|
||||
|
||||
def is_serving(self) -> bool:
|
||||
if self.server:
|
||||
return self.server.is_serving()
|
||||
return False
|
||||
|
||||
def wait_closed(self):
|
||||
if self.server:
|
||||
return self.server.wait_closed()
|
||||
|
||||
def close(self):
|
||||
if self.server:
|
||||
self.server.close()
|
||||
coro = self.wait_closed()
|
||||
task = asyncio.ensure_future(coro, loop=self.loop)
|
||||
return task
|
||||
|
||||
def start_serving(self):
|
||||
if self.server:
|
||||
try:
|
||||
return self.server.start_serving()
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
"server.start_serving not available in this version "
|
||||
"of asyncio or uvloop."
|
||||
)
|
||||
|
||||
def serve_forever(self):
|
||||
if self.server:
|
||||
try:
|
||||
return self.server.serve_forever()
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
"server.serve_forever not available in this version "
|
||||
"of asyncio or uvloop."
|
||||
)
|
||||
|
||||
def _server_event(self, concern: str, action: str):
|
||||
if not self.init:
|
||||
raise SanicException(
|
||||
"Cannot dispatch server event without "
|
||||
"first running server.startup()"
|
||||
)
|
||||
return self.app._server_event(concern, action, loop=self.loop)
|
||||
|
||||
def __await__(self):
|
||||
"""
|
||||
Starts the asyncio server, returns AsyncServerCoro
|
||||
"""
|
||||
task = asyncio.ensure_future(self.serve_coro)
|
||||
while not task.done():
|
||||
yield
|
||||
self.server = task.result()
|
||||
return self
|
||||
16
sanic/server/events.py
Normal file
16
sanic/server/events.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from inspect import isawaitable
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
|
||||
def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop):
|
||||
"""
|
||||
Trigger event callbacks (functions or async)
|
||||
|
||||
:param events: one or more sync or async functions to execute
|
||||
:param loop: event loop
|
||||
"""
|
||||
if events:
|
||||
for event in events:
|
||||
result = event(loop)
|
||||
if isawaitable(result):
|
||||
loop.run_until_complete(result)
|
||||
0
sanic/server/protocols/__init__.py
Normal file
0
sanic/server/protocols/__init__.py
Normal file
143
sanic/server/protocols/base_protocol.py
Normal file
143
sanic/server/protocols/base_protocol.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.app import Sanic
|
||||
|
||||
import asyncio
|
||||
|
||||
from asyncio import CancelledError
|
||||
from asyncio.transports import Transport
|
||||
from time import monotonic as current_time
|
||||
|
||||
from sanic.log import error_logger
|
||||
from sanic.models.server_types import ConnInfo, Signal
|
||||
|
||||
|
||||
class SanicProtocol(asyncio.Protocol):
|
||||
__slots__ = (
|
||||
"app",
|
||||
# event loop, connection
|
||||
"loop",
|
||||
"transport",
|
||||
"connections",
|
||||
"conn_info",
|
||||
"signal",
|
||||
"_can_write",
|
||||
"_time",
|
||||
"_task",
|
||||
"_unix",
|
||||
"_data_received",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop,
|
||||
app: Sanic,
|
||||
signal=None,
|
||||
connections=None,
|
||||
unix=None,
|
||||
**kwargs,
|
||||
):
|
||||
asyncio.set_event_loop(loop)
|
||||
self.loop = loop
|
||||
self.app: Sanic = app
|
||||
self.signal = signal or Signal()
|
||||
self.transport: Optional[Transport] = None
|
||||
self.connections = connections if connections is not None else set()
|
||||
self.conn_info: Optional[ConnInfo] = None
|
||||
self._can_write = asyncio.Event()
|
||||
self._can_write.set()
|
||||
self._unix = unix
|
||||
self._time = 0.0 # type: float
|
||||
self._task = None # type: Optional[asyncio.Task]
|
||||
self._data_received = asyncio.Event()
|
||||
|
||||
@property
|
||||
def ctx(self):
|
||||
if self.conn_info is not None:
|
||||
return self.conn_info.ctx
|
||||
else:
|
||||
return None
|
||||
|
||||
async def send(self, data):
|
||||
"""
|
||||
Generic data write implementation with backpressure control.
|
||||
"""
|
||||
await self._can_write.wait()
|
||||
if self.transport.is_closing():
|
||||
raise CancelledError
|
||||
self.transport.write(data)
|
||||
self._time = current_time()
|
||||
|
||||
async def receive_more(self):
|
||||
"""
|
||||
Wait until more data is received into the Server protocol's buffer
|
||||
"""
|
||||
self.transport.resume_reading()
|
||||
self._data_received.clear()
|
||||
await self._data_received.wait()
|
||||
|
||||
def close(self, timeout: Optional[float] = None):
|
||||
"""
|
||||
Attempt close the connection.
|
||||
"""
|
||||
# Cause a call to connection_lost where further cleanup occurs
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
if timeout is None:
|
||||
timeout = self.app.config.GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
self.loop.call_later(timeout, self.abort)
|
||||
|
||||
def abort(self):
|
||||
"""
|
||||
Force close the connection.
|
||||
"""
|
||||
# Cause a call to connection_lost where further cleanup occurs
|
||||
if self.transport:
|
||||
self.transport.abort()
|
||||
self.transport = None
|
||||
|
||||
# asyncio.Protocol API Callbacks #
|
||||
# ------------------------------ #
|
||||
def connection_made(self, transport):
|
||||
"""
|
||||
Generic connection-made, with no connection_task, and no recv_buffer.
|
||||
Override this for protocol-specific connection implementations.
|
||||
"""
|
||||
try:
|
||||
transport.set_write_buffer_limits(low=16384, high=65536)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.conn_info = ConnInfo(self.transport, unix=self._unix)
|
||||
except Exception:
|
||||
error_logger.exception("protocol.connect_made")
|
||||
|
||||
def connection_lost(self, exc):
|
||||
try:
|
||||
self.connections.discard(self)
|
||||
self.resume_writing()
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
except BaseException:
|
||||
error_logger.exception("protocol.connection_lost")
|
||||
|
||||
def pause_writing(self):
|
||||
self._can_write.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
self._can_write.set()
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
try:
|
||||
self._time = current_time()
|
||||
if not data:
|
||||
return self.close()
|
||||
|
||||
if self._data_received:
|
||||
self._data_received.set()
|
||||
except BaseException:
|
||||
error_logger.exception("protocol.data_received")
|
||||
238
sanic/server/protocols/http_protocol.py
Normal file
238
sanic/server/protocols/http_protocol.py
Normal file
@@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sanic.touchup.meta import TouchUpMeta
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.app import Sanic
|
||||
|
||||
from asyncio import CancelledError
|
||||
from time import monotonic as current_time
|
||||
|
||||
from sanic.exceptions import RequestTimeout, ServiceUnavailable
|
||||
from sanic.http import Http, Stage
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.models.server_types import ConnInfo
|
||||
from sanic.request import Request
|
||||
from sanic.server.protocols.base_protocol import SanicProtocol
|
||||
|
||||
|
||||
class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
|
||||
"""
|
||||
This class provides implements the HTTP 1.1 protocol on top of our
|
||||
Sanic Server transport
|
||||
"""
|
||||
|
||||
__touchup__ = (
|
||||
"send",
|
||||
"connection_task",
|
||||
)
|
||||
__slots__ = (
|
||||
# request params
|
||||
"request",
|
||||
# request config
|
||||
"request_handler",
|
||||
"request_timeout",
|
||||
"response_timeout",
|
||||
"keep_alive_timeout",
|
||||
"request_max_size",
|
||||
"request_class",
|
||||
"error_handler",
|
||||
# enable or disable access log purpose
|
||||
"access_log",
|
||||
# connection management
|
||||
"state",
|
||||
"url",
|
||||
"_handler_task",
|
||||
"_http",
|
||||
"_exception",
|
||||
"recv_buffer",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop,
|
||||
app: Sanic,
|
||||
signal=None,
|
||||
connections=None,
|
||||
state=None,
|
||||
unix=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
loop=loop,
|
||||
app=app,
|
||||
signal=signal,
|
||||
connections=connections,
|
||||
unix=unix,
|
||||
)
|
||||
self.url = None
|
||||
self.request: Optional[Request] = None
|
||||
self.access_log = self.app.config.ACCESS_LOG
|
||||
self.request_handler = self.app.handle_request
|
||||
self.error_handler = self.app.error_handler
|
||||
self.request_timeout = self.app.config.REQUEST_TIMEOUT
|
||||
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
|
||||
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
|
||||
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
|
||||
self.request_class = self.app.request_class or Request
|
||||
self.state = state if state else {}
|
||||
if "requests_count" not in self.state:
|
||||
self.state["requests_count"] = 0
|
||||
self._exception = None
|
||||
|
||||
def _setup_connection(self):
|
||||
self._http = Http(self)
|
||||
self._time = current_time()
|
||||
self.check_timeouts()
|
||||
|
||||
async def connection_task(self): # no cov
|
||||
"""
|
||||
Run a HTTP connection.
|
||||
|
||||
Timeouts and some additional error handling occur here, while most of
|
||||
everything else happens in class Http or in code called from there.
|
||||
"""
|
||||
try:
|
||||
self._setup_connection()
|
||||
await self.app.dispatch(
|
||||
"http.lifecycle.begin",
|
||||
inline=True,
|
||||
context={"conn_info": self.conn_info},
|
||||
)
|
||||
await self._http.http1()
|
||||
except CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
error_logger.exception("protocol.connection_task uncaught")
|
||||
finally:
|
||||
if (
|
||||
self.app.debug
|
||||
and self._http
|
||||
and self.transport
|
||||
and not self._http.upgrade_websocket
|
||||
):
|
||||
ip = self.transport.get_extra_info("peername")
|
||||
error_logger.error(
|
||||
"Connection lost before response written"
|
||||
f" @ {ip} {self._http.request}"
|
||||
)
|
||||
self._http = None
|
||||
self._task = None
|
||||
try:
|
||||
self.close()
|
||||
except BaseException:
|
||||
error_logger.exception("Closing failed")
|
||||
finally:
|
||||
await self.app.dispatch(
|
||||
"http.lifecycle.complete",
|
||||
inline=True,
|
||||
context={"conn_info": self.conn_info},
|
||||
)
|
||||
# Important to keep this Ellipsis here for the TouchUp module
|
||||
...
|
||||
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Runs itself periodically to enforce any expired timeouts.
|
||||
"""
|
||||
try:
|
||||
if not self._task:
|
||||
return
|
||||
duration = current_time() - self._time
|
||||
stage = self._http.stage
|
||||
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
|
||||
logger.debug("KeepAlive Timeout. Closing connection.")
|
||||
elif stage is Stage.REQUEST and duration > self.request_timeout:
|
||||
logger.debug("Request Timeout. Closing connection.")
|
||||
self._http.exception = RequestTimeout("Request Timeout")
|
||||
elif stage is Stage.HANDLER and self._http.upgrade_websocket:
|
||||
logger.debug("Handling websocket. Timeouts disabled.")
|
||||
return
|
||||
elif (
|
||||
stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED)
|
||||
and duration > self.response_timeout
|
||||
):
|
||||
logger.debug("Response Timeout. Closing connection.")
|
||||
self._http.exception = ServiceUnavailable("Response Timeout")
|
||||
else:
|
||||
interval = (
|
||||
min(
|
||||
self.keep_alive_timeout,
|
||||
self.request_timeout,
|
||||
self.response_timeout,
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
self.loop.call_later(max(0.1, interval), self.check_timeouts)
|
||||
return
|
||||
self._task.cancel()
|
||||
except Exception:
|
||||
error_logger.exception("protocol.check_timeouts")
|
||||
|
||||
async def send(self, data): # no cov
|
||||
"""
|
||||
Writes HTTP data with backpressure control.
|
||||
"""
|
||||
await self._can_write.wait()
|
||||
if self.transport.is_closing():
|
||||
raise CancelledError
|
||||
await self.app.dispatch(
|
||||
"http.lifecycle.send",
|
||||
inline=True,
|
||||
context={"data": data},
|
||||
)
|
||||
self.transport.write(data)
|
||||
self._time = current_time()
|
||||
|
||||
def close_if_idle(self) -> bool:
|
||||
"""
|
||||
Close the connection if a request is not being sent or received
|
||||
|
||||
:return: boolean - True if closed, false if staying open
|
||||
"""
|
||||
if self._http is None or self._http.stage is Stage.IDLE:
|
||||
self.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Only asyncio.Protocol callbacks below this
|
||||
# -------------------------------------------- #
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""
|
||||
HTTP-protocol-specific new connection handler
|
||||
"""
|
||||
try:
|
||||
# TODO: Benchmark to find suitable write buffer limits
|
||||
transport.set_write_buffer_limits(low=16384, high=65536)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self._task = self.loop.create_task(self.connection_task())
|
||||
self.recv_buffer = bytearray()
|
||||
self.conn_info = ConnInfo(self.transport, unix=self._unix)
|
||||
except Exception:
|
||||
error_logger.exception("protocol.connect_made")
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
|
||||
try:
|
||||
self._time = current_time()
|
||||
if not data:
|
||||
return self.close()
|
||||
self.recv_buffer += data
|
||||
|
||||
if (
|
||||
len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE
|
||||
and self.transport
|
||||
):
|
||||
self.transport.pause_reading()
|
||||
|
||||
if self._data_received:
|
||||
self._data_received.set()
|
||||
except Exception:
|
||||
error_logger.exception("protocol.data_received")
|
||||
164
sanic/server/protocols/websocket_protocol.py
Normal file
164
sanic/server/protocols/websocket_protocol.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
from websockets.connection import CLOSED, CLOSING, OPEN
|
||||
from websockets.server import ServerConnection
|
||||
|
||||
from sanic.exceptions import ServerError
|
||||
from sanic.log import error_logger
|
||||
from sanic.server import HttpProtocol
|
||||
|
||||
from ..websockets.impl import WebsocketImplProtocol
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from websockets import http11
|
||||
|
||||
|
||||
class WebSocketProtocol(HttpProtocol):
|
||||
|
||||
websocket: Optional[WebsocketImplProtocol]
|
||||
websocket_timeout: float
|
||||
websocket_max_size = Optional[int]
|
||||
websocket_ping_interval = Optional[float]
|
||||
websocket_ping_timeout = Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
websocket_timeout: float = 10.0,
|
||||
websocket_max_size: Optional[int] = None,
|
||||
websocket_max_queue: Optional[int] = None, # max_queue is deprecated
|
||||
websocket_read_limit: Optional[int] = None, # read_limit is deprecated
|
||||
websocket_write_limit: Optional[int] = None, # write_limit deprecated
|
||||
websocket_ping_interval: Optional[float] = 20.0,
|
||||
websocket_ping_timeout: Optional[float] = 20.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.websocket = None
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_max_size = websocket_max_size
|
||||
if websocket_max_queue is not None and websocket_max_queue > 0:
|
||||
# TODO: Reminder remove this warning in v22.3
|
||||
error_logger.warning(
|
||||
DeprecationWarning(
|
||||
"Websocket no longer uses queueing, so websocket_max_queue"
|
||||
" is no longer required."
|
||||
)
|
||||
)
|
||||
if websocket_read_limit is not None and websocket_read_limit > 0:
|
||||
# TODO: Reminder remove this warning in v22.3
|
||||
error_logger.warning(
|
||||
DeprecationWarning(
|
||||
"Websocket no longer uses read buffers, so "
|
||||
"websocket_read_limit is not required."
|
||||
)
|
||||
)
|
||||
if websocket_write_limit is not None and websocket_write_limit > 0:
|
||||
# TODO: Reminder remove this warning in v22.3
|
||||
error_logger.warning(
|
||||
DeprecationWarning(
|
||||
"Websocket no longer uses write buffers, so "
|
||||
"websocket_write_limit is not required."
|
||||
)
|
||||
)
|
||||
self.websocket_ping_interval = websocket_ping_interval
|
||||
self.websocket_ping_timeout = websocket_ping_timeout
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self.websocket is not None:
|
||||
self.websocket.connection_lost(exc)
|
||||
super().connection_lost(exc)
|
||||
|
||||
def data_received(self, data):
|
||||
if self.websocket is not None:
|
||||
self.websocket.data_received(data)
|
||||
else:
|
||||
# Pass it to HttpProtocol handler first
|
||||
# That will (hopefully) upgrade it to a websocket.
|
||||
super().data_received(data)
|
||||
|
||||
def eof_received(self) -> Optional[bool]:
|
||||
if self.websocket is not None:
|
||||
return self.websocket.eof_received()
|
||||
else:
|
||||
return False
|
||||
|
||||
def close(self, timeout: Optional[float] = None):
|
||||
# Called by HttpProtocol at the end of connection_task
|
||||
# If we've upgraded to websocket, we do our own closing
|
||||
if self.websocket is not None:
|
||||
# Note, we don't want to use websocket.close()
|
||||
# That is used for user's application code to send a
|
||||
# websocket close packet. This is different.
|
||||
self.websocket.end_connection(1001)
|
||||
else:
|
||||
super().close()
|
||||
|
||||
def close_if_idle(self):
|
||||
# Called by Sanic Server when shutting down
|
||||
# If we've upgraded to websocket, shut it down
|
||||
if self.websocket is not None:
|
||||
if self.websocket.connection.state in (CLOSING, CLOSED):
|
||||
return True
|
||||
elif self.websocket.loop is not None:
|
||||
self.websocket.loop.create_task(self.websocket.close(1001))
|
||||
else:
|
||||
self.websocket.end_connection(1001)
|
||||
else:
|
||||
return super().close_if_idle()
|
||||
|
||||
async def websocket_handshake(
|
||||
self, request, subprotocols=Optional[Sequence[str]]
|
||||
):
|
||||
# let the websockets package do the handshake with the client
|
||||
try:
|
||||
if subprotocols is not None:
|
||||
# subprotocols can be a set or frozenset,
|
||||
# but ServerConnection needs a list
|
||||
subprotocols = list(subprotocols)
|
||||
ws_conn = ServerConnection(
|
||||
max_size=self.websocket_max_size,
|
||||
subprotocols=subprotocols,
|
||||
state=OPEN,
|
||||
logger=error_logger,
|
||||
)
|
||||
resp: "http11.Response" = ws_conn.accept(request)
|
||||
except Exception:
|
||||
msg = (
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
)
|
||||
raise ServerError(msg, status_code=500)
|
||||
if 100 <= resp.status_code <= 299:
|
||||
rbody = "".join(
|
||||
[
|
||||
"HTTP/1.1 ",
|
||||
str(resp.status_code),
|
||||
" ",
|
||||
resp.reason_phrase,
|
||||
"\r\n",
|
||||
]
|
||||
)
|
||||
rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
|
||||
if resp.body is not None:
|
||||
rbody += f"\r\n{resp.body}\r\n\r\n"
|
||||
else:
|
||||
rbody += "\r\n"
|
||||
await super().send(rbody.encode())
|
||||
else:
|
||||
raise ServerError(resp.body, resp.status_code)
|
||||
self.websocket = WebsocketImplProtocol(
|
||||
ws_conn,
|
||||
ping_interval=self.websocket_ping_interval,
|
||||
ping_timeout=self.websocket_ping_timeout,
|
||||
close_timeout=self.websocket_timeout,
|
||||
)
|
||||
loop = (
|
||||
request.transport.loop
|
||||
if hasattr(request, "transport")
|
||||
and hasattr(request.transport, "loop")
|
||||
else None
|
||||
)
|
||||
await self.websocket.connection_made(self, loop=loop)
|
||||
return self.websocket
|
||||
280
sanic/server/runners.py
Normal file
280
sanic/server/runners.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ssl import SSLContext
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
|
||||
|
||||
from sanic.config import Config
|
||||
from sanic.server.events import trigger_events
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.app import Sanic
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import os
|
||||
import socket
|
||||
|
||||
from functools import partial
|
||||
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
||||
from signal import signal as signal_func
|
||||
|
||||
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.models.server_types import Signal
|
||||
from sanic.server.async_server import AsyncioServer
|
||||
from sanic.server.protocols.http_protocol import HttpProtocol
|
||||
from sanic.server.socket import (
|
||||
bind_socket,
|
||||
bind_unix_socket,
|
||||
remove_unix_socket,
|
||||
)
|
||||
|
||||
|
||||
def serve(
|
||||
host,
|
||||
port,
|
||||
app: Sanic,
|
||||
ssl: Optional[SSLContext] = None,
|
||||
sock: Optional[socket.socket] = None,
|
||||
unix: Optional[str] = None,
|
||||
reuse_port: bool = False,
|
||||
loop=None,
|
||||
protocol: Type[asyncio.Protocol] = HttpProtocol,
|
||||
backlog: int = 100,
|
||||
register_sys_signals: bool = True,
|
||||
run_multiple: bool = False,
|
||||
run_async: bool = False,
|
||||
connections=None,
|
||||
signal=Signal(),
|
||||
state=None,
|
||||
asyncio_server_kwargs=None,
|
||||
):
|
||||
"""Start asynchronous HTTP Server on an individual process.
|
||||
|
||||
:param host: Address to host on
|
||||
:param port: Port to host on
|
||||
:param before_start: function to be executed before the server starts
|
||||
listening. Takes arguments `app` instance and `loop`
|
||||
:param after_start: function to be executed after the server starts
|
||||
listening. Takes arguments `app` instance and `loop`
|
||||
:param before_stop: function to be executed when a stop signal is
|
||||
received before it is respected. Takes arguments
|
||||
`app` instance and `loop`
|
||||
:param after_stop: function to be executed when a stop signal is
|
||||
received after it is respected. Takes arguments
|
||||
`app` instance and `loop`
|
||||
:param ssl: SSLContext
|
||||
:param sock: Socket for the server to accept connections from
|
||||
:param unix: Unix socket to listen on instead of TCP port
|
||||
:param reuse_port: `True` for multiple workers
|
||||
:param loop: asyncio compatible event loop
|
||||
:param run_async: bool: Do not create a new event loop for the server,
|
||||
and return an AsyncServer object rather than running it
|
||||
:param asyncio_server_kwargs: key-value args for asyncio/uvloop
|
||||
create_server method
|
||||
:return: Nothing
|
||||
"""
|
||||
if not run_async and not loop:
|
||||
# create new event_loop after fork
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if app.debug:
|
||||
loop.set_debug(app.debug)
|
||||
|
||||
app.asgi = False
|
||||
|
||||
connections = connections if connections is not None else set()
|
||||
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
|
||||
server = partial(
|
||||
protocol,
|
||||
loop=loop,
|
||||
connections=connections,
|
||||
signal=signal,
|
||||
app=app,
|
||||
state=state,
|
||||
unix=unix,
|
||||
**protocol_kwargs,
|
||||
)
|
||||
asyncio_server_kwargs = (
|
||||
asyncio_server_kwargs if asyncio_server_kwargs else {}
|
||||
)
|
||||
# UNIX sockets are always bound by us (to preserve semantics between modes)
|
||||
if unix:
|
||||
sock = bind_unix_socket(unix, backlog=backlog)
|
||||
server_coroutine = loop.create_server(
|
||||
server,
|
||||
None if sock else host,
|
||||
None if sock else port,
|
||||
ssl=ssl,
|
||||
reuse_port=reuse_port,
|
||||
sock=sock,
|
||||
backlog=backlog,
|
||||
**asyncio_server_kwargs,
|
||||
)
|
||||
|
||||
if run_async:
|
||||
return AsyncioServer(
|
||||
app=app,
|
||||
loop=loop,
|
||||
serve_coro=server_coroutine,
|
||||
connections=connections,
|
||||
)
|
||||
|
||||
loop.run_until_complete(app._startup())
|
||||
loop.run_until_complete(app._server_event("init", "before"))
|
||||
|
||||
try:
|
||||
http_server = loop.run_until_complete(server_coroutine)
|
||||
except BaseException:
|
||||
error_logger.exception("Unable to start server")
|
||||
return
|
||||
|
||||
# Ignore SIGINT when run_multiple
|
||||
if run_multiple:
|
||||
signal_func(SIGINT, SIG_IGN)
|
||||
|
||||
# Register signals for graceful termination
|
||||
if register_sys_signals:
|
||||
if OS_IS_WINDOWS:
|
||||
ctrlc_workaround_for_windows(app)
|
||||
else:
|
||||
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
|
||||
loop.add_signal_handler(_signal, app.stop)
|
||||
|
||||
loop.run_until_complete(app._server_event("init", "after"))
|
||||
pid = os.getpid()
|
||||
try:
|
||||
logger.info("Starting worker [%s]", pid)
|
||||
loop.run_forever()
|
||||
finally:
|
||||
logger.info("Stopping worker [%s]", pid)
|
||||
|
||||
# Run the on_stop function if provided
|
||||
loop.run_until_complete(app._server_event("shutdown", "before"))
|
||||
|
||||
# Wait for event loop to finish and all connections to drain
|
||||
http_server.close()
|
||||
loop.run_until_complete(http_server.wait_closed())
|
||||
|
||||
# Complete all tasks on the loop
|
||||
signal.stopped = True
|
||||
for connection in connections:
|
||||
connection.close_if_idle()
|
||||
|
||||
# Gracefully shutdown timeout.
|
||||
# We should provide graceful_shutdown_timeout,
|
||||
# instead of letting connection hangs forever.
|
||||
# Let's roughly calcucate time.
|
||||
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
start_shutdown: float = 0
|
||||
while connections and (start_shutdown < graceful):
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
start_shutdown = start_shutdown + 0.1
|
||||
|
||||
# Force close non-idle connection after waiting for
|
||||
# graceful_shutdown_timeout
|
||||
for conn in connections:
|
||||
if hasattr(conn, "websocket") and conn.websocket:
|
||||
conn.websocket.fail_connection(code=1001)
|
||||
else:
|
||||
conn.abort()
|
||||
loop.run_until_complete(app._server_event("shutdown", "after"))
|
||||
|
||||
remove_unix_socket(unix)
|
||||
|
||||
|
||||
def serve_single(server_settings):
|
||||
main_start = server_settings.pop("main_start", None)
|
||||
main_stop = server_settings.pop("main_stop", None)
|
||||
|
||||
if not server_settings.get("run_async"):
|
||||
# create new event_loop after fork
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_settings["loop"] = loop
|
||||
|
||||
trigger_events(main_start, server_settings["loop"])
|
||||
serve(**server_settings)
|
||||
trigger_events(main_stop, server_settings["loop"])
|
||||
|
||||
server_settings["loop"].close()
|
||||
|
||||
|
||||
def serve_multiple(server_settings, workers):
|
||||
"""Start multiple server processes simultaneously. Stop on interrupt
|
||||
and terminate signals, and drain connections when complete.
|
||||
|
||||
:param server_settings: kw arguments to be passed to the serve function
|
||||
:param workers: number of workers to launch
|
||||
:param stop_event: if provided, is used as a stop signal
|
||||
:return:
|
||||
"""
|
||||
server_settings["reuse_port"] = True
|
||||
server_settings["run_multiple"] = True
|
||||
|
||||
main_start = server_settings.pop("main_start", None)
|
||||
main_stop = server_settings.pop("main_stop", None)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
trigger_events(main_start, loop)
|
||||
|
||||
# Create a listening socket or use the one in settings
|
||||
sock = server_settings.get("sock")
|
||||
unix = server_settings["unix"]
|
||||
backlog = server_settings["backlog"]
|
||||
if unix:
|
||||
sock = bind_unix_socket(unix, backlog=backlog)
|
||||
server_settings["unix"] = unix
|
||||
if sock is None:
|
||||
sock = bind_socket(
|
||||
server_settings["host"], server_settings["port"], backlog=backlog
|
||||
)
|
||||
sock.set_inheritable(True)
|
||||
server_settings["sock"] = sock
|
||||
server_settings["host"] = None
|
||||
server_settings["port"] = None
|
||||
|
||||
processes = []
|
||||
|
||||
def sig_handler(signal, frame):
|
||||
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
|
||||
for process in processes:
|
||||
os.kill(process.pid, SIGTERM)
|
||||
|
||||
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
|
||||
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
|
||||
mp = multiprocessing.get_context("fork")
|
||||
|
||||
for _ in range(workers):
|
||||
process = mp.Process(target=serve, kwargs=server_settings)
|
||||
process.daemon = True
|
||||
process.start()
|
||||
processes.append(process)
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
|
||||
# the above processes will block this until they're stopped
|
||||
for process in processes:
|
||||
process.terminate()
|
||||
|
||||
trigger_events(main_stop, loop)
|
||||
|
||||
sock.close()
|
||||
loop.close()
|
||||
remove_unix_socket(unix)
|
||||
|
||||
|
||||
def _build_protocol_kwargs(
|
||||
protocol: Type[asyncio.Protocol], config: Config
|
||||
) -> Dict[str, Union[int, float]]:
|
||||
if hasattr(protocol, "websocket_handshake"):
|
||||
return {
|
||||
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
|
||||
"websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT,
|
||||
"websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL,
|
||||
}
|
||||
return {}
|
||||
87
sanic/server/socket.py
Normal file
87
sanic/server/socket.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import stat
|
||||
|
||||
from ipaddress import ip_address
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
|
||||
"""Create TCP server socket.
|
||||
:param host: IPv4, IPv6 or hostname may be specified
|
||||
:param port: TCP port number
|
||||
:param backlog: Maximum number of connections to queue
|
||||
:return: socket.socket object
|
||||
"""
|
||||
try: # IP address: family must be specified for IPv6 at least
|
||||
ip = ip_address(host)
|
||||
host = str(ip)
|
||||
sock = socket.socket(
|
||||
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
|
||||
)
|
||||
except ValueError: # Hostname, may become AF_INET or AF_INET6
|
||||
sock = socket.socket()
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind((host, port))
|
||||
sock.listen(backlog)
|
||||
return sock
|
||||
|
||||
|
||||
def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
|
||||
"""Create unix socket.
|
||||
:param path: filesystem path
|
||||
:param backlog: Maximum number of connections to queue
|
||||
:return: socket.socket object
|
||||
"""
|
||||
"""Open or atomically replace existing socket with zero downtime."""
|
||||
# Sanitise and pre-verify socket path
|
||||
path = os.path.abspath(path)
|
||||
folder = os.path.dirname(path)
|
||||
if not os.path.isdir(folder):
|
||||
raise FileNotFoundError(f"Socket folder does not exist: {folder}")
|
||||
try:
|
||||
if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
|
||||
raise FileExistsError(f"Existing file is not a socket: {path}")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# Create new socket with a random temporary name
|
||||
tmp_path = f"{path}.{secrets.token_urlsafe()}"
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
try:
|
||||
# Critical section begins (filename races)
|
||||
sock.bind(tmp_path)
|
||||
try:
|
||||
os.chmod(tmp_path, mode)
|
||||
# Start listening before rename to avoid connection failures
|
||||
sock.listen(backlog)
|
||||
os.rename(tmp_path, path)
|
||||
except: # noqa: E722
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
finally:
|
||||
raise
|
||||
except: # noqa: E722
|
||||
try:
|
||||
sock.close()
|
||||
finally:
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
def remove_unix_socket(path: Optional[str]) -> None:
|
||||
"""Remove dead unix socket during server exit."""
|
||||
if not path:
|
||||
return
|
||||
try:
|
||||
if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
|
||||
# Is it actually dead (doesn't belong to a new server instance)?
|
||||
with socket.socket(socket.AF_UNIX) as testsock:
|
||||
try:
|
||||
testsock.connect(path)
|
||||
except ConnectionRefusedError:
|
||||
os.unlink(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
0
sanic/server/websockets/__init__.py
Normal file
0
sanic/server/websockets/__init__.py
Normal file
82
sanic/server/websockets/connection.py
Normal file
82
sanic/server/websockets/connection.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
ASIMessage = MutableMapping[str, Any]
|
||||
|
||||
|
||||
class WebSocketConnection:
|
||||
"""
|
||||
This is for ASGI Connections.
|
||||
It provides an interface similar to WebsocketProtocol, but
|
||||
sends/receives over an ASGI connection.
|
||||
"""
|
||||
|
||||
# TODO
|
||||
# - Implement ping/pong
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send: Callable[[ASIMessage], Awaitable[None]],
|
||||
receive: Callable[[], Awaitable[ASIMessage]],
|
||||
subprotocols: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self._send = send
|
||||
self._receive = receive
|
||||
self._subprotocols = subprotocols or []
|
||||
|
||||
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
||||
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
||||
|
||||
if isinstance(data, bytes):
|
||||
message.update({"bytes": data})
|
||||
else:
|
||||
message.update({"text": str(data)})
|
||||
|
||||
await self._send(message)
|
||||
|
||||
async def recv(self, *args, **kwargs) -> Optional[str]:
|
||||
message = await self._receive()
|
||||
|
||||
if message["type"] == "websocket.receive":
|
||||
return message["text"]
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
receive = recv
|
||||
|
||||
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
|
||||
subprotocol = None
|
||||
if subprotocols:
|
||||
for subp in subprotocols:
|
||||
if subp in self.subprotocols:
|
||||
subprotocol = subp
|
||||
break
|
||||
|
||||
await self._send(
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
"subprotocol": subprotocol,
|
||||
}
|
||||
)
|
||||
|
||||
async def close(self, code: int = 1000, reason: str = "") -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def subprotocols(self):
|
||||
return self._subprotocols
|
||||
|
||||
@subprotocols.setter
|
||||
def subprotocols(self, subprotocols: Optional[List[str]] = None):
|
||||
self._subprotocols = subprotocols or []
|
||||
294
sanic/server/websockets/frame.py
Normal file
294
sanic/server/websockets/frame.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
|
||||
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
|
||||
|
||||
from websockets.frames import Frame, Opcode
|
||||
from websockets.typing import Data
|
||||
|
||||
from sanic.exceptions import ServerError
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .impl import WebsocketImplProtocol
|
||||
|
||||
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
|
||||
|
||||
|
||||
class WebsocketFrameAssembler:
|
||||
"""
|
||||
Assemble a message from frames.
|
||||
Code borrowed from aaugustin/websockets project:
|
||||
https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"protocol",
|
||||
"read_mutex",
|
||||
"write_mutex",
|
||||
"message_complete",
|
||||
"message_fetched",
|
||||
"get_in_progress",
|
||||
"decoder",
|
||||
"completed_queue",
|
||||
"chunks",
|
||||
"chunks_queue",
|
||||
"paused",
|
||||
"get_id",
|
||||
"put_id",
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
protocol: "WebsocketImplProtocol"
|
||||
read_mutex: asyncio.Lock
|
||||
write_mutex: asyncio.Lock
|
||||
message_complete: asyncio.Event
|
||||
message_fetched: asyncio.Event
|
||||
completed_queue: asyncio.Queue
|
||||
get_in_progress: bool
|
||||
decoder: Optional[codecs.IncrementalDecoder]
|
||||
# For streaming chunks rather than messages:
|
||||
chunks: List[Data]
|
||||
chunks_queue: Optional[asyncio.Queue[Optional[Data]]]
|
||||
paused: bool
|
||||
|
||||
def __init__(self, protocol) -> None:
|
||||
|
||||
self.protocol = protocol
|
||||
|
||||
self.read_mutex = asyncio.Lock()
|
||||
self.write_mutex = asyncio.Lock()
|
||||
|
||||
self.completed_queue = asyncio.Queue(
|
||||
maxsize=1
|
||||
) # type: asyncio.Queue[Data]
|
||||
|
||||
# put() sets this event to tell get() that a message can be fetched.
|
||||
self.message_complete = asyncio.Event()
|
||||
# get() sets this event to let put()
|
||||
self.message_fetched = asyncio.Event()
|
||||
|
||||
# This flag prevents concurrent calls to get() by user code.
|
||||
self.get_in_progress = False
|
||||
|
||||
# Decoder for text frames, None for binary frames.
|
||||
self.decoder = None
|
||||
|
||||
# Buffer data from frames belonging to the same message.
|
||||
self.chunks = []
|
||||
|
||||
# When switching from "buffering" to "streaming", we use a thread-safe
|
||||
# queue for transferring frames from the writing thread (library code)
|
||||
# to the reading thread (user code). We're buffering when chunks_queue
|
||||
# is None and streaming when it's a Queue. None is a sentinel
|
||||
# value marking the end of the stream, superseding message_complete.
|
||||
|
||||
# Stream data from frames belonging to the same message.
|
||||
self.chunks_queue = None
|
||||
|
||||
# Flag to indicate we've paused the protocol
|
||||
self.paused = False
|
||||
|
||||
async def get(self, timeout: Optional[float] = None) -> Optional[Data]:
|
||||
"""
|
||||
Read the next message.
|
||||
:meth:`get` returns a single :class:`str` or :class:`bytes`.
|
||||
If the :message was fragmented, :meth:`get` waits until the last frame
|
||||
is received, then it reassembles the message.
|
||||
If ``timeout`` is set and elapses before a complete message is
|
||||
received, :meth:`get` returns ``None``.
|
||||
"""
|
||||
async with self.read_mutex:
|
||||
if timeout is not None and timeout <= 0:
|
||||
if not self.message_complete.is_set():
|
||||
return None
|
||||
if self.get_in_progress:
|
||||
# This should be guarded against with the read_mutex,
|
||||
# exception is only here as a failsafe
|
||||
raise ServerError(
|
||||
"Called get() on Websocket frame assembler "
|
||||
"while asynchronous get is already in progress."
|
||||
)
|
||||
self.get_in_progress = True
|
||||
|
||||
# If the message_complete event isn't set yet, release the lock to
|
||||
# allow put() to run and eventually set it.
|
||||
# Locking with get_in_progress ensures only one task can get here.
|
||||
if timeout is None:
|
||||
completed = await self.message_complete.wait()
|
||||
elif timeout <= 0:
|
||||
completed = self.message_complete.is_set()
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.message_complete.wait(), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
...
|
||||
finally:
|
||||
completed = self.message_complete.is_set()
|
||||
|
||||
# Unpause the transport, if its paused
|
||||
if self.paused:
|
||||
self.protocol.resume_frames()
|
||||
self.paused = False
|
||||
if not self.get_in_progress:
|
||||
# This should be guarded against with the read_mutex,
|
||||
# exception is here as a failsafe
|
||||
raise ServerError(
|
||||
"State of Websocket frame assembler was modified while an "
|
||||
"asynchronous get was in progress."
|
||||
)
|
||||
self.get_in_progress = False
|
||||
|
||||
# Waiting for a complete message timed out.
|
||||
if not completed:
|
||||
return None
|
||||
if not self.message_complete.is_set():
|
||||
return None
|
||||
|
||||
self.message_complete.clear()
|
||||
|
||||
joiner: Data = b"" if self.decoder is None else ""
|
||||
# mypy cannot figure out that chunks have the proper type.
|
||||
message: Data = joiner.join(self.chunks) # type: ignore
|
||||
if self.message_fetched.is_set():
|
||||
# This should be guarded against with the read_mutex,
|
||||
# and get_in_progress check, this exception is here
|
||||
# as a failsafe
|
||||
raise ServerError(
|
||||
"Websocket get() found a message when "
|
||||
"state was already fetched."
|
||||
)
|
||||
self.message_fetched.set()
|
||||
self.chunks = []
|
||||
# this should already be None, but set it here for safety
|
||||
self.chunks_queue = None
|
||||
return message
|
||||
|
||||
async def get_iter(self) -> AsyncIterator[Data]:
|
||||
"""
|
||||
Stream the next message.
|
||||
Iterating the return value of :meth:`get_iter` yields a :class:`str`
|
||||
or :class:`bytes` for each frame in the message.
|
||||
"""
|
||||
async with self.read_mutex:
|
||||
if self.get_in_progress:
|
||||
# This should be guarded against with the read_mutex,
|
||||
# exception is only here as a failsafe
|
||||
raise ServerError(
|
||||
"Called get_iter on Websocket frame assembler "
|
||||
"while asynchronous get is already in progress."
|
||||
)
|
||||
self.get_in_progress = True
|
||||
|
||||
chunks = self.chunks
|
||||
self.chunks = []
|
||||
self.chunks_queue = asyncio.Queue()
|
||||
|
||||
# Sending None in chunk_queue supersedes setting message_complete
|
||||
# when switching to "streaming". If message is already complete
|
||||
# when the switch happens, put() didn't send None, so we have to.
|
||||
if self.message_complete.is_set():
|
||||
await self.chunks_queue.put(None)
|
||||
|
||||
# Locking with get_in_progress ensures only one task can get here
|
||||
for c in chunks:
|
||||
yield c
|
||||
while True:
|
||||
chunk = await self.chunks_queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
# Unpause the transport, if its paused
|
||||
if self.paused:
|
||||
self.protocol.resume_frames()
|
||||
self.paused = False
|
||||
if not self.get_in_progress:
|
||||
# This should be guarded against with the read_mutex,
|
||||
# exception is here as a failsafe
|
||||
raise ServerError(
|
||||
"State of Websocket frame assembler was modified while an "
|
||||
"asynchronous get was in progress."
|
||||
)
|
||||
self.get_in_progress = False
|
||||
if not self.message_complete.is_set():
|
||||
# This should be guarded against with the read_mutex,
|
||||
# exception is here as a failsafe
|
||||
raise ServerError(
|
||||
"Websocket frame assembler chunks queue ended before "
|
||||
"message was complete."
|
||||
)
|
||||
self.message_complete.clear()
|
||||
if self.message_fetched.is_set():
|
||||
# This should be guarded against with the read_mutex,
|
||||
# and get_in_progress check, this exception is
|
||||
# here as a failsafe
|
||||
raise ServerError(
|
||||
"Websocket get_iter() found a message when state was "
|
||||
"already fetched."
|
||||
)
|
||||
|
||||
self.message_fetched.set()
|
||||
# this should already be empty, but set it here for safety
|
||||
self.chunks = []
|
||||
self.chunks_queue = None
|
||||
|
||||
async def put(self, frame: Frame) -> None:
|
||||
"""
|
||||
Add ``frame`` to the next message.
|
||||
When ``frame`` is the final frame in a message, :meth:`put` waits
|
||||
until the message is fetched, either by calling :meth:`get` or by
|
||||
iterating the return value of :meth:`get_iter`.
|
||||
:meth:`put` assumes that the stream of frames respects the protocol.
|
||||
If it doesn't, the behavior is undefined.
|
||||
"""
|
||||
|
||||
async with self.write_mutex:
|
||||
if frame.opcode is Opcode.TEXT:
|
||||
self.decoder = UTF8Decoder(errors="strict")
|
||||
elif frame.opcode is Opcode.BINARY:
|
||||
self.decoder = None
|
||||
elif frame.opcode is Opcode.CONT:
|
||||
pass
|
||||
else:
|
||||
# Ignore control frames.
|
||||
return
|
||||
data: Data
|
||||
if self.decoder is not None:
|
||||
data = self.decoder.decode(frame.data, frame.fin)
|
||||
else:
|
||||
data = frame.data
|
||||
if self.chunks_queue is None:
|
||||
self.chunks.append(data)
|
||||
else:
|
||||
await self.chunks_queue.put(data)
|
||||
|
||||
if not frame.fin:
|
||||
return
|
||||
if not self.get_in_progress:
|
||||
# nobody is waiting for this frame, so try to pause subsequent
|
||||
# frames at the protocol level
|
||||
self.paused = self.protocol.pause_frames()
|
||||
# Message is complete. Wait until it's fetched to return.
|
||||
|
||||
if self.chunks_queue is not None:
|
||||
await self.chunks_queue.put(None)
|
||||
if self.message_complete.is_set():
|
||||
# This should be guarded against with the write_mutex
|
||||
raise ServerError(
|
||||
"Websocket put() got a new message when a message was "
|
||||
"already in its chamber."
|
||||
)
|
||||
self.message_complete.set() # Signal to get() it can serve the
|
||||
if self.message_fetched.is_set():
|
||||
# This should be guarded against with the write_mutex
|
||||
raise ServerError(
|
||||
"Websocket put() got a new message when the previous "
|
||||
"message was not yet fetched."
|
||||
)
|
||||
|
||||
# Allow get() to run and eventually set the event.
|
||||
await self.message_fetched.wait()
|
||||
self.message_fetched.clear()
|
||||
self.decoder = None
|
||||
834
sanic/server/websockets/impl.py
Normal file
834
sanic/server/websockets/impl.py
Normal file
@@ -0,0 +1,834 @@
|
||||
import asyncio
|
||||
import random
|
||||
import struct
|
||||
|
||||
from typing import (
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from websockets.connection import CLOSED, CLOSING, OPEN, Event
|
||||
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
||||
from websockets.frames import Frame, Opcode
|
||||
from websockets.server import ServerConnection
|
||||
from websockets.typing import Data
|
||||
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.server.protocols.base_protocol import SanicProtocol
|
||||
|
||||
from ...exceptions import ServerError, WebsocketClosed
|
||||
from .frame import WebsocketFrameAssembler
|
||||
|
||||
|
||||
class WebsocketImplProtocol:
|
||||
connection: ServerConnection
|
||||
io_proto: Optional[SanicProtocol]
|
||||
loop: Optional[asyncio.AbstractEventLoop]
|
||||
max_queue: int
|
||||
close_timeout: float
|
||||
ping_interval: Optional[float]
|
||||
ping_timeout: Optional[float]
|
||||
assembler: WebsocketFrameAssembler
|
||||
# Dict[bytes, asyncio.Future[None]]
|
||||
pings: Dict[bytes, asyncio.Future]
|
||||
conn_mutex: asyncio.Lock
|
||||
recv_lock: asyncio.Lock
|
||||
recv_cancel: Optional[asyncio.Future]
|
||||
process_event_mutex: asyncio.Lock
|
||||
can_pause: bool
|
||||
# Optional[asyncio.Future[None]]
|
||||
data_finished_fut: Optional[asyncio.Future]
|
||||
# Optional[asyncio.Future[None]]
|
||||
pause_frame_fut: Optional[asyncio.Future]
|
||||
# Optional[asyncio.Future[None]]
|
||||
connection_lost_waiter: Optional[asyncio.Future]
|
||||
keepalive_ping_task: Optional[asyncio.Task]
|
||||
auto_closer_task: Optional[asyncio.Task]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
max_queue=None,
|
||||
ping_interval: Optional[float] = 20,
|
||||
ping_timeout: Optional[float] = 20,
|
||||
close_timeout: float = 10,
|
||||
loop=None,
|
||||
):
|
||||
self.connection = connection
|
||||
self.io_proto = None
|
||||
self.loop = None
|
||||
self.max_queue = max_queue
|
||||
self.close_timeout = close_timeout
|
||||
self.ping_interval = ping_interval
|
||||
self.ping_timeout = ping_timeout
|
||||
self.assembler = WebsocketFrameAssembler(self)
|
||||
self.pings = {}
|
||||
self.conn_mutex = asyncio.Lock()
|
||||
self.recv_lock = asyncio.Lock()
|
||||
self.recv_cancel = None
|
||||
self.process_event_mutex = asyncio.Lock()
|
||||
self.data_finished_fut = None
|
||||
self.can_pause = True
|
||||
self.pause_frame_fut = None
|
||||
self.keepalive_ping_task = None
|
||||
self.auto_closer_task = None
|
||||
self.connection_lost_waiter = None
|
||||
|
||||
@property
|
||||
def subprotocol(self):
|
||||
return self.connection.subprotocol
|
||||
|
||||
def pause_frames(self):
|
||||
if not self.can_pause:
|
||||
return False
|
||||
if self.pause_frame_fut:
|
||||
logger.debug("Websocket connection already paused.")
|
||||
return False
|
||||
if (not self.loop) or (not self.io_proto):
|
||||
return False
|
||||
if self.io_proto.transport:
|
||||
self.io_proto.transport.pause_reading()
|
||||
self.pause_frame_fut = self.loop.create_future()
|
||||
logger.debug("Websocket connection paused.")
|
||||
return True
|
||||
|
||||
def resume_frames(self):
|
||||
if not self.pause_frame_fut:
|
||||
logger.debug("Websocket connection not paused.")
|
||||
return False
|
||||
if (not self.loop) or (not self.io_proto):
|
||||
logger.debug(
|
||||
"Websocket attempting to resume reading frames, "
|
||||
"but connection is gone."
|
||||
)
|
||||
return False
|
||||
if self.io_proto.transport:
|
||||
self.io_proto.transport.resume_reading()
|
||||
self.pause_frame_fut.set_result(None)
|
||||
self.pause_frame_fut = None
|
||||
logger.debug("Websocket connection unpaused.")
|
||||
return True
|
||||
|
||||
async def connection_made(
|
||||
self,
|
||||
io_proto: SanicProtocol,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
):
|
||||
if not loop:
|
||||
try:
|
||||
loop = getattr(io_proto, "loop")
|
||||
except AttributeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
if not loop:
|
||||
# This catch is for mypy type checker
|
||||
# to assert loop is not None here.
|
||||
raise ServerError("Connection received with no asyncio loop.")
|
||||
if self.auto_closer_task:
|
||||
raise ServerError(
|
||||
"Cannot call connection_made more than once "
|
||||
"on a websocket connection."
|
||||
)
|
||||
self.loop = loop
|
||||
self.io_proto = io_proto
|
||||
self.connection_lost_waiter = self.loop.create_future()
|
||||
self.data_finished_fut = asyncio.shield(self.loop.create_future())
|
||||
|
||||
if self.ping_interval:
|
||||
self.keepalive_ping_task = asyncio.create_task(
|
||||
self.keepalive_ping()
|
||||
)
|
||||
self.auto_closer_task = asyncio.create_task(
|
||||
self.auto_close_connection()
|
||||
)
|
||||
|
||||
async def wait_for_connection_lost(self, timeout=None) -> bool:
|
||||
"""
|
||||
Wait until the TCP connection is closed or ``timeout`` elapses.
|
||||
If timeout is None, wait forever.
|
||||
Recommend you should pass in self.close_timeout as timeout
|
||||
|
||||
Return ``True`` if the connection is closed and ``False`` otherwise.
|
||||
|
||||
"""
|
||||
if not self.connection_lost_waiter:
|
||||
return False
|
||||
if self.connection_lost_waiter.done():
|
||||
return True
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(self.connection_lost_waiter), timeout
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
# Re-check self.connection_lost_waiter.done() synchronously
|
||||
# because connection_lost() could run between the moment the
|
||||
# timeout occurs and the moment this coroutine resumes running
|
||||
return self.connection_lost_waiter.done()
|
||||
|
||||
async def process_events(self, events: Sequence[Event]) -> None:
|
||||
"""
|
||||
Process a list of incoming events.
|
||||
"""
|
||||
# Wrapped in a mutex lock, to prevent other incoming events
|
||||
# from processing at the same time
|
||||
async with self.process_event_mutex:
|
||||
for event in events:
|
||||
if not isinstance(event, Frame):
|
||||
# Event is not a frame. Ignore it.
|
||||
continue
|
||||
if event.opcode == Opcode.PONG:
|
||||
await self.process_pong(event)
|
||||
elif event.opcode == Opcode.CLOSE:
|
||||
if self.recv_cancel:
|
||||
self.recv_cancel.cancel()
|
||||
else:
|
||||
await self.assembler.put(event)
|
||||
|
||||
async def process_pong(self, frame: Frame) -> None:
|
||||
if frame.data in self.pings:
|
||||
# Acknowledge all pings up to the one matching this pong.
|
||||
ping_ids = []
|
||||
for ping_id, ping in self.pings.items():
|
||||
ping_ids.append(ping_id)
|
||||
if not ping.done():
|
||||
ping.set_result(None)
|
||||
if ping_id == frame.data:
|
||||
break
|
||||
else: # noqa
|
||||
raise ServerError("ping_id is not in self.pings")
|
||||
# Remove acknowledged pings from self.pings.
|
||||
for ping_id in ping_ids:
|
||||
del self.pings[ping_id]
|
||||
|
||||
async def keepalive_ping(self) -> None:
|
||||
"""
|
||||
Send a Ping frame and wait for a Pong frame at regular intervals.
|
||||
This coroutine exits when the connection terminates and one of the
|
||||
following happens:
|
||||
- :meth:`ping` raises :exc:`ConnectionClosed`, or
|
||||
- :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`.
|
||||
"""
|
||||
if self.ping_interval is None:
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(self.ping_interval)
|
||||
|
||||
# ping() raises CancelledError if the connection is closed,
|
||||
# when auto_close_connection() cancels keepalive_ping_task.
|
||||
|
||||
# ping() raises ConnectionClosed if the connection is lost,
|
||||
# when connection_lost() calls abort_pings().
|
||||
|
||||
ping_waiter = await self.ping()
|
||||
|
||||
if self.ping_timeout is not None:
|
||||
try:
|
||||
await asyncio.wait_for(ping_waiter, self.ping_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
error_logger.warning(
|
||||
"Websocket timed out waiting for pong"
|
||||
)
|
||||
self.fail_connection(1011)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
# It is expected for this task to be cancelled during during
|
||||
# normal operation, when the connection is closed.
|
||||
logger.debug("Websocket keepalive ping task was cancelled.")
|
||||
except (ConnectionClosed, WebsocketClosed):
|
||||
logger.debug("Websocket closed. Keepalive ping task exiting.")
|
||||
except Exception as e:
|
||||
error_logger.warning(
|
||||
"Unexpected exception in websocket keepalive ping task."
|
||||
)
|
||||
logger.debug(str(e))
|
||||
|
||||
def _force_disconnect(self) -> bool:
|
||||
"""
|
||||
Internal methdod used by end_connection and fail_connection
|
||||
only when the graceful auto-closer cannot be used
|
||||
"""
|
||||
if self.auto_closer_task and not self.auto_closer_task.done():
|
||||
self.auto_closer_task.cancel()
|
||||
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||
self.data_finished_fut.cancel()
|
||||
self.data_finished_fut = None
|
||||
if self.keepalive_ping_task and not self.keepalive_ping_task.done():
|
||||
self.keepalive_ping_task.cancel()
|
||||
self.keepalive_ping_task = None
|
||||
if self.loop and self.io_proto and self.io_proto.transport:
|
||||
self.io_proto.transport.close()
|
||||
self.loop.call_later(
|
||||
self.close_timeout, self.io_proto.transport.abort
|
||||
)
|
||||
# We were never open, or already closed
|
||||
return True
|
||||
|
||||
def fail_connection(self, code: int = 1006, reason: str = "") -> bool:
|
||||
"""
|
||||
Fail the WebSocket Connection
|
||||
This requires:
|
||||
1. Stopping all processing of incoming data, which means cancelling
|
||||
pausing the underlying io protocol. The close code will be 1006
|
||||
unless a close frame was received earlier.
|
||||
2. Sending a close frame with an appropriate code if the opening
|
||||
handshake succeeded and the other side is likely to process it.
|
||||
3. Closing the connection. :meth:`auto_close_connection` takes care
|
||||
of this.
|
||||
(The specification describes these steps in the opposite order.)
|
||||
"""
|
||||
if self.io_proto and self.io_proto.transport:
|
||||
# Stop new data coming in
|
||||
# In Python Version 3.7: pause_reading is idempotent
|
||||
# ut can be called when the transport is already paused or closed
|
||||
self.io_proto.transport.pause_reading()
|
||||
|
||||
# Keeping fail_connection() synchronous guarantees it can't
|
||||
# get stuck and simplifies the implementation of the callers.
|
||||
# Not draining the write buffer is acceptable in this context.
|
||||
|
||||
# clear the send buffer
|
||||
_ = self.connection.data_to_send()
|
||||
# If we're not already CLOSED or CLOSING, then send the close.
|
||||
if self.connection.state is OPEN:
|
||||
if code in (1000, 1001):
|
||||
self.connection.send_close(code, reason)
|
||||
else:
|
||||
self.connection.fail(code, reason)
|
||||
try:
|
||||
data_to_send = self.connection.data_to_send()
|
||||
while (
|
||||
len(data_to_send)
|
||||
and self.io_proto
|
||||
and self.io_proto.transport
|
||||
):
|
||||
frame_data = data_to_send.pop(0)
|
||||
self.io_proto.transport.write(frame_data)
|
||||
except Exception:
|
||||
# sending close frames may fail if the
|
||||
# transport closes during this period
|
||||
...
|
||||
if code == 1006:
|
||||
# Special case: 1006 consider the transport already closed
|
||||
self.connection.state = CLOSED
|
||||
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||
# We have a graceful auto-closer. Use it to close the connection.
|
||||
self.data_finished_fut.cancel()
|
||||
self.data_finished_fut = None
|
||||
if (not self.auto_closer_task) or self.auto_closer_task.done():
|
||||
return self._force_disconnect()
|
||||
return False
|
||||
|
||||
def end_connection(self, code=1000, reason=""):
|
||||
# This is like slightly more graceful form of fail_connection
|
||||
# Use this instead of close() when you need an immediate
|
||||
# close and cannot await websocket.close() handshake.
|
||||
|
||||
if code == 1006 or not self.io_proto or not self.io_proto.transport:
|
||||
return self.fail_connection(code, reason)
|
||||
|
||||
# Stop new data coming in
|
||||
# In Python Version 3.7: pause_reading is idempotent
|
||||
# i.e. it can be called when the transport is already paused or closed.
|
||||
self.io_proto.transport.pause_reading()
|
||||
if self.connection.state == OPEN:
|
||||
data_to_send = self.connection.data_to_send()
|
||||
self.connection.send_close(code, reason)
|
||||
data_to_send.extend(self.connection.data_to_send())
|
||||
try:
|
||||
while (
|
||||
len(data_to_send)
|
||||
and self.io_proto
|
||||
and self.io_proto.transport
|
||||
):
|
||||
frame_data = data_to_send.pop(0)
|
||||
self.io_proto.transport.write(frame_data)
|
||||
except Exception:
|
||||
# sending close frames may fail if the
|
||||
# transport closes during this period
|
||||
# But that doesn't matter at this point
|
||||
...
|
||||
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||
# We have the ability to signal the auto-closer
|
||||
# try to trigger it to auto-close the connection
|
||||
self.data_finished_fut.cancel()
|
||||
self.data_finished_fut = None
|
||||
if (not self.auto_closer_task) or self.auto_closer_task.done():
|
||||
# Auto-closer is not running, do force disconnect
|
||||
return self._force_disconnect()
|
||||
return False
|
||||
|
||||
async def auto_close_connection(self) -> None:
|
||||
"""
|
||||
Close the WebSocket Connection
|
||||
When the opening handshake succeeds, :meth:`connection_open` starts
|
||||
this coroutine in a task. It waits for the data transfer phase to
|
||||
complete then it closes the TCP connection cleanly.
|
||||
When the opening handshake fails, :meth:`fail_connection` does the
|
||||
same. There's no data transfer phase in that case.
|
||||
"""
|
||||
try:
|
||||
# Wait for the data transfer phase to complete.
|
||||
if self.data_finished_fut:
|
||||
try:
|
||||
await self.data_finished_fut
|
||||
logger.debug(
|
||||
"Websocket task finished. Closing the connection."
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# Cancelled error is called when data phase is cancelled
|
||||
# if an error occurred or the client closed the connection
|
||||
logger.debug(
|
||||
"Websocket handler cancelled. Closing the connection."
|
||||
)
|
||||
|
||||
# Cancel the keepalive ping task.
|
||||
if self.keepalive_ping_task:
|
||||
self.keepalive_ping_task.cancel()
|
||||
self.keepalive_ping_task = None
|
||||
|
||||
# Half-close the TCP connection if possible (when there's no TLS).
|
||||
if (
|
||||
self.io_proto
|
||||
and self.io_proto.transport
|
||||
and self.io_proto.transport.can_write_eof()
|
||||
):
|
||||
logger.debug("Websocket half-closing TCP connection")
|
||||
self.io_proto.transport.write_eof()
|
||||
if self.connection_lost_waiter:
|
||||
if await self.wait_for_connection_lost(timeout=0):
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
...
|
||||
finally:
|
||||
# The try/finally ensures that the transport never remains open,
|
||||
# even if this coroutine is cancelled (for example).
|
||||
if (not self.io_proto) or (not self.io_proto.transport):
|
||||
# we were never open, or done. Can't do any finalization.
|
||||
return
|
||||
elif (
|
||||
self.connection_lost_waiter
|
||||
and self.connection_lost_waiter.done()
|
||||
):
|
||||
# connection confirmed closed already, proceed to abort waiter
|
||||
...
|
||||
elif self.io_proto.transport.is_closing():
|
||||
# Connection is already closing (due to half-close above)
|
||||
# proceed to abort waiter
|
||||
...
|
||||
else:
|
||||
self.io_proto.transport.close()
|
||||
if not self.connection_lost_waiter:
|
||||
# Our connection monitor task isn't running.
|
||||
try:
|
||||
await asyncio.sleep(self.close_timeout)
|
||||
except asyncio.CancelledError:
|
||||
...
|
||||
if self.io_proto and self.io_proto.transport:
|
||||
self.io_proto.transport.abort()
|
||||
else:
|
||||
if await self.wait_for_connection_lost(
|
||||
timeout=self.close_timeout
|
||||
):
|
||||
# Connection aborted before the timeout expired.
|
||||
return
|
||||
error_logger.warning(
|
||||
"Timeout waiting for TCP connection to close. Aborting"
|
||||
)
|
||||
if self.io_proto and self.io_proto.transport:
|
||||
self.io_proto.transport.abort()
|
||||
|
||||
def abort_pings(self) -> None:
|
||||
"""
|
||||
Raise ConnectionClosed in pending keepalive pings.
|
||||
They'll never receive a pong once the connection is closed.
|
||||
"""
|
||||
if self.connection.state is not CLOSED:
|
||||
raise ServerError(
|
||||
"Webscoket about_pings should only be called "
|
||||
"after connection state is changed to CLOSED"
|
||||
)
|
||||
|
||||
for ping in self.pings.values():
|
||||
ping.set_exception(ConnectionClosedError(None, None))
|
||||
# If the exception is never retrieved, it will be logged when ping
|
||||
# is garbage-collected. This is confusing for users.
|
||||
# Given that ping is done (with an exception), canceling it does
|
||||
# nothing, but it prevents logging the exception.
|
||||
ping.cancel()
|
||||
|
||||
async def close(self, code: int = 1000, reason: str = "") -> None:
|
||||
"""
|
||||
Perform the closing handshake.
|
||||
This is a websocket-protocol level close.
|
||||
:meth:`close` waits for the other end to complete the handshake and
|
||||
for the TCP connection to terminate.
|
||||
:meth:`close` is idempotent: it doesn't do anything once the
|
||||
connection is closed.
|
||||
:param code: WebSocket close code
|
||||
:param reason: WebSocket close reason
|
||||
"""
|
||||
if code == 1006:
|
||||
self.fail_connection(code, reason)
|
||||
return
|
||||
async with self.conn_mutex:
|
||||
if self.connection.state is OPEN:
|
||||
self.connection.send_close(code, reason)
|
||||
data_to_send = self.connection.data_to_send()
|
||||
await self.send_data(data_to_send)
|
||||
|
||||
async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
|
||||
"""
|
||||
Receive the next message.
|
||||
Return a :class:`str` for a text frame and :class:`bytes` for a binary
|
||||
frame.
|
||||
When the end of the message stream is reached, :meth:`recv` raises
|
||||
:exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
|
||||
raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
|
||||
connection closure and
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
|
||||
error or a network failure.
|
||||
If ``timeout`` is ``None``, block until a message is received. Else,
|
||||
if no message is received within ``timeout`` seconds, return ``None``.
|
||||
Set ``timeout`` to ``0`` to check if a message was already received.
|
||||
:raises ~websockets.exceptions.ConnectionClosed: when the
|
||||
connection is closed
|
||||
:raises asyncio.CancelledError: if the websocket closes while waiting
|
||||
:raises ServerError: if two tasks call :meth:`recv` or
|
||||
:meth:`recv_streaming` concurrently
|
||||
"""
|
||||
|
||||
if self.recv_lock.locked():
|
||||
raise ServerError(
|
||||
"cannot call recv while another task is "
|
||||
"already waiting for the next message"
|
||||
)
|
||||
await self.recv_lock.acquire()
|
||||
if self.connection.state is CLOSED:
|
||||
self.recv_lock.release()
|
||||
raise WebsocketClosed(
|
||||
"Cannot receive from websocket interface after it is closed."
|
||||
)
|
||||
try:
|
||||
self.recv_cancel = asyncio.Future()
|
||||
done, pending = await asyncio.wait(
|
||||
(self.recv_cancel, self.assembler.get(timeout)),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
done_task = next(iter(done))
|
||||
if done_task is self.recv_cancel:
|
||||
# recv was cancelled
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
raise asyncio.CancelledError()
|
||||
else:
|
||||
self.recv_cancel.cancel()
|
||||
return done_task.result()
|
||||
finally:
|
||||
self.recv_cancel = None
|
||||
self.recv_lock.release()
|
||||
|
||||
async def recv_burst(self, max_recv=256) -> Sequence[Data]:
|
||||
"""
|
||||
Receive the messages which have arrived since last checking.
|
||||
Return a :class:`list` containing :class:`str` for a text frame
|
||||
and :class:`bytes` for a binary frame.
|
||||
When the end of the message stream is reached, :meth:`recv_burst`
|
||||
raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically,
|
||||
it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a
|
||||
normal connection closure and
|
||||
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
|
||||
error or a network failure.
|
||||
:raises ~websockets.exceptions.ConnectionClosed: when the
|
||||
connection is closed
|
||||
:raises ServerError: if two tasks call :meth:`recv_burst` or
|
||||
:meth:`recv_streaming` concurrently
|
||||
"""
|
||||
|
||||
if self.recv_lock.locked():
|
||||
raise ServerError(
|
||||
"cannot call recv_burst while another task is already waiting "
|
||||
"for the next message"
|
||||
)
|
||||
await self.recv_lock.acquire()
|
||||
if self.connection.state is CLOSED:
|
||||
self.recv_lock.release()
|
||||
raise WebsocketClosed(
|
||||
"Cannot receive from websocket interface after it is closed."
|
||||
)
|
||||
messages = []
|
||||
try:
|
||||
# Prevent pausing the transport when we're
|
||||
# receiving a burst of messages
|
||||
self.can_pause = False
|
||||
self.recv_cancel = asyncio.Future()
|
||||
while True:
|
||||
done, pending = await asyncio.wait(
|
||||
(self.recv_cancel, self.assembler.get(timeout=0)),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
done_task = next(iter(done))
|
||||
if done_task is self.recv_cancel:
|
||||
# recv_burst was cancelled
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
raise asyncio.CancelledError()
|
||||
m = done_task.result()
|
||||
if m is None:
|
||||
# None left in the burst. This is good!
|
||||
break
|
||||
messages.append(m)
|
||||
if len(messages) >= max_recv:
|
||||
# Too much data in the pipe. Hit our burst limit.
|
||||
break
|
||||
# Allow an eventloop iteration for the
|
||||
# next message to pass into the Assembler
|
||||
await asyncio.sleep(0)
|
||||
self.recv_cancel.cancel()
|
||||
finally:
|
||||
self.recv_cancel = None
|
||||
self.can_pause = True
|
||||
self.recv_lock.release()
|
||||
return messages
|
||||
|
||||
async def recv_streaming(self) -> AsyncIterator[Data]:
|
||||
"""
|
||||
Receive the next message frame by frame.
|
||||
Return an iterator of :class:`str` for a text frame and :class:`bytes`
|
||||
for a binary frame. The iterator should be exhausted, or else the
|
||||
connection will become unusable.
|
||||
With the exception of the return value, :meth:`recv_streaming` behaves
|
||||
like :meth:`recv`.
|
||||
"""
|
||||
if self.recv_lock.locked():
|
||||
raise ServerError(
|
||||
"Cannot call recv_streaming while another task "
|
||||
"is already waiting for the next message"
|
||||
)
|
||||
await self.recv_lock.acquire()
|
||||
if self.connection.state is CLOSED:
|
||||
self.recv_lock.release()
|
||||
raise WebsocketClosed(
|
||||
"Cannot receive from websocket interface after it is closed."
|
||||
)
|
||||
try:
|
||||
cancelled = False
|
||||
self.recv_cancel = asyncio.Future()
|
||||
self.can_pause = False
|
||||
async for m in self.assembler.get_iter():
|
||||
if self.recv_cancel.done():
|
||||
cancelled = True
|
||||
break
|
||||
yield m
|
||||
if cancelled:
|
||||
raise asyncio.CancelledError()
|
||||
finally:
|
||||
self.can_pause = True
|
||||
self.recv_cancel = None
|
||||
self.recv_lock.release()
|
||||
|
||||
async def send(self, message: Union[Data, Iterable[Data]]) -> None:
|
||||
"""
|
||||
Send a message.
|
||||
A string (:class:`str`) is sent as a `Text frame`_. A bytestring or
|
||||
bytes-like object (:class:`bytes`, :class:`bytearray`, or
|
||||
:class:`memoryview`) is sent as a `Binary frame`_.
|
||||
.. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6
|
||||
.. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6
|
||||
:meth:`send` also accepts an iterable of strings, bytestrings, or
|
||||
bytes-like objects. In that case the message is fragmented. Each item
|
||||
is treated as a message fragment and sent in its own frame. All items
|
||||
must be of the same type, or else :meth:`send` will raise a
|
||||
:exc:`TypeError` and the connection will be closed.
|
||||
:meth:`send` rejects dict-like objects because this is often an error.
|
||||
If you wish to send the keys of a dict-like object as fragments, call
|
||||
its :meth:`~dict.keys` method and pass the result to :meth:`send`.
|
||||
:raises TypeError: for unsupported inputs
|
||||
"""
|
||||
async with self.conn_mutex:
|
||||
|
||||
if self.connection.state in (CLOSED, CLOSING):
|
||||
raise WebsocketClosed(
|
||||
"Cannot write to websocket interface after it is closed."
|
||||
)
|
||||
if (not self.data_finished_fut) or self.data_finished_fut.done():
|
||||
raise ServerError(
|
||||
"Cannot write to websocket interface after it is finished."
|
||||
)
|
||||
|
||||
# Unfragmented message -- this case must be handled first because
|
||||
# strings and bytes-like objects are iterable.
|
||||
|
||||
if isinstance(message, str):
|
||||
self.connection.send_text(message.encode("utf-8"))
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
|
||||
elif isinstance(message, (bytes, bytearray, memoryview)):
|
||||
self.connection.send_binary(message)
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
|
||||
elif isinstance(message, Mapping):
|
||||
# Catch a common mistake -- passing a dict to send().
|
||||
raise TypeError("data is a dict-like object")
|
||||
|
||||
elif isinstance(message, Iterable):
|
||||
# Fragmented message -- regular iterator.
|
||||
raise NotImplementedError(
|
||||
"Fragmented websocket messages are not supported."
|
||||
)
|
||||
else:
|
||||
raise TypeError("Websocket data must be bytes, str.")
|
||||
|
||||
async def ping(self, data: Optional[Data] = None) -> asyncio.Future:
|
||||
"""
|
||||
Send a ping.
|
||||
Return an :class:`~asyncio.Future` that will be resolved when the
|
||||
corresponding pong is received. You can ignore it if you don't intend
|
||||
to wait.
|
||||
A ping may serve as a keepalive or as a check that the remote endpoint
|
||||
received all messages up to this point::
|
||||
await pong_event = ws.ping()
|
||||
await pong_event # only if you want to wait for the pong
|
||||
By default, the ping contains four random bytes. This payload may be
|
||||
overridden with the optional ``data`` argument which must be a string
|
||||
(which will be encoded to UTF-8) or a bytes-like object.
|
||||
"""
|
||||
async with self.conn_mutex:
|
||||
if self.connection.state in (CLOSED, CLOSING):
|
||||
raise WebsocketClosed(
|
||||
"Cannot send a ping when the websocket interface "
|
||||
"is closed."
|
||||
)
|
||||
if (not self.io_proto) or (not self.io_proto.loop):
|
||||
raise ServerError(
|
||||
"Cannot send a ping when the websocket has no I/O "
|
||||
"protocol attached."
|
||||
)
|
||||
if data is not None:
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
elif isinstance(data, (bytearray, memoryview)):
|
||||
data = bytes(data)
|
||||
|
||||
# Protect against duplicates if a payload is explicitly set.
|
||||
if data in self.pings:
|
||||
raise ValueError(
|
||||
"already waiting for a pong with the same data"
|
||||
)
|
||||
|
||||
# Generate a unique random payload otherwise.
|
||||
while data is None or data in self.pings:
|
||||
data = struct.pack("!I", random.getrandbits(32))
|
||||
|
||||
self.pings[data] = self.io_proto.loop.create_future()
|
||||
|
||||
self.connection.send_ping(data)
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
|
||||
return asyncio.shield(self.pings[data])
|
||||
|
||||
async def pong(self, data: Data = b"") -> None:
|
||||
"""
|
||||
Send a pong.
|
||||
An unsolicited pong may serve as a unidirectional heartbeat.
|
||||
The payload may be set with the optional ``data`` argument which must
|
||||
be a string (which will be encoded to UTF-8) or a bytes-like object.
|
||||
"""
|
||||
async with self.conn_mutex:
|
||||
if self.connection.state in (CLOSED, CLOSING):
|
||||
# Cannot send pong after transport is shutting down
|
||||
return
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
elif isinstance(data, (bytearray, memoryview)):
|
||||
data = bytes(data)
|
||||
self.connection.send_pong(data)
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
|
||||
async def send_data(self, data_to_send):
|
||||
for data in data_to_send:
|
||||
if data:
|
||||
await self.io_proto.send(data)
|
||||
else:
|
||||
# Send an EOF - We don't actually send it,
|
||||
# just trigger to autoclose the connection
|
||||
if (
|
||||
self.auto_closer_task
|
||||
and not self.auto_closer_task.done()
|
||||
and self.data_finished_fut
|
||||
and not self.data_finished_fut.done()
|
||||
):
|
||||
# Auto-close the connection
|
||||
self.data_finished_fut.set_result(None)
|
||||
else:
|
||||
# This will fail the connection appropriately
|
||||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||
|
||||
async def async_data_received(self, data_to_send, events_to_process):
|
||||
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||
# receiving data can generate data to send (eg, pong for a ping)
|
||||
# send connection.data_to_send()
|
||||
await self.send_data(data_to_send)
|
||||
if len(events_to_process) > 0:
|
||||
await self.process_events(events_to_process)
|
||||
|
||||
def data_received(self, data):
|
||||
self.connection.receive_data(data)
|
||||
data_to_send = self.connection.data_to_send()
|
||||
events_to_process = self.connection.events_received()
|
||||
if len(data_to_send) > 0 or len(events_to_process) > 0:
|
||||
asyncio.create_task(
|
||||
self.async_data_received(data_to_send, events_to_process)
|
||||
)
|
||||
|
||||
async def async_eof_received(self, data_to_send, events_to_process):
|
||||
# receiving EOF can generate data to send
|
||||
# send connection.data_to_send()
|
||||
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||
await self.send_data(data_to_send)
|
||||
if len(events_to_process) > 0:
|
||||
await self.process_events(events_to_process)
|
||||
if self.recv_cancel:
|
||||
self.recv_cancel.cancel()
|
||||
if (
|
||||
self.auto_closer_task
|
||||
and not self.auto_closer_task.done()
|
||||
and self.data_finished_fut
|
||||
and not self.data_finished_fut.done()
|
||||
):
|
||||
# Auto-close the connection
|
||||
self.data_finished_fut.set_result(None)
|
||||
# Cancel the running handler if its waiting
|
||||
else:
|
||||
# This will fail the connection appropriately
|
||||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||
|
||||
def eof_received(self) -> Optional[bool]:
|
||||
self.connection.receive_eof()
|
||||
data_to_send = self.connection.data_to_send()
|
||||
events_to_process = self.connection.events_received()
|
||||
asyncio.create_task(
|
||||
self.async_eof_received(data_to_send, events_to_process)
|
||||
)
|
||||
return False
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""
|
||||
The WebSocket Connection is Closed.
|
||||
"""
|
||||
if not self.connection.state == CLOSED:
|
||||
# signal to the websocket connection handler
|
||||
# we've lost the connection
|
||||
self.connection.fail(code=1006)
|
||||
self.connection.state = CLOSED
|
||||
|
||||
self.abort_pings()
|
||||
if self.connection_lost_waiter:
|
||||
self.connection_lost_waiter.set_result(None)
|
||||
@@ -10,13 +10,39 @@ from sanic_routing.exceptions import NotFound # type: ignore
|
||||
from sanic_routing.utils import path_to_parts # type: ignore
|
||||
|
||||
from sanic.exceptions import InvalidSignal
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.models.handler_types import SignalHandler
|
||||
|
||||
|
||||
RESERVED_NAMESPACES = (
|
||||
"server",
|
||||
"http",
|
||||
)
|
||||
RESERVED_NAMESPACES = {
|
||||
"server": (
|
||||
# "server.main.start",
|
||||
# "server.main.stop",
|
||||
"server.init.before",
|
||||
"server.init.after",
|
||||
"server.shutdown.before",
|
||||
"server.shutdown.after",
|
||||
),
|
||||
"http": (
|
||||
"http.lifecycle.begin",
|
||||
"http.lifecycle.complete",
|
||||
"http.lifecycle.exception",
|
||||
"http.lifecycle.handle",
|
||||
"http.lifecycle.read_body",
|
||||
"http.lifecycle.read_head",
|
||||
"http.lifecycle.request",
|
||||
"http.lifecycle.response",
|
||||
"http.routing.after",
|
||||
"http.routing.before",
|
||||
"http.lifecycle.send",
|
||||
"http.middleware.after",
|
||||
"http.middleware.before",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _blank():
|
||||
...
|
||||
|
||||
|
||||
class Signal(Route):
|
||||
@@ -59,8 +85,13 @@ class SignalRouter(BaseRouter):
|
||||
terms.append(extra)
|
||||
raise NotFound(message % tuple(terms))
|
||||
|
||||
# Regex routes evaluate and can extract params directly. They are set
|
||||
# on param_basket["__params__"]
|
||||
params = param_basket["__params__"]
|
||||
if not params:
|
||||
# If param_basket["__params__"] does not exist, we might have
|
||||
# param_basket["__matches__"], which are indexed based matches
|
||||
# on path segments. They should already be cast types.
|
||||
params = {
|
||||
param.name: param_basket["__matches__"][idx]
|
||||
for idx, param in group.params.items()
|
||||
@@ -73,8 +104,18 @@ class SignalRouter(BaseRouter):
|
||||
event: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
condition: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
fail_not_found: bool = True,
|
||||
reverse: bool = False,
|
||||
) -> Any:
|
||||
try:
|
||||
group, handlers, params = self.get(event, condition=condition)
|
||||
except NotFound as e:
|
||||
if fail_not_found:
|
||||
raise e
|
||||
else:
|
||||
if self.ctx.app.debug:
|
||||
error_logger.warning(str(e))
|
||||
return None
|
||||
|
||||
events = [signal.ctx.event for signal in group]
|
||||
for signal_event in events:
|
||||
@@ -82,12 +123,19 @@ class SignalRouter(BaseRouter):
|
||||
if context:
|
||||
params.update(context)
|
||||
|
||||
if not reverse:
|
||||
handlers = handlers[::-1]
|
||||
try:
|
||||
for handler in handlers:
|
||||
if condition is None or condition == handler.__requirements__:
|
||||
maybe_coroutine = handler(**params)
|
||||
if isawaitable(maybe_coroutine):
|
||||
await maybe_coroutine
|
||||
retval = await maybe_coroutine
|
||||
if retval:
|
||||
return retval
|
||||
elif maybe_coroutine:
|
||||
return maybe_coroutine
|
||||
return None
|
||||
finally:
|
||||
for signal_event in events:
|
||||
signal_event.clear()
|
||||
@@ -98,14 +146,23 @@ class SignalRouter(BaseRouter):
|
||||
*,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
condition: Optional[Dict[str, str]] = None,
|
||||
) -> asyncio.Task:
|
||||
task = self.ctx.loop.create_task(
|
||||
self._dispatch(
|
||||
fail_not_found: bool = True,
|
||||
inline: bool = False,
|
||||
reverse: bool = False,
|
||||
) -> Union[asyncio.Task, Any]:
|
||||
dispatch = self._dispatch(
|
||||
event,
|
||||
context=context,
|
||||
condition=condition,
|
||||
fail_not_found=fail_not_found and inline,
|
||||
reverse=reverse,
|
||||
)
|
||||
)
|
||||
logger.debug(f"Dispatching signal: {event}")
|
||||
|
||||
if inline:
|
||||
return await dispatch
|
||||
|
||||
task = asyncio.get_running_loop().create_task(dispatch)
|
||||
await asyncio.sleep(0)
|
||||
return task
|
||||
|
||||
@@ -131,7 +188,9 @@ class SignalRouter(BaseRouter):
|
||||
append=True,
|
||||
) # type: ignore
|
||||
|
||||
def finalize(self, do_compile: bool = True):
|
||||
def finalize(self, do_compile: bool = True, do_optimize: bool = False):
|
||||
self.add(_blank, "sanic.__signal__.__init__")
|
||||
|
||||
try:
|
||||
self.ctx.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
@@ -140,7 +199,7 @@ class SignalRouter(BaseRouter):
|
||||
for signal in self.routes:
|
||||
signal.ctx.event = asyncio.Event()
|
||||
|
||||
return super().finalize(do_compile=do_compile)
|
||||
return super().finalize(do_compile=do_compile, do_optimize=do_optimize)
|
||||
|
||||
def _build_event_parts(self, event: str) -> Tuple[str, str, str]:
|
||||
parts = path_to_parts(event, self.delimiter)
|
||||
@@ -151,7 +210,11 @@ class SignalRouter(BaseRouter):
|
||||
):
|
||||
raise InvalidSignal("Invalid signal event: %s" % event)
|
||||
|
||||
if parts[0] in RESERVED_NAMESPACES:
|
||||
if (
|
||||
parts[0] in RESERVED_NAMESPACES
|
||||
and event not in RESERVED_NAMESPACES[parts[0]]
|
||||
and not (parts[2].startswith("<") and parts[2].endswith(">"))
|
||||
):
|
||||
raise InvalidSignal(
|
||||
"Cannot declare reserved signal event: %s" % event
|
||||
)
|
||||
|
||||
8
sanic/touchup/__init__.py
Normal file
8
sanic/touchup/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .meta import TouchUpMeta
|
||||
from .service import TouchUp
|
||||
|
||||
|
||||
__all__ = (
|
||||
"TouchUp",
|
||||
"TouchUpMeta",
|
||||
)
|
||||
22
sanic/touchup/meta.py
Normal file
22
sanic/touchup/meta.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sanic.exceptions import SanicException
|
||||
|
||||
from .service import TouchUp
|
||||
|
||||
|
||||
class TouchUpMeta(type):
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
gen_class = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
|
||||
methods = attrs.get("__touchup__")
|
||||
attrs["__touched__"] = False
|
||||
if methods:
|
||||
|
||||
for method in methods:
|
||||
if method not in attrs:
|
||||
raise SanicException(
|
||||
"Cannot perform touchup on non-existent method: "
|
||||
f"{name}.{method}"
|
||||
)
|
||||
TouchUp.register(gen_class, method)
|
||||
|
||||
return gen_class
|
||||
5
sanic/touchup/schemes/__init__.py
Normal file
5
sanic/touchup/schemes/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base import BaseScheme
|
||||
from .ode import OptionalDispatchEvent # noqa
|
||||
|
||||
|
||||
__all__ = ("BaseScheme",)
|
||||
20
sanic/touchup/schemes/base.py
Normal file
20
sanic/touchup/schemes/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Set, Type
|
||||
|
||||
|
||||
class BaseScheme(ABC):
|
||||
ident: str
|
||||
_registry: Set[Type] = set()
|
||||
|
||||
def __init__(self, app) -> None:
|
||||
self.app = app
|
||||
|
||||
@abstractmethod
|
||||
def run(self, method, module_globals) -> None:
|
||||
...
|
||||
|
||||
def __init_subclass__(cls):
|
||||
BaseScheme._registry.add(cls)
|
||||
|
||||
def __call__(self, method, module_globals):
|
||||
return self.run(method, module_globals)
|
||||
67
sanic/touchup/schemes/ode.py
Normal file
67
sanic/touchup/schemes/ode.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse
|
||||
from inspect import getsource
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from sanic.log import logger
|
||||
|
||||
from .base import BaseScheme
|
||||
|
||||
|
||||
class OptionalDispatchEvent(BaseScheme):
|
||||
ident = "ODE"
|
||||
|
||||
def __init__(self, app) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
self._registered_events = [
|
||||
signal.path for signal in app.signal_router.routes
|
||||
]
|
||||
|
||||
def run(self, method, module_globals):
|
||||
raw_source = getsource(method)
|
||||
src = dedent(raw_source)
|
||||
tree = parse(src)
|
||||
node = RemoveDispatch(self._registered_events).visit(tree)
|
||||
compiled_src = compile(node, method.__name__, "exec")
|
||||
exec_locals: Dict[str, Any] = {}
|
||||
exec(compiled_src, module_globals, exec_locals) # nosec
|
||||
|
||||
return exec_locals[method.__name__]
|
||||
|
||||
|
||||
class RemoveDispatch(NodeTransformer):
|
||||
def __init__(self, registered_events) -> None:
|
||||
self._registered_events = registered_events
|
||||
|
||||
def visit_Expr(self, node: Expr) -> Any:
|
||||
call = node.value
|
||||
if isinstance(call, Await):
|
||||
call = call.value
|
||||
|
||||
func = getattr(call, "func", None)
|
||||
args = getattr(call, "args", None)
|
||||
if not func or not args:
|
||||
return node
|
||||
|
||||
if isinstance(func, Attribute) and func.attr == "dispatch":
|
||||
event = args[0]
|
||||
if hasattr(event, "s"):
|
||||
event_name = getattr(event, "value", event.s)
|
||||
if self._not_registered(event_name):
|
||||
logger.debug(f"Disabling event: {event_name}")
|
||||
return None
|
||||
return node
|
||||
|
||||
def _not_registered(self, event_name):
|
||||
dynamic = []
|
||||
for event in self._registered_events:
|
||||
if event.endswith(">"):
|
||||
namespace_concern, _ = event.rsplit(".", 1)
|
||||
dynamic.append(namespace_concern)
|
||||
|
||||
namespace_concern, _ = event_name.rsplit(".", 1)
|
||||
return (
|
||||
event_name not in self._registered_events
|
||||
and namespace_concern not in dynamic
|
||||
)
|
||||
33
sanic/touchup/service.py
Normal file
33
sanic/touchup/service.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from inspect import getmembers, getmodule
|
||||
from typing import Set, Tuple, Type
|
||||
|
||||
from .schemes import BaseScheme
|
||||
|
||||
|
||||
class TouchUp:
|
||||
_registry: Set[Tuple[Type, str]] = set()
|
||||
|
||||
@classmethod
|
||||
def run(cls, app):
|
||||
for target, method_name in cls._registry:
|
||||
method = getattr(target, method_name)
|
||||
|
||||
if app.test_mode:
|
||||
placeholder = f"_{method_name}"
|
||||
if hasattr(target, placeholder):
|
||||
method = getattr(target, placeholder)
|
||||
else:
|
||||
setattr(target, placeholder, method)
|
||||
|
||||
module = getmodule(target)
|
||||
module_globals = dict(getmembers(module))
|
||||
|
||||
for scheme in BaseScheme._registry:
|
||||
modified = scheme(app)(method, module_globals)
|
||||
setattr(target, method_name, modified)
|
||||
|
||||
target.__touched__ = True
|
||||
|
||||
@classmethod
|
||||
def register(cls, target, method_name):
|
||||
cls._registry.add((target, method_name))
|
||||
@@ -13,6 +13,7 @@ from warnings import warn
|
||||
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.models.handler_types import RouteHandler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -86,7 +87,7 @@ class HTTPMethodView:
|
||||
return handler(request, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def as_view(cls, *class_args, **class_kwargs):
|
||||
def as_view(cls, *class_args: Any, **class_kwargs: Any) -> RouteHandler:
|
||||
"""Return view function for use with the routing system, that
|
||||
dispatches request to appropriate handler method.
|
||||
"""
|
||||
@@ -100,7 +101,7 @@ class HTTPMethodView:
|
||||
for decorator in cls.decorators:
|
||||
view = decorator(view)
|
||||
|
||||
view.view_class = cls
|
||||
view.view_class = cls # type: ignore
|
||||
view.__doc__ = cls.__doc__
|
||||
view.__module__ = cls.__module__
|
||||
view.__name__ = cls.__name__
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from httptools import HttpParserUpgrade # type: ignore
|
||||
from websockets import ( # type: ignore
|
||||
ConnectionClosed,
|
||||
InvalidHandshake,
|
||||
WebSocketCommonProtocol,
|
||||
)
|
||||
|
||||
# Despite the "legacy" namespace, the primary maintainer of websockets
|
||||
# committed to maintaining backwards-compatibility until 2026 and will
|
||||
# consider extending it if sanic continues depending on this module.
|
||||
from websockets.legacy import handshake
|
||||
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.server import HttpProtocol
|
||||
|
||||
|
||||
__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"]
|
||||
|
||||
ASIMessage = MutableMapping[str, Any]
|
||||
|
||||
|
||||
class WebSocketProtocol(HttpProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
websocket_timeout=10,
|
||||
websocket_max_size=None,
|
||||
websocket_max_queue=None,
|
||||
websocket_read_limit=2 ** 16,
|
||||
websocket_write_limit=2 ** 16,
|
||||
websocket_ping_interval=20,
|
||||
websocket_ping_timeout=20,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.websocket = None
|
||||
# self.app = None
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_max_size = websocket_max_size
|
||||
self.websocket_max_queue = websocket_max_queue
|
||||
self.websocket_read_limit = websocket_read_limit
|
||||
self.websocket_write_limit = websocket_write_limit
|
||||
self.websocket_ping_interval = websocket_ping_interval
|
||||
self.websocket_ping_timeout = websocket_ping_timeout
|
||||
|
||||
# timeouts make no sense for websocket routes
|
||||
def request_timeout_callback(self):
|
||||
if self.websocket is None:
|
||||
super().request_timeout_callback()
|
||||
|
||||
def response_timeout_callback(self):
|
||||
if self.websocket is None:
|
||||
super().response_timeout_callback()
|
||||
|
||||
def keep_alive_timeout_callback(self):
|
||||
if self.websocket is None:
|
||||
super().keep_alive_timeout_callback()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self.websocket is not None:
|
||||
self.websocket.connection_lost(exc)
|
||||
super().connection_lost(exc)
|
||||
|
||||
def data_received(self, data):
|
||||
if self.websocket is not None:
|
||||
# pass the data to the websocket protocol
|
||||
self.websocket.data_received(data)
|
||||
else:
|
||||
try:
|
||||
super().data_received(data)
|
||||
except HttpParserUpgrade:
|
||||
# this is okay, it just indicates we've got an upgrade request
|
||||
pass
|
||||
|
||||
def write_response(self, response):
|
||||
if self.websocket is not None:
|
||||
# websocket requests do not write a response
|
||||
self.transport.close()
|
||||
else:
|
||||
super().write_response(response)
|
||||
|
||||
async def websocket_handshake(self, request, subprotocols=None):
|
||||
# let the websockets package do the handshake with the client
|
||||
headers = {}
|
||||
|
||||
try:
|
||||
key = handshake.check_request(request.headers)
|
||||
handshake.build_response(headers, key)
|
||||
except InvalidHandshake:
|
||||
raise InvalidUsage("Invalid websocket request")
|
||||
|
||||
subprotocol = None
|
||||
if subprotocols and "Sec-Websocket-Protocol" in request.headers:
|
||||
# select a subprotocol
|
||||
client_subprotocols = [
|
||||
p.strip()
|
||||
for p in request.headers["Sec-Websocket-Protocol"].split(",")
|
||||
]
|
||||
for p in client_subprotocols:
|
||||
if p in subprotocols:
|
||||
subprotocol = p
|
||||
headers["Sec-Websocket-Protocol"] = subprotocol
|
||||
break
|
||||
|
||||
# write the 101 response back to the client
|
||||
rv = b"HTTP/1.1 101 Switching Protocols\r\n"
|
||||
for k, v in headers.items():
|
||||
rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n"
|
||||
rv += b"\r\n"
|
||||
request.transport.write(rv)
|
||||
|
||||
# hook up the websocket protocol
|
||||
self.websocket = WebSocketCommonProtocol(
|
||||
close_timeout=self.websocket_timeout,
|
||||
max_size=self.websocket_max_size,
|
||||
max_queue=self.websocket_max_queue,
|
||||
read_limit=self.websocket_read_limit,
|
||||
write_limit=self.websocket_write_limit,
|
||||
ping_interval=self.websocket_ping_interval,
|
||||
ping_timeout=self.websocket_ping_timeout,
|
||||
)
|
||||
# we use WebSocketCommonProtocol because we don't want the handshake
|
||||
# logic from WebSocketServerProtocol; however, we must tell it that
|
||||
# we're running on the server side
|
||||
self.websocket.is_client = False
|
||||
self.websocket.side = "server"
|
||||
self.websocket.subprotocol = subprotocol
|
||||
self.websocket.connection_made(request.transport)
|
||||
self.websocket.connection_open()
|
||||
return self.websocket
|
||||
|
||||
|
||||
class WebSocketConnection:
|
||||
|
||||
# TODO
|
||||
# - Implement ping/pong
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send: Callable[[ASIMessage], Awaitable[None]],
|
||||
receive: Callable[[], Awaitable[ASIMessage]],
|
||||
subprotocols: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self._send = send
|
||||
self._receive = receive
|
||||
self._subprotocols = subprotocols or []
|
||||
|
||||
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
||||
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
||||
|
||||
if isinstance(data, bytes):
|
||||
message.update({"bytes": data})
|
||||
else:
|
||||
message.update({"text": str(data)})
|
||||
|
||||
await self._send(message)
|
||||
|
||||
async def recv(self, *args, **kwargs) -> Optional[str]:
|
||||
message = await self._receive()
|
||||
|
||||
if message["type"] == "websocket.receive":
|
||||
return message["text"]
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
receive = recv
|
||||
|
||||
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
|
||||
subprotocol = None
|
||||
if subprotocols:
|
||||
for subp in subprotocols:
|
||||
if subp in self.subprotocols:
|
||||
subprotocol = subp
|
||||
break
|
||||
|
||||
await self._send(
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
"subprotocol": subprotocol,
|
||||
}
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def subprotocols(self):
|
||||
return self._subprotocols
|
||||
|
||||
@subprotocols.setter
|
||||
def subprotocols(self, subprotocols: Optional[List[str]] = None):
|
||||
self._subprotocols = subprotocols or []
|
||||
@@ -8,8 +8,8 @@ import traceback
|
||||
from gunicorn.workers import base # type: ignore
|
||||
|
||||
from sanic.log import logger
|
||||
from sanic.server import HttpProtocol, Signal, serve, trigger_events
|
||||
from sanic.websocket import WebSocketProtocol
|
||||
from sanic.server import HttpProtocol, Signal, serve
|
||||
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
|
||||
|
||||
|
||||
try:
|
||||
@@ -68,10 +68,10 @@ class GunicornWorker(base.Worker):
|
||||
)
|
||||
self._server_settings["signal"] = self.signal
|
||||
self._server_settings.pop("sock")
|
||||
trigger_events(
|
||||
self._server_settings.get("before_start", []), self.loop
|
||||
self._await(self.app.callable._startup())
|
||||
self._await(
|
||||
self.app.callable._server_event("init", "before", loop=self.loop)
|
||||
)
|
||||
self._server_settings["before_start"] = ()
|
||||
|
||||
main_start = self._server_settings.pop("main_start", None)
|
||||
main_stop = self._server_settings.pop("main_stop", None)
|
||||
@@ -82,24 +82,29 @@ class GunicornWorker(base.Worker):
|
||||
"with GunicornWorker"
|
||||
)
|
||||
|
||||
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
|
||||
try:
|
||||
self.loop.run_until_complete(self._runner)
|
||||
self._await(self._run())
|
||||
self.app.callable.is_running = True
|
||||
trigger_events(
|
||||
self._server_settings.get("after_start", []), self.loop
|
||||
self._await(
|
||||
self.app.callable._server_event(
|
||||
"init", "after", loop=self.loop
|
||||
)
|
||||
)
|
||||
self.loop.run_until_complete(self._check_alive())
|
||||
trigger_events(
|
||||
self._server_settings.get("before_stop", []), self.loop
|
||||
self._await(
|
||||
self.app.callable._server_event(
|
||||
"shutdown", "before", loop=self.loop
|
||||
)
|
||||
)
|
||||
self.loop.run_until_complete(self.close())
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
try:
|
||||
trigger_events(
|
||||
self._server_settings.get("after_stop", []), self.loop
|
||||
self._await(
|
||||
self.app.callable._server_event(
|
||||
"shutdown", "after", loop=self.loop
|
||||
)
|
||||
)
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
@@ -137,14 +142,11 @@ class GunicornWorker(base.Worker):
|
||||
|
||||
# Force close non-idle connection after waiting for
|
||||
# graceful_shutdown_timeout
|
||||
coros = []
|
||||
for conn in self.connections:
|
||||
if hasattr(conn, "websocket") and conn.websocket:
|
||||
coros.append(conn.websocket.close_connection())
|
||||
conn.websocket.fail_connection(code=1001)
|
||||
else:
|
||||
conn.close()
|
||||
_shutdown = asyncio.gather(*coros, loop=self.loop)
|
||||
await _shutdown
|
||||
conn.abort()
|
||||
|
||||
async def _run(self):
|
||||
for sock in self.sockets:
|
||||
@@ -238,3 +240,7 @@ class GunicornWorker(base.Worker):
|
||||
self.exit_code = 1
|
||||
self.cfg.worker_abort(self)
|
||||
sys.exit(1)
|
||||
|
||||
def _await(self, coro):
|
||||
fut = asyncio.ensure_future(coro, loop=self.loop)
|
||||
self.loop.run_until_complete(fut)
|
||||
|
||||
35
setup.py
35
setup.py
@@ -81,60 +81,63 @@ env_dependency = (
|
||||
)
|
||||
ujson = "ujson>=1.35" + env_dependency
|
||||
uvloop = "uvloop>=0.5.3" + env_dependency
|
||||
|
||||
types_ujson = "types-ujson" + env_dependency
|
||||
requirements = [
|
||||
"sanic-routing==0.7.0",
|
||||
"sanic-routing~=0.7",
|
||||
"httptools>=0.0.10",
|
||||
uvloop,
|
||||
ujson,
|
||||
"aiofiles>=0.6.0",
|
||||
"websockets>=9.0",
|
||||
"websockets>=10.0",
|
||||
"multidict>=5.0,<6.0",
|
||||
]
|
||||
|
||||
tests_require = [
|
||||
"sanic-testing>=0.6.0",
|
||||
"sanic-testing>=0.7.0",
|
||||
"pytest==5.2.1",
|
||||
"multidict>=5.0,<6.0",
|
||||
"coverage==5.3",
|
||||
"gunicorn==20.0.4",
|
||||
"pytest-cov",
|
||||
"beautifulsoup4",
|
||||
uvloop,
|
||||
ujson,
|
||||
"pytest-sanic",
|
||||
"pytest-sugar",
|
||||
"pytest-benchmark",
|
||||
"chardet==3.*",
|
||||
"flake8",
|
||||
"black",
|
||||
"isort>=5.0.0",
|
||||
"bandit",
|
||||
"mypy>=0.901",
|
||||
"docutils",
|
||||
"pygments",
|
||||
"uvicorn<0.15.0",
|
||||
types_ujson,
|
||||
]
|
||||
|
||||
docs_require = [
|
||||
"sphinx>=2.1.2",
|
||||
"sphinx_rtd_theme",
|
||||
"recommonmark>=0.5.0",
|
||||
"sphinx_rtd_theme>=0.4.3",
|
||||
"docutils",
|
||||
"pygments",
|
||||
"m2r2",
|
||||
]
|
||||
|
||||
dev_require = tests_require + [
|
||||
"aiofiles",
|
||||
"tox",
|
||||
"black",
|
||||
"flake8",
|
||||
"bandit",
|
||||
"towncrier",
|
||||
]
|
||||
|
||||
all_require = dev_require + docs_require
|
||||
all_require = list(set(dev_require + docs_require))
|
||||
|
||||
if strtobool(os.environ.get("SANIC_NO_UJSON", "no")):
|
||||
print("Installing without uJSON")
|
||||
requirements.remove(ujson)
|
||||
tests_require.remove(ujson)
|
||||
tests_require.remove(types_ujson)
|
||||
|
||||
# 'nt' means windows OS
|
||||
if strtobool(os.environ.get("SANIC_NO_UVLOOP", "no")):
|
||||
print("Installing without uvLoop")
|
||||
requirements.remove(uvloop)
|
||||
tests_require.remove(uvloop)
|
||||
|
||||
extras_require = {
|
||||
"test": tests_require,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
@@ -9,10 +11,12 @@ from typing import Tuple
|
||||
import pytest
|
||||
|
||||
from sanic_routing.exceptions import RouteExists
|
||||
from sanic_testing.testing import PORT
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.router import Router
|
||||
from sanic.touchup.service import TouchUp
|
||||
|
||||
|
||||
slugify = re.compile(r"[^a-zA-Z0-9_\-]")
|
||||
@@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]:
|
||||
collect_ignore = ["test_worker.py"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(caplog):
|
||||
yield caplog
|
||||
|
||||
|
||||
async def _handler(request):
|
||||
"""
|
||||
Dummy placeholder method used for route resolver when creating a new
|
||||
@@ -41,33 +40,32 @@ async def _handler(request):
|
||||
|
||||
|
||||
TYPE_TO_GENERATOR_MAP = {
|
||||
"string": lambda: "".join(
|
||||
"str": lambda: "".join(
|
||||
[random.choice(string.ascii_lowercase) for _ in range(4)]
|
||||
),
|
||||
"int": lambda: random.choice(range(1000000)),
|
||||
"number": lambda: random.random(),
|
||||
"float": lambda: random.random(),
|
||||
"alpha": lambda: "".join(
|
||||
[random.choice(string.ascii_lowercase) for _ in range(4)]
|
||||
),
|
||||
"uuid": lambda: str(uuid.uuid1()),
|
||||
}
|
||||
|
||||
CACHE = {}
|
||||
|
||||
|
||||
class RouteStringGenerator:
|
||||
|
||||
ROUTE_COUNT_PER_DEPTH = 100
|
||||
HTTP_METHODS = HTTP_METHODS
|
||||
ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"]
|
||||
ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"]
|
||||
|
||||
def generate_random_direct_route(self, max_route_depth=4):
|
||||
routes = []
|
||||
for depth in range(1, max_route_depth + 1):
|
||||
for _ in range(self.ROUTE_COUNT_PER_DEPTH):
|
||||
route = "/".join(
|
||||
[
|
||||
TYPE_TO_GENERATOR_MAP.get("string")()
|
||||
for _ in range(depth)
|
||||
]
|
||||
[TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)]
|
||||
)
|
||||
route = route.replace(".", "", -1)
|
||||
route_detail = (random.choice(self.HTTP_METHODS), route)
|
||||
@@ -83,7 +81,7 @@ class RouteStringGenerator:
|
||||
new_route_part = "/".join(
|
||||
[
|
||||
"<{}:{}>".format(
|
||||
TYPE_TO_GENERATOR_MAP.get("string")(),
|
||||
TYPE_TO_GENERATOR_MAP.get("str")(),
|
||||
random.choice(self.ROUTE_PARAM_TYPES),
|
||||
)
|
||||
for _ in range(max_route_depth - current_length)
|
||||
@@ -98,7 +96,7 @@ class RouteStringGenerator:
|
||||
def generate_url_for_template(template):
|
||||
url = template
|
||||
for pattern, param_type in re.findall(
|
||||
re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"),
|
||||
re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"),
|
||||
template,
|
||||
):
|
||||
value = TYPE_TO_GENERATOR_MAP.get(param_type)()
|
||||
@@ -111,6 +109,7 @@ def sanic_router(app):
|
||||
# noinspection PyProtectedMember
|
||||
def _setup(route_details: tuple) -> Tuple[Router, tuple]:
|
||||
router = Router()
|
||||
router.ctx.app = app
|
||||
added_router = []
|
||||
for method, route in route_details:
|
||||
try:
|
||||
@@ -141,5 +140,33 @@ def url_param_generator():
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def app(request):
|
||||
if not CACHE:
|
||||
for target, method_name in TouchUp._registry:
|
||||
CACHE[method_name] = getattr(target, method_name)
|
||||
app = Sanic(slugify.sub("-", request.node.name))
|
||||
return app
|
||||
yield app
|
||||
for target, method_name in TouchUp._registry:
|
||||
setattr(target, method_name, CACHE[method_name])
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def run_startup(caplog):
|
||||
def run(app):
|
||||
nonlocal caplog
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
server = app.create_server(
|
||||
debug=True, return_asyncio_server=True, port=PORT
|
||||
)
|
||||
loop._stopping = False
|
||||
|
||||
_server = loop.run_until_complete(server)
|
||||
|
||||
_server.close()
|
||||
loop.run_until_complete(_server.wait_closed())
|
||||
app.stop()
|
||||
|
||||
return caplog.record_tuples
|
||||
|
||||
return run
|
||||
|
||||
@@ -178,9 +178,6 @@ def test_app_enable_websocket(app, websocket_enabled, enable):
|
||||
@patch("sanic.app.WebSocketProtocol")
|
||||
def test_app_websocket_parameters(websocket_protocol_mock, app):
|
||||
app.config.WEBSOCKET_MAX_SIZE = 44
|
||||
app.config.WEBSOCKET_MAX_QUEUE = 45
|
||||
app.config.WEBSOCKET_READ_LIMIT = 46
|
||||
app.config.WEBSOCKET_WRITE_LIMIT = 47
|
||||
app.config.WEBSOCKET_PING_TIMEOUT = 48
|
||||
app.config.WEBSOCKET_PING_INTERVAL = 50
|
||||
|
||||
@@ -197,11 +194,6 @@ def test_app_websocket_parameters(websocket_protocol_mock, app):
|
||||
websocket_protocol_call_args = websocket_protocol_mock.call_args
|
||||
ws_kwargs = websocket_protocol_call_args[1]
|
||||
assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE
|
||||
assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE
|
||||
assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT
|
||||
assert (
|
||||
ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT
|
||||
)
|
||||
assert (
|
||||
ws_kwargs["websocket_ping_timeout"]
|
||||
== app.config.WEBSOCKET_PING_TIMEOUT
|
||||
@@ -396,7 +388,7 @@ def test_app_set_attribute_warning(app):
|
||||
assert len(record) == 1
|
||||
assert record[0].message.args[0] == (
|
||||
"Setting variables on Sanic instances is deprecated "
|
||||
"and will be removed in version 21.9. You should change your "
|
||||
"and will be removed in version 21.12. You should change your "
|
||||
"Sanic instance to use instance.ctx.foo instead."
|
||||
)
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ import uvicorn
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.asgi import MockTransport
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable
|
||||
from sanic.request import Request
|
||||
from sanic.response import json, text
|
||||
from sanic.websocket import WebSocketConnection
|
||||
from sanic.server.websockets.connection import WebSocketConnection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -346,3 +346,33 @@ async def test_content_type(app):
|
||||
|
||||
_, response = await app.asgi_client.get("/custom")
|
||||
assert response.headers.get("content-type") == "somethingelse"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_handle_exception(app):
|
||||
@app.get("/error-prone")
|
||||
def _request(request):
|
||||
raise ServiceUnavailable(message="Service unavailable")
|
||||
|
||||
_, response = await app.asgi_client.get("/wrong-path")
|
||||
assert response.status_code == 404
|
||||
|
||||
_, response = await app.asgi_client.get("/error-prone")
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_exception_suppressed_by_middleware(app):
|
||||
@app.get("/error-prone")
|
||||
def _request(request):
|
||||
raise ServiceUnavailable(message="Service unavailable")
|
||||
|
||||
@app.on_request
|
||||
def forbidden(request):
|
||||
raise Forbidden(message="forbidden")
|
||||
|
||||
_, response = await app.asgi_client.get("/wrong-path")
|
||||
assert response.status_code == 403
|
||||
|
||||
_, response = await app.asgi_client.get("/error-prone")
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -20,4 +20,4 @@ def test_bad_request_response(app):
|
||||
|
||||
app.run(host="127.0.0.1", port=42101, debug=False)
|
||||
assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n"
|
||||
assert b"Bad Request" in lines[-1]
|
||||
assert b"Bad Request" in lines[-2]
|
||||
|
||||
70
tests/test_blueprint_copy.py
Normal file
70
tests/test_blueprint_copy.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from sanic import Blueprint, Sanic, blueprints, response
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
def test_bp_copy(app: Sanic):
|
||||
bp1 = Blueprint("test_bp1", version=1)
|
||||
bp1.ctx.test = 1
|
||||
assert hasattr(bp1.ctx, "test")
|
||||
|
||||
@bp1.route("/page")
|
||||
def handle_request(request):
|
||||
return text("Hello world!")
|
||||
|
||||
bp2 = bp1.copy(name="test_bp2", version=2)
|
||||
assert id(bp1) != id(bp2)
|
||||
assert bp1._apps == bp2._apps == set()
|
||||
assert not hasattr(bp2.ctx, "test")
|
||||
assert len(bp2._future_exceptions) == len(bp1._future_exceptions)
|
||||
assert len(bp2._future_listeners) == len(bp1._future_listeners)
|
||||
assert len(bp2._future_middleware) == len(bp1._future_middleware)
|
||||
assert len(bp2._future_routes) == len(bp1._future_routes)
|
||||
assert len(bp2._future_signals) == len(bp1._future_signals)
|
||||
|
||||
app.blueprint(bp1)
|
||||
app.blueprint(bp2)
|
||||
|
||||
bp3 = bp1.copy(name="test_bp3", version=3, with_registration=True)
|
||||
assert id(bp1) != id(bp3)
|
||||
assert bp1._apps == bp3._apps and bp3._apps
|
||||
assert not hasattr(bp3.ctx, "test")
|
||||
|
||||
bp4 = bp1.copy(name="test_bp4", version=4, with_ctx=True)
|
||||
assert id(bp1) != id(bp4)
|
||||
assert bp4.ctx.test == 1
|
||||
|
||||
bp5 = bp1.copy(name="test_bp5", version=5, with_registration=False)
|
||||
assert id(bp1) != id(bp5)
|
||||
assert not bp5._apps
|
||||
assert bp1._apps != set()
|
||||
|
||||
app.blueprint(bp5)
|
||||
|
||||
bp6 = bp1.copy(
|
||||
name="test_bp6",
|
||||
version=6,
|
||||
with_registration=True,
|
||||
version_prefix="/version",
|
||||
)
|
||||
assert bp6._apps
|
||||
assert bp6.version_prefix == "/version"
|
||||
|
||||
_, response = app.test_client.get("/v1/page")
|
||||
assert "Hello world!" in response.text
|
||||
|
||||
_, response = app.test_client.get("/v2/page")
|
||||
assert "Hello world!" in response.text
|
||||
|
||||
_, response = app.test_client.get("/v3/page")
|
||||
assert "Hello world!" in response.text
|
||||
|
||||
_, response = app.test_client.get("/v4/page")
|
||||
assert "Hello world!" in response.text
|
||||
|
||||
_, response = app.test_client.get("/v5/page")
|
||||
assert "Hello world!" in response.text
|
||||
|
||||
_, response = app.test_client.get("/version6/page")
|
||||
assert "Hello world!" in response.text
|
||||
@@ -3,6 +3,12 @@ from pytest import raises
|
||||
from sanic.app import Sanic
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.exceptions import (
|
||||
Forbidden,
|
||||
InvalidUsage,
|
||||
SanicException,
|
||||
ServerError,
|
||||
)
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, text
|
||||
|
||||
@@ -96,16 +102,28 @@ def test_bp_group(app: Sanic):
|
||||
def blueprint_1_default_route(request):
|
||||
return text("BP1_OK")
|
||||
|
||||
@blueprint_1.route("/invalid")
|
||||
def blueprint_1_error(request: Request):
|
||||
raise InvalidUsage("Invalid")
|
||||
|
||||
@blueprint_2.route("/")
|
||||
def blueprint_2_default_route(request):
|
||||
return text("BP2_OK")
|
||||
|
||||
@blueprint_2.route("/error")
|
||||
def blueprint_2_error(request: Request):
|
||||
raise ServerError("Error")
|
||||
|
||||
blueprint_group_1 = Blueprint.group(
|
||||
blueprint_1, blueprint_2, url_prefix="/bp"
|
||||
)
|
||||
|
||||
blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3")
|
||||
|
||||
@blueprint_group_1.exception(InvalidUsage)
|
||||
def handle_group_exception(request, exception):
|
||||
return text("BP1_ERR_OK")
|
||||
|
||||
@blueprint_group_1.middleware("request")
|
||||
def blueprint_group_1_middleware(request):
|
||||
global MIDDLEWARE_INVOKE_COUNTER
|
||||
@@ -116,19 +134,47 @@ def test_bp_group(app: Sanic):
|
||||
global MIDDLEWARE_INVOKE_COUNTER
|
||||
MIDDLEWARE_INVOKE_COUNTER["request"] += 1
|
||||
|
||||
@blueprint_group_1.on_request
|
||||
def blueprint_group_1_convenience_1(request):
|
||||
global MIDDLEWARE_INVOKE_COUNTER
|
||||
MIDDLEWARE_INVOKE_COUNTER["request"] += 1
|
||||
|
||||
@blueprint_group_1.on_request()
|
||||
def blueprint_group_1_convenience_2(request):
|
||||
global MIDDLEWARE_INVOKE_COUNTER
|
||||
MIDDLEWARE_INVOKE_COUNTER["request"] += 1
|
||||
|
||||
@blueprint_3.route("/")
|
||||
def blueprint_3_default_route(request):
|
||||
return text("BP3_OK")
|
||||
|
||||
@blueprint_3.route("/forbidden")
|
||||
def blueprint_3_forbidden(request: Request):
|
||||
raise Forbidden("Forbidden")
|
||||
|
||||
blueprint_group_2 = Blueprint.group(
|
||||
blueprint_group_1, blueprint_3, url_prefix="/api"
|
||||
)
|
||||
|
||||
@blueprint_group_2.exception(SanicException)
|
||||
def handle_non_handled_exception(request, exception):
|
||||
return text("BP2_ERR_OK")
|
||||
|
||||
@blueprint_group_2.middleware("response")
|
||||
def blueprint_group_2_middleware(request, response):
|
||||
global MIDDLEWARE_INVOKE_COUNTER
|
||||
MIDDLEWARE_INVOKE_COUNTER["response"] += 1
|
||||
|
||||
@blueprint_group_2.on_response
|
||||
def blueprint_group_2_middleware_convenience_1(request, response):
|
||||
global MIDDLEWARE_INVOKE_COUNTER
|
||||
MIDDLEWARE_INVOKE_COUNTER["response"] += 1
|
||||
|
||||
@blueprint_group_2.on_response()
|
||||
def blueprint_group_2_middleware_convenience_2(request, response):
|
||||
global MIDDLEWARE_INVOKE_COUNTER
|
||||
MIDDLEWARE_INVOKE_COUNTER["response"] += 1
|
||||
|
||||
app.blueprint(blueprint_group_2)
|
||||
|
||||
@app.route("/")
|
||||
@@ -141,14 +187,23 @@ def test_bp_group(app: Sanic):
|
||||
_, response = app.test_client.get("/api/bp/bp1")
|
||||
assert response.text == "BP1_OK"
|
||||
|
||||
_, response = app.test_client.get("/api/bp/bp1/invalid")
|
||||
assert response.text == "BP1_ERR_OK"
|
||||
|
||||
_, response = app.test_client.get("/api/bp/bp2")
|
||||
assert response.text == "BP2_OK"
|
||||
|
||||
_, response = app.test_client.get("/api/bp/bp2/error")
|
||||
assert response.text == "BP2_ERR_OK"
|
||||
|
||||
_, response = app.test_client.get("/api/bp3")
|
||||
assert response.text == "BP3_OK"
|
||||
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4
|
||||
_, response = app.test_client.get("/api/bp3/forbidden")
|
||||
assert response.text == "BP2_ERR_OK"
|
||||
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 18
|
||||
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 16
|
||||
|
||||
|
||||
def test_bp_group_list_operations(app: Sanic):
|
||||
|
||||
@@ -83,7 +83,6 @@ def test_versioned_routes_get(app, method):
|
||||
return text("OK")
|
||||
|
||||
else:
|
||||
print(func)
|
||||
raise Exception(f"{func} is not callable")
|
||||
|
||||
app.blueprint(bp)
|
||||
@@ -477,6 +476,58 @@ def test_bp_exception_handler(app):
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_bp_exception_handler_applied(app):
|
||||
class Error(Exception):
|
||||
pass
|
||||
|
||||
handled = Blueprint("handled")
|
||||
nothandled = Blueprint("nothandled")
|
||||
|
||||
@handled.exception(Error)
|
||||
def handle_error(req, e):
|
||||
return text("handled {}".format(e))
|
||||
|
||||
@handled.route("/ok")
|
||||
def ok(request):
|
||||
raise Error("uh oh")
|
||||
|
||||
@nothandled.route("/notok")
|
||||
def notok(request):
|
||||
raise Error("uh oh")
|
||||
|
||||
app.blueprint(handled)
|
||||
app.blueprint(nothandled)
|
||||
|
||||
_, response = app.test_client.get("/ok")
|
||||
assert response.status == 200
|
||||
assert response.text == "handled uh oh"
|
||||
|
||||
_, response = app.test_client.get("/notok")
|
||||
assert response.status == 500
|
||||
|
||||
|
||||
def test_bp_exception_handler_not_applied(app):
|
||||
class Error(Exception):
|
||||
pass
|
||||
|
||||
handled = Blueprint("handled")
|
||||
nothandled = Blueprint("nothandled")
|
||||
|
||||
@handled.exception(Error)
|
||||
def handle_error(req, e):
|
||||
return text("handled {}".format(e))
|
||||
|
||||
@nothandled.route("/notok")
|
||||
def notok(request):
|
||||
raise Error("uh oh")
|
||||
|
||||
app.blueprint(handled)
|
||||
app.blueprint(nothandled)
|
||||
|
||||
_, response = app.test_client.get("/notok")
|
||||
assert response.status == 500
|
||||
|
||||
|
||||
def test_bp_listeners(app):
|
||||
app.route("/")(lambda x: x)
|
||||
blueprint = Blueprint("test_middleware")
|
||||
@@ -1034,6 +1085,6 @@ def test_bp_set_attribute_warning():
|
||||
assert len(record) == 1
|
||||
assert record[0].message.args[0] == (
|
||||
"Setting variables on Blueprint instances is deprecated "
|
||||
"and will be removed in version 21.9. You should change your "
|
||||
"and will be removed in version 21.12. You should change your "
|
||||
"Blueprint instance to use instance.ctx.foo instead."
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ def test_debug(cmd):
|
||||
out, err, exitcode = capture(command)
|
||||
lines = out.split(b"\n")
|
||||
|
||||
app_info = lines[9]
|
||||
app_info = lines[26]
|
||||
info = json.loads(app_info)
|
||||
|
||||
assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO
|
||||
@@ -103,7 +103,7 @@ def test_auto_reload(cmd):
|
||||
out, err, exitcode = capture(command)
|
||||
lines = out.split(b"\n")
|
||||
|
||||
app_info = lines[9]
|
||||
app_info = lines[26]
|
||||
info = json.loads(app_info)
|
||||
|
||||
assert info["debug"] is False
|
||||
@@ -118,7 +118,7 @@ def test_access_logs(cmd, expected):
|
||||
out, err, exitcode = capture(command)
|
||||
lines = out.split(b"\n")
|
||||
|
||||
app_info = lines[9]
|
||||
app_info = lines[26]
|
||||
info = json.loads(app_info)
|
||||
|
||||
assert info["access_log"] is expected
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from crypt import methods
|
||||
|
||||
from sanic import text
|
||||
from sanic import Sanic, text
|
||||
from sanic.constants import HTTP_METHODS, HTTPMethod
|
||||
|
||||
|
||||
@@ -14,7 +12,7 @@ def test_string_compat():
|
||||
assert HTTPMethod.GET.upper() == "GET"
|
||||
|
||||
|
||||
def test_use_in_routes(app):
|
||||
def test_use_in_routes(app: Sanic):
|
||||
@app.route("/", methods=[HTTPMethod.GET, HTTPMethod.POST])
|
||||
def handler(_):
|
||||
return text("It works")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from queue import Queue
|
||||
from threading import Event
|
||||
|
||||
from sanic.response import text
|
||||
@@ -13,8 +12,6 @@ def test_create_task(app):
|
||||
await asyncio.sleep(0.05)
|
||||
e.set()
|
||||
|
||||
app.add_task(coro)
|
||||
|
||||
@app.route("/early")
|
||||
def not_set(request):
|
||||
return text(str(e.is_set()))
|
||||
@@ -24,24 +21,30 @@ def test_create_task(app):
|
||||
await asyncio.sleep(0.1)
|
||||
return text(str(e.is_set()))
|
||||
|
||||
app.add_task(coro)
|
||||
|
||||
request, response = app.test_client.get("/early")
|
||||
assert response.body == b"False"
|
||||
|
||||
app.signal_router.reset()
|
||||
app.add_task(coro)
|
||||
request, response = app.test_client.get("/late")
|
||||
assert response.body == b"True"
|
||||
|
||||
|
||||
def test_create_task_with_app_arg(app):
|
||||
q = Queue()
|
||||
@app.after_server_start
|
||||
async def setup_q(app, _):
|
||||
app.ctx.q = asyncio.Queue()
|
||||
|
||||
@app.route("/")
|
||||
def not_set(request):
|
||||
return "hello"
|
||||
async def not_set(request):
|
||||
return text(await request.app.ctx.q.get())
|
||||
|
||||
async def coro(app):
|
||||
q.put(app.name)
|
||||
await app.ctx.q.put(app.name)
|
||||
|
||||
app.add_task(coro)
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
assert q.get() == "test_create_task_with_app_arg"
|
||||
_, response = app.test_client.get("/")
|
||||
assert response.text == "test_create_task_with_app_arg"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.errorpages import exception_response
|
||||
from sanic.exceptions import NotFound
|
||||
from sanic.errorpages import HTMLRenderer, exception_response
|
||||
from sanic.exceptions import NotFound, SanicException
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse
|
||||
from sanic.response import HTTPResponse, html, json, text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,7 +20,7 @@ def app():
|
||||
|
||||
@pytest.fixture
|
||||
def fake_request(app):
|
||||
return Request(b"/foobar", {}, "1.1", "GET", None, app)
|
||||
return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -47,7 +47,13 @@ def test_should_return_html_valid_setting(
|
||||
try:
|
||||
raise exception("bad stuff")
|
||||
except Exception as e:
|
||||
response = exception_response(fake_request, e, True)
|
||||
response = exception_response(
|
||||
fake_request,
|
||||
e,
|
||||
True,
|
||||
base=HTMLRenderer,
|
||||
fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT,
|
||||
)
|
||||
|
||||
assert isinstance(response, HTTPResponse)
|
||||
assert response.status == status
|
||||
@@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app):
|
||||
app.config.FALLBACK_ERROR_FORMAT = "auto"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error", headers={"content-type": "application/json"}
|
||||
"/error", headers={"content-type": "application/json", "accept": "*/*"}
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error", headers={"content-type": "text/plain"}
|
||||
"/error", headers={"content-type": "foo/bar", "accept": "*/*"}
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
|
||||
def test_route_error_format_set_on_auto(app):
|
||||
@app.get("/text")
|
||||
def text_response(request):
|
||||
return text(request.route.ctx.error_format)
|
||||
|
||||
@app.get("/json")
|
||||
def json_response(request):
|
||||
return json({"format": request.route.ctx.error_format})
|
||||
|
||||
@app.get("/html")
|
||||
def html_response(request):
|
||||
return html(request.route.ctx.error_format)
|
||||
|
||||
_, response = app.test_client.get("/text")
|
||||
assert response.text == "text"
|
||||
|
||||
_, response = app.test_client.get("/json")
|
||||
assert response.json["format"] == "json"
|
||||
|
||||
_, response = app.test_client.get("/html")
|
||||
assert response.text == "html"
|
||||
|
||||
|
||||
def test_route_error_response_from_auto_route(app):
|
||||
@app.get("/text")
|
||||
def text_response(request):
|
||||
raise Exception("oops")
|
||||
return text("Never gonna see this")
|
||||
|
||||
@app.get("/json")
|
||||
def json_response(request):
|
||||
raise Exception("oops")
|
||||
return json({"message": "Never gonna see this"})
|
||||
|
||||
@app.get("/html")
|
||||
def html_response(request):
|
||||
raise Exception("oops")
|
||||
return html("<h1>Never gonna see this</h1>")
|
||||
|
||||
_, response = app.test_client.get("/text")
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
_, response = app.test_client.get("/json")
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get("/html")
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
|
||||
def test_route_error_response_from_explicit_format(app):
|
||||
@app.get("/text", error_format="json")
|
||||
def text_response(request):
|
||||
raise Exception("oops")
|
||||
return text("Never gonna see this")
|
||||
|
||||
@app.get("/json", error_format="text")
|
||||
def json_response(request):
|
||||
raise Exception("oops")
|
||||
return json({"message": "Never gonna see this"})
|
||||
|
||||
_, response = app.test_client.get("/text")
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get("/json")
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
|
||||
def test_unknown_fallback_format(app):
|
||||
with pytest.raises(SanicException, match="Unknown format: bad"):
|
||||
app.config.FALLBACK_ERROR_FORMAT = "bad"
|
||||
|
||||
|
||||
def test_route_error_format_unknown(app):
|
||||
with pytest.raises(SanicException, match="Unknown format: bad"):
|
||||
|
||||
@app.get("/text", error_format="bad")
|
||||
def handler(request):
|
||||
...
|
||||
|
||||
|
||||
def test_fallback_with_content_type_mismatch_accept(app):
|
||||
app.config.FALLBACK_ERROR_FORMAT = "auto"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error",
|
||||
headers={"content-type": "application/json", "accept": "text/plain"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error",
|
||||
headers={"content-type": "text/plain", "accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
app.router.reset()
|
||||
|
||||
@app.route("/alt1")
|
||||
@app.route("/alt2", error_format="text")
|
||||
@app.route("/alt3", error_format="html")
|
||||
def handler(_):
|
||||
raise Exception("problem here")
|
||||
# Yes, we know this return value is unreachable. This is on purpose.
|
||||
return json({})
|
||||
|
||||
app.router.finalize()
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/alt1",
|
||||
headers={"accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
_, response = app.test_client.get(
|
||||
"/alt1",
|
||||
headers={"accept": "foo/bar,*/*"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/alt2",
|
||||
headers={"accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
_, response = app.test_client.get(
|
||||
"/alt2",
|
||||
headers={"accept": "foo/bar,*/*"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/alt3",
|
||||
headers={"accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"accept,content_type,expected",
|
||||
(
|
||||
(None, None, "text/plain; charset=utf-8"),
|
||||
("foo/bar", None, "text/html; charset=utf-8"),
|
||||
("application/json", None, "application/json"),
|
||||
("application/json,text/plain", None, "application/json"),
|
||||
("text/plain,application/json", None, "application/json"),
|
||||
("text/plain,foo/bar", None, "text/plain; charset=utf-8"),
|
||||
# Following test is valid after v22.3
|
||||
# ("text/plain,text/html", None, "text/plain; charset=utf-8"),
|
||||
("*/*", "foo/bar", "text/html; charset=utf-8"),
|
||||
("*/*", "application/json", "application/json"),
|
||||
),
|
||||
)
|
||||
def test_combinations_for_auto(fake_request, accept, content_type, expected):
|
||||
if accept:
|
||||
fake_request.headers["accept"] = accept
|
||||
else:
|
||||
del fake_request.headers["accept"]
|
||||
|
||||
if content_type:
|
||||
fake_request.headers["content-type"] = content_type
|
||||
|
||||
try:
|
||||
raise Exception("bad stuff")
|
||||
except Exception as e:
|
||||
response = exception_response(
|
||||
fake_request,
|
||||
e,
|
||||
True,
|
||||
base=HTMLRenderer,
|
||||
fallback="auto",
|
||||
)
|
||||
|
||||
assert response.content_type == expected
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.version import version as websockets_version
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.exceptions import (
|
||||
@@ -232,3 +234,41 @@ def test_sanic_exception(exception_app):
|
||||
request, response = exception_app.test_client.get("/old_abort")
|
||||
assert response.status == 500
|
||||
assert len(w) == 1 and "deprecated" in w[0].message.args[0]
|
||||
|
||||
|
||||
def test_custom_exception_default_message(exception_app):
|
||||
class TeaError(SanicException):
|
||||
message = "Tempest in a teapot"
|
||||
status_code = 418
|
||||
|
||||
exception_app.router.reset()
|
||||
|
||||
@exception_app.get("/tempest")
|
||||
def tempest(_):
|
||||
raise TeaError
|
||||
|
||||
_, response = exception_app.test_client.get("/tempest", debug=True)
|
||||
assert response.status == 418
|
||||
assert b"Tempest in a teapot" in response.body
|
||||
|
||||
|
||||
def test_exception_in_ws_logged(caplog):
|
||||
app = Sanic(__file__)
|
||||
|
||||
@app.websocket("/feed")
|
||||
async def feed(request, ws):
|
||||
raise Exception("...")
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.test_client.websocket("/feed")
|
||||
# Websockets v10.0 and above output an additional
|
||||
# INFO message when a ws connection is accepted
|
||||
ws_version_parts = websockets_version.split(".")
|
||||
ws_major = int(ws_version_parts[0])
|
||||
record_index = 2 if ws_major >= 10 else 1
|
||||
assert caplog.record_tuples[record_index][0] == "sanic.error"
|
||||
assert caplog.record_tuples[record_index][1] == logging.ERROR
|
||||
assert (
|
||||
"Exception occurred while handling uri:"
|
||||
in caplog.record_tuples[record_index][2]
|
||||
)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
@@ -8,9 +11,6 @@ from sanic.handlers import ErrorHandler
|
||||
from sanic.response import stream, text
|
||||
|
||||
|
||||
exception_handler_app = Sanic("test_exception_handler")
|
||||
|
||||
|
||||
async def sample_streaming_fn(response):
|
||||
await response.write("foo,")
|
||||
await asyncio.sleep(0.001)
|
||||
@@ -21,36 +21,35 @@ class ErrorWithRequestCtx(ServerError):
|
||||
pass
|
||||
|
||||
|
||||
@exception_handler_app.route("/1")
|
||||
@pytest.fixture
|
||||
def exception_handler_app():
|
||||
exception_handler_app = Sanic("test_exception_handler")
|
||||
|
||||
@exception_handler_app.route("/1", error_format="html")
|
||||
def handler_1(request):
|
||||
raise InvalidUsage("OK")
|
||||
|
||||
|
||||
@exception_handler_app.route("/2")
|
||||
@exception_handler_app.route("/2", error_format="html")
|
||||
def handler_2(request):
|
||||
raise ServerError("OK")
|
||||
|
||||
|
||||
@exception_handler_app.route("/3")
|
||||
@exception_handler_app.route("/3", error_format="html")
|
||||
def handler_3(request):
|
||||
raise NotFound("OK")
|
||||
|
||||
|
||||
@exception_handler_app.route("/4")
|
||||
@exception_handler_app.route("/4", error_format="html")
|
||||
def handler_4(request):
|
||||
foo = bar # noqa -- F821 undefined name 'bar' is done to throw exception
|
||||
foo = bar # noqa -- F821
|
||||
return text(foo)
|
||||
|
||||
|
||||
@exception_handler_app.route("/5")
|
||||
@exception_handler_app.route("/5", error_format="html")
|
||||
def handler_5(request):
|
||||
class CustomServerError(ServerError):
|
||||
pass
|
||||
|
||||
raise CustomServerError("Custom server error")
|
||||
|
||||
|
||||
@exception_handler_app.route("/6/<arg:int>")
|
||||
@exception_handler_app.route("/6/<arg:int>", error_format="html")
|
||||
def handler_6(request, arg):
|
||||
try:
|
||||
foo = 1 / arg
|
||||
@@ -58,28 +57,23 @@ def handler_6(request, arg):
|
||||
raise e from ValueError(f"{arg}")
|
||||
return text(foo)
|
||||
|
||||
|
||||
@exception_handler_app.route("/7")
|
||||
@exception_handler_app.route("/7", error_format="html")
|
||||
def handler_7(request):
|
||||
raise Forbidden("go away!")
|
||||
|
||||
|
||||
@exception_handler_app.route("/8")
|
||||
@exception_handler_app.route("/8", error_format="html")
|
||||
def handler_8(request):
|
||||
|
||||
raise ErrorWithRequestCtx("OK")
|
||||
|
||||
|
||||
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
|
||||
def handler_exception_with_ctx(request, exception):
|
||||
return text(request.ctx.middleware_ran)
|
||||
|
||||
|
||||
@exception_handler_app.exception(ServerError)
|
||||
def handler_exception(request, exception):
|
||||
return text("OK")
|
||||
|
||||
|
||||
@exception_handler_app.exception(Forbidden)
|
||||
async def async_handler_exception(request, exception):
|
||||
return stream(
|
||||
@@ -87,47 +81,47 @@ async def async_handler_exception(request, exception):
|
||||
content_type="text/csv",
|
||||
)
|
||||
|
||||
|
||||
@exception_handler_app.middleware
|
||||
async def some_request_middleware(request):
|
||||
request.ctx.middleware_ran = "Done."
|
||||
|
||||
return exception_handler_app
|
||||
|
||||
def test_invalid_usage_exception_handler():
|
||||
|
||||
def test_invalid_usage_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/1")
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
def test_server_error_exception_handler():
|
||||
def test_server_error_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/2")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
|
||||
|
||||
def test_not_found_exception_handler():
|
||||
def test_not_found_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/3")
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_text_exception__handler():
|
||||
def test_text_exception__handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/random")
|
||||
assert response.status == 200
|
||||
assert response.text == "Done."
|
||||
|
||||
|
||||
def test_async_exception_handler():
|
||||
def test_async_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/7")
|
||||
assert response.status == 200
|
||||
assert response.text == "foo,bar"
|
||||
|
||||
|
||||
def test_html_traceback_output_in_debug_mode():
|
||||
def test_html_traceback_output_in_debug_mode(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/4", debug=True)
|
||||
assert response.status == 500
|
||||
soup = BeautifulSoup(response.body, "html.parser")
|
||||
html = str(soup)
|
||||
|
||||
assert "response = handler(request, **kwargs)" in html
|
||||
assert "handler_4" in html
|
||||
assert "foo = bar" in html
|
||||
|
||||
@@ -137,12 +131,12 @@ def test_html_traceback_output_in_debug_mode():
|
||||
) == summary_text
|
||||
|
||||
|
||||
def test_inherited_exception_handler():
|
||||
def test_inherited_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/5")
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_chained_exception_handler():
|
||||
def test_chained_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get(
|
||||
"/6/0", debug=True
|
||||
)
|
||||
@@ -151,11 +145,9 @@ def test_chained_exception_handler():
|
||||
soup = BeautifulSoup(response.body, "html.parser")
|
||||
html = str(soup)
|
||||
|
||||
assert "response = handler(request, **kwargs)" in html
|
||||
assert "handler_6" in html
|
||||
assert "foo = 1 / arg" in html
|
||||
assert "ValueError" in html
|
||||
assert "The above exception was the direct cause" in html
|
||||
|
||||
summary_text = " ".join(soup.select(".summary")[0].text.split())
|
||||
assert (
|
||||
@@ -163,7 +155,7 @@ def test_chained_exception_handler():
|
||||
) == summary_text
|
||||
|
||||
|
||||
def test_exception_handler_lookup():
|
||||
def test_exception_handler_lookup(exception_handler_app):
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
@@ -186,26 +178,52 @@ def test_exception_handler_lookup():
|
||||
class ModuleNotFoundError(ImportError):
|
||||
pass
|
||||
|
||||
handler = ErrorHandler()
|
||||
handler = ErrorHandler("auto")
|
||||
handler.add(ImportError, import_error_handler)
|
||||
handler.add(CustomError, custom_error_handler)
|
||||
handler.add(ServerError, server_error_handler)
|
||||
|
||||
assert handler.lookup(ImportError()) == import_error_handler
|
||||
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
|
||||
assert handler.lookup(CustomError()) == custom_error_handler
|
||||
assert handler.lookup(ServerError("Error")) == server_error_handler
|
||||
assert handler.lookup(CustomServerError("Error")) == server_error_handler
|
||||
assert handler.lookup(ImportError(), None) == import_error_handler
|
||||
assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler
|
||||
assert handler.lookup(CustomError(), None) == custom_error_handler
|
||||
assert handler.lookup(ServerError("Error"), None) == server_error_handler
|
||||
assert (
|
||||
handler.lookup(CustomServerError("Error"), None)
|
||||
== server_error_handler
|
||||
)
|
||||
|
||||
# once again to ensure there is no caching bug
|
||||
assert handler.lookup(ImportError()) == import_error_handler
|
||||
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
|
||||
assert handler.lookup(CustomError()) == custom_error_handler
|
||||
assert handler.lookup(ServerError("Error")) == server_error_handler
|
||||
assert handler.lookup(CustomServerError("Error")) == server_error_handler
|
||||
assert handler.lookup(ImportError(), None) == import_error_handler
|
||||
assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler
|
||||
assert handler.lookup(CustomError(), None) == custom_error_handler
|
||||
assert handler.lookup(ServerError("Error"), None) == server_error_handler
|
||||
assert (
|
||||
handler.lookup(CustomServerError("Error"), None)
|
||||
== server_error_handler
|
||||
)
|
||||
|
||||
|
||||
def test_exception_handler_processed_request_middleware():
|
||||
def test_exception_handler_processed_request_middleware(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/8")
|
||||
assert response.status == 200
|
||||
assert response.text == "Done."
|
||||
|
||||
|
||||
def test_single_arg_exception_handler_notice(exception_handler_app, caplog):
|
||||
class CustomErrorHandler(ErrorHandler):
|
||||
def lookup(self, exception):
|
||||
return super().lookup(exception, None)
|
||||
|
||||
exception_handler_app.error_handler = CustomErrorHandler()
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
_, response = exception_handler_app.test_client.get("/1")
|
||||
|
||||
assert caplog.records[0].message == (
|
||||
"You are using a deprecated error handler. The lookup method should "
|
||||
"accept two positional parameters: (exception, route_name: "
|
||||
"Optional[str]). Until you upgrade your ErrorHandler.lookup, "
|
||||
"Blueprint specific exceptions will not work properly. Beginning in "
|
||||
"v22.3, the legacy style lookup method will not work at all."
|
||||
)
|
||||
assert response.status == 400
|
||||
|
||||
46
tests/test_graceful_shutdown.py
Normal file
46
tests/test_graceful_shutdown.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
from collections import Counter
|
||||
from multiprocessing import Process
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
PORT = 42101
|
||||
|
||||
|
||||
def test_no_exceptions_when_cancel_pending_request(app, caplog):
|
||||
app.config.GRACEFUL_SHUTDOWN_TIMEOUT = 1
|
||||
|
||||
@app.get("/")
|
||||
async def handler(request):
|
||||
await asyncio.sleep(5)
|
||||
|
||||
@app.after_server_start
|
||||
def shutdown(app, _):
|
||||
time.sleep(0.2)
|
||||
app.stop()
|
||||
|
||||
def ping():
|
||||
time.sleep(0.1)
|
||||
response = httpx.get("http://127.0.0.1:8000")
|
||||
print(response.status_code)
|
||||
|
||||
p = Process(target=ping)
|
||||
p.start()
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.run()
|
||||
|
||||
p.kill()
|
||||
|
||||
counter = Counter([r[1] for r in caplog.record_tuples])
|
||||
|
||||
assert counter[logging.INFO] == 5
|
||||
assert logging.ERROR not in counter
|
||||
assert (
|
||||
caplog.record_tuples[3][2]
|
||||
== "Request: GET http://127.0.0.1:8000/ stopped. Transport is closed."
|
||||
)
|
||||
39
tests/test_handler_annotations.py
Normal file
39
tests/test_handler_annotations.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic import json
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"idx,path,expectation",
|
||||
(
|
||||
(0, "/abc", "str"),
|
||||
(1, "/123", "int"),
|
||||
(2, "/123.5", "float"),
|
||||
(3, "/8af729fe-2b94-4a95-a168-c07068568429", "UUID"),
|
||||
),
|
||||
)
|
||||
def test_annotated_handlers(app, idx, path, expectation):
|
||||
def build_response(num, foo):
|
||||
return json({"num": num, "type": type(foo).__name__})
|
||||
|
||||
@app.get("/<foo>")
|
||||
def handler0(_, foo: str):
|
||||
return build_response(0, foo)
|
||||
|
||||
@app.get("/<foo>")
|
||||
def handler1(_, foo: int):
|
||||
return build_response(1, foo)
|
||||
|
||||
@app.get("/<foo>")
|
||||
def handler2(_, foo: float):
|
||||
return build_response(2, foo)
|
||||
|
||||
@app.get("/<foo>")
|
||||
def handler3(_, foo: UUID):
|
||||
return build_response(3, foo)
|
||||
|
||||
_, response = app.test_client.get(path)
|
||||
assert response.json["num"] == idx
|
||||
assert response.json["type"] == expectation
|
||||
@@ -3,8 +3,9 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from sanic import headers, text
|
||||
from sanic.exceptions import PayloadTooLarge
|
||||
from sanic.exceptions import InvalidHeader, PayloadTooLarge
|
||||
from sanic.http import Http
|
||||
from sanic.request import Request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -182,3 +183,187 @@ def test_request_line(app):
|
||||
)
|
||||
|
||||
assert request.request_line == b"GET / HTTP/1.1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw",
|
||||
(
|
||||
"show/first, show/second",
|
||||
"show/*, show/first",
|
||||
"*/*, show/first",
|
||||
"*/*, show/*",
|
||||
"other/*; q=0.1, show/*; q=0.2",
|
||||
"show/first; q=0.5, show/second; q=0.5",
|
||||
"show/first; foo=bar, show/second; foo=bar",
|
||||
"show/second, show/first; foo=bar",
|
||||
"show/second; q=0.5, show/first; foo=bar; q=0.5",
|
||||
"show/second; q=0.5, show/first; q=1.0",
|
||||
"show/first, show/second; q=1.0",
|
||||
),
|
||||
)
|
||||
def test_parse_accept_ordered_okay(raw):
|
||||
ordered = headers.parse_accept(raw)
|
||||
expected_subtype = (
|
||||
"*" if all(q.subtype.is_wildcard for q in ordered) else "first"
|
||||
)
|
||||
assert ordered[0].type_ == "show"
|
||||
assert ordered[0].subtype == expected_subtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw",
|
||||
(
|
||||
"missing",
|
||||
"missing/",
|
||||
"/missing",
|
||||
),
|
||||
)
|
||||
def test_bad_accept(raw):
|
||||
with pytest.raises(InvalidHeader):
|
||||
headers.parse_accept(raw)
|
||||
|
||||
|
||||
def test_empty_accept():
|
||||
assert headers.parse_accept("") == []
|
||||
|
||||
|
||||
def test_wildcard_accept_set_ok():
|
||||
accept = headers.parse_accept("*/*")[0]
|
||||
assert accept.type_.is_wildcard
|
||||
assert accept.subtype.is_wildcard
|
||||
|
||||
accept = headers.parse_accept("foo/bar")[0]
|
||||
assert not accept.type_.is_wildcard
|
||||
assert not accept.subtype.is_wildcard
|
||||
|
||||
|
||||
def test_accept_parsed_against_str():
|
||||
accept = headers.Accept.parse("foo/bar")
|
||||
assert accept > "foo/bar; q=0.1"
|
||||
|
||||
|
||||
def test_media_type_equality():
|
||||
assert headers.MediaType("foo") == headers.MediaType("foo") == "foo"
|
||||
assert headers.MediaType("foo") == headers.MediaType("*") == "*"
|
||||
assert headers.MediaType("foo") != headers.MediaType("bar")
|
||||
assert headers.MediaType("foo") != "bar"
|
||||
|
||||
|
||||
def test_media_type_matching():
|
||||
assert headers.MediaType("foo").match(headers.MediaType("foo"))
|
||||
assert headers.MediaType("foo").match("foo")
|
||||
|
||||
assert not headers.MediaType("foo").match(headers.MediaType("*"))
|
||||
assert not headers.MediaType("foo").match("*")
|
||||
|
||||
assert not headers.MediaType("foo").match(headers.MediaType("bar"))
|
||||
assert not headers.MediaType("foo").match("bar")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,other,outcome,allow_type,allow_subtype",
|
||||
(
|
||||
# ALLOW BOTH
|
||||
("foo/bar", "foo/bar", True, True, True),
|
||||
("foo/bar", headers.Accept.parse("foo/bar"), True, True, True),
|
||||
("foo/bar", "foo/*", True, True, True),
|
||||
("foo/bar", headers.Accept.parse("foo/*"), True, True, True),
|
||||
("foo/bar", "*/*", True, True, True),
|
||||
("foo/bar", headers.Accept.parse("*/*"), True, True, True),
|
||||
("foo/*", "foo/bar", True, True, True),
|
||||
("foo/*", headers.Accept.parse("foo/bar"), True, True, True),
|
||||
("foo/*", "foo/*", True, True, True),
|
||||
("foo/*", headers.Accept.parse("foo/*"), True, True, True),
|
||||
("foo/*", "*/*", True, True, True),
|
||||
("foo/*", headers.Accept.parse("*/*"), True, True, True),
|
||||
("*/*", "foo/bar", True, True, True),
|
||||
("*/*", headers.Accept.parse("foo/bar"), True, True, True),
|
||||
("*/*", "foo/*", True, True, True),
|
||||
("*/*", headers.Accept.parse("foo/*"), True, True, True),
|
||||
("*/*", "*/*", True, True, True),
|
||||
("*/*", headers.Accept.parse("*/*"), True, True, True),
|
||||
# ALLOW TYPE
|
||||
("foo/bar", "foo/bar", True, True, False),
|
||||
("foo/bar", headers.Accept.parse("foo/bar"), True, True, False),
|
||||
("foo/bar", "foo/*", False, True, False),
|
||||
("foo/bar", headers.Accept.parse("foo/*"), False, True, False),
|
||||
("foo/bar", "*/*", False, True, False),
|
||||
("foo/bar", headers.Accept.parse("*/*"), False, True, False),
|
||||
("foo/*", "foo/bar", False, True, False),
|
||||
("foo/*", headers.Accept.parse("foo/bar"), False, True, False),
|
||||
("foo/*", "foo/*", False, True, False),
|
||||
("foo/*", headers.Accept.parse("foo/*"), False, True, False),
|
||||
("foo/*", "*/*", False, True, False),
|
||||
("foo/*", headers.Accept.parse("*/*"), False, True, False),
|
||||
("*/*", "foo/bar", False, True, False),
|
||||
("*/*", headers.Accept.parse("foo/bar"), False, True, False),
|
||||
("*/*", "foo/*", False, True, False),
|
||||
("*/*", headers.Accept.parse("foo/*"), False, True, False),
|
||||
("*/*", "*/*", False, True, False),
|
||||
("*/*", headers.Accept.parse("*/*"), False, True, False),
|
||||
# ALLOW SUBTYPE
|
||||
("foo/bar", "foo/bar", True, False, True),
|
||||
("foo/bar", headers.Accept.parse("foo/bar"), True, False, True),
|
||||
("foo/bar", "foo/*", True, False, True),
|
||||
("foo/bar", headers.Accept.parse("foo/*"), True, False, True),
|
||||
("foo/bar", "*/*", False, False, True),
|
||||
("foo/bar", headers.Accept.parse("*/*"), False, False, True),
|
||||
("foo/*", "foo/bar", True, False, True),
|
||||
("foo/*", headers.Accept.parse("foo/bar"), True, False, True),
|
||||
("foo/*", "foo/*", True, False, True),
|
||||
("foo/*", headers.Accept.parse("foo/*"), True, False, True),
|
||||
("foo/*", "*/*", False, False, True),
|
||||
("foo/*", headers.Accept.parse("*/*"), False, False, True),
|
||||
("*/*", "foo/bar", False, False, True),
|
||||
("*/*", headers.Accept.parse("foo/bar"), False, False, True),
|
||||
("*/*", "foo/*", False, False, True),
|
||||
("*/*", headers.Accept.parse("foo/*"), False, False, True),
|
||||
("*/*", "*/*", False, False, True),
|
||||
("*/*", headers.Accept.parse("*/*"), False, False, True),
|
||||
),
|
||||
)
|
||||
def test_accept_matching(value, other, outcome, allow_type, allow_subtype):
|
||||
assert (
|
||||
headers.Accept.parse(value).match(
|
||||
other,
|
||||
allow_type_wildcard=allow_type,
|
||||
allow_subtype_wildcard=allow_subtype,
|
||||
)
|
||||
is outcome
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*"))
|
||||
def test_value_in_accept(value):
|
||||
acceptable = headers.parse_accept(value)
|
||||
assert "foo/bar" in acceptable
|
||||
assert "foo/*" in acceptable
|
||||
assert "*/*" in acceptable
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ("foo/bar", "foo/*"))
|
||||
def test_value_not_in_accept(value):
|
||||
acceptable = headers.parse_accept(value)
|
||||
assert "no/match" not in acceptable
|
||||
assert "no/*" not in acceptable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header,expected",
|
||||
(
|
||||
(
|
||||
"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501
|
||||
[
|
||||
"text/html",
|
||||
"application/xhtml+xml",
|
||||
"image/avif",
|
||||
"image/webp",
|
||||
"application/xml;q=0.9",
|
||||
"*/*;q=0.8",
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_browser_headers(header, expected):
|
||||
request = Request(b"/", {"accept": header}, "1.1", "GET", None, None)
|
||||
assert request.accept == expected
|
||||
|
||||
137
tests/test_http.py
Normal file
137
tests/test_http.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import asyncio
|
||||
import json as stdjson
|
||||
|
||||
from collections import namedtuple
|
||||
from textwrap import dedent
|
||||
from typing import AnyStr
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic_testing.reusable import ReusableClient
|
||||
|
||||
from sanic import json, text
|
||||
from sanic.app import Sanic
|
||||
|
||||
|
||||
PORT = 1234
|
||||
|
||||
|
||||
class RawClient:
|
||||
CRLF = b"\r\n"
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
async def connect(self):
|
||||
self.reader, self.writer = await asyncio.open_connection(
|
||||
self.host, self.port
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
|
||||
async def send(self, message: AnyStr):
|
||||
if isinstance(message, str):
|
||||
msg = self._clean(message).encode("utf-8")
|
||||
else:
|
||||
msg = message
|
||||
await self._send(msg)
|
||||
|
||||
async def _send(self, message: bytes):
|
||||
if not self.writer:
|
||||
raise Exception("No open write stream")
|
||||
self.writer.write(message)
|
||||
|
||||
async def recv(self, nbytes: int = -1) -> bytes:
|
||||
if not self.reader:
|
||||
raise Exception("No open read stream")
|
||||
return await self.reader.read(nbytes)
|
||||
|
||||
def _clean(self, message: str) -> str:
|
||||
return (
|
||||
dedent(message)
|
||||
.lstrip("\n")
|
||||
.replace("\n", self.CRLF.decode("utf-8"))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(app: Sanic):
|
||||
app.config.KEEP_ALIVE_TIMEOUT = 1
|
||||
|
||||
@app.get("/")
|
||||
async def base_handler(request):
|
||||
return text("111122223333444455556666777788889999")
|
||||
|
||||
@app.post("/upload", stream=True)
|
||||
async def upload_handler(request):
|
||||
data = [part.decode("utf-8") async for part in request.stream]
|
||||
return json(data)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(test_app):
|
||||
client = ReusableClient(test_app, port=PORT)
|
||||
client.run()
|
||||
yield client
|
||||
client.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(runner):
|
||||
client = namedtuple("Client", ("raw", "send", "recv"))
|
||||
|
||||
raw = RawClient(runner.host, runner.port)
|
||||
runner._run(raw.connect())
|
||||
|
||||
def send(msg):
|
||||
nonlocal runner
|
||||
nonlocal raw
|
||||
runner._run(raw.send(msg))
|
||||
|
||||
def recv(**kwargs):
|
||||
nonlocal runner
|
||||
nonlocal raw
|
||||
method = raw.recv_until if "until" in kwargs else raw.recv
|
||||
return runner._run(method(**kwargs))
|
||||
|
||||
yield client(raw, send, recv)
|
||||
|
||||
runner._run(raw.close())
|
||||
|
||||
|
||||
def test_full_message(client):
|
||||
client.send(
|
||||
"""
|
||||
GET / HTTP/1.1
|
||||
host: localhost:7777
|
||||
|
||||
"""
|
||||
)
|
||||
response = client.recv()
|
||||
assert len(response) == 140
|
||||
assert b"200 OK" in response
|
||||
|
||||
|
||||
def test_transfer_chunked(client):
|
||||
client.send(
|
||||
"""
|
||||
POST /upload HTTP/1.1
|
||||
transfer-encoding: chunked
|
||||
|
||||
"""
|
||||
)
|
||||
client.send(b"3\r\nfoo\r\n")
|
||||
client.send(b"3\r\nbar\r\n")
|
||||
client.send(b"0\r\n\r\n")
|
||||
response = client.recv()
|
||||
_, body = response.rsplit(b"\r\n\r\n", 1)
|
||||
data = stdjson.loads(body)
|
||||
|
||||
assert data == ["foo", "bar"]
|
||||
@@ -2,16 +2,13 @@ import asyncio
|
||||
import platform
|
||||
|
||||
from asyncio import sleep as aio_sleep
|
||||
from json import JSONDecodeError
|
||||
from os import environ
|
||||
|
||||
import httpcore
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from sanic_testing.testing import HOST, SanicTestClient
|
||||
from sanic_testing.reusable import ReusableClient
|
||||
|
||||
from sanic import Sanic, server
|
||||
from sanic import Sanic
|
||||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.response import text
|
||||
|
||||
@@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
|
||||
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
|
||||
|
||||
|
||||
class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool):
|
||||
last_reused_connection = None
|
||||
|
||||
async def _get_connection_from_pool(self, *args, **kwargs):
|
||||
conn = await super()._get_connection_from_pool(*args, **kwargs)
|
||||
self.__class__.last_reused_connection = conn
|
||||
return conn
|
||||
|
||||
|
||||
class ResusableSanicSession(httpx.AsyncClient):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
transport = ReusableSanicConnectionPool()
|
||||
super().__init__(transport=transport, *args, **kwargs)
|
||||
|
||||
|
||||
class ReuseableSanicTestClient(SanicTestClient):
|
||||
def __init__(self, app, loop=None):
|
||||
super().__init__(app)
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
self._server = None
|
||||
self._tcp_connector = None
|
||||
self._session = None
|
||||
|
||||
def get_new_session(self):
|
||||
return ResusableSanicSession()
|
||||
|
||||
# Copied from SanicTestClient, but with some changes to reuse the
|
||||
# same loop for the same app.
|
||||
def _sanic_endpoint_test(
|
||||
self,
|
||||
method="get",
|
||||
uri="/",
|
||||
gather_request=True,
|
||||
debug=False,
|
||||
server_kwargs=None,
|
||||
*request_args,
|
||||
**request_kwargs,
|
||||
):
|
||||
loop = self._loop
|
||||
results = [None, None]
|
||||
exceptions = []
|
||||
server_kwargs = server_kwargs or {"return_asyncio_server": True}
|
||||
if gather_request:
|
||||
|
||||
def _collect_request(request):
|
||||
if results[0] is None:
|
||||
results[0] = request
|
||||
|
||||
self.app.request_middleware.appendleft(_collect_request)
|
||||
|
||||
if uri.startswith(
|
||||
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
|
||||
):
|
||||
url = uri
|
||||
else:
|
||||
uri = uri if uri.startswith("/") else f"/{uri}"
|
||||
scheme = "http"
|
||||
url = f"{scheme}://{HOST}:{PORT}{uri}"
|
||||
|
||||
@self.app.listener("after_server_start")
|
||||
async def _collect_response(loop):
|
||||
try:
|
||||
response = await self._local_request(
|
||||
method, url, *request_args, **request_kwargs
|
||||
)
|
||||
results[-1] = response
|
||||
except Exception as e2:
|
||||
exceptions.append(e2)
|
||||
|
||||
if self._server is not None:
|
||||
_server = self._server
|
||||
else:
|
||||
_server_co = self.app.create_server(
|
||||
host=HOST, debug=debug, port=PORT, **server_kwargs
|
||||
)
|
||||
|
||||
server.trigger_events(
|
||||
self.app.listeners["before_server_start"], loop
|
||||
)
|
||||
|
||||
try:
|
||||
loop._stopping = False
|
||||
_server = loop.run_until_complete(_server_co)
|
||||
except Exception as e1:
|
||||
raise e1
|
||||
self._server = _server
|
||||
server.trigger_events(self.app.listeners["after_server_start"], loop)
|
||||
self.app.listeners["after_server_start"].pop()
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(f"Exception during request: {exceptions}")
|
||||
|
||||
if gather_request:
|
||||
self.app.request_middleware.pop()
|
||||
try:
|
||||
request, response = results
|
||||
return request, response
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Request and response object expected, got ({results})"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return results[-1]
|
||||
except Exception:
|
||||
raise ValueError(f"Request object expected, got ({results})")
|
||||
|
||||
def kill_server(self):
|
||||
try:
|
||||
if self._server:
|
||||
self._server.close()
|
||||
self._loop.run_until_complete(self._server.wait_closed())
|
||||
self._server = None
|
||||
|
||||
if self._session:
|
||||
self._loop.run_until_complete(self._session.aclose())
|
||||
self._session = None
|
||||
|
||||
except Exception as e3:
|
||||
raise e3
|
||||
|
||||
# Copied from SanicTestClient, but with some changes to reuse the
|
||||
# same TCPConnection and the sane ClientSession more than once.
|
||||
# Note, you cannot use the same session if you are in a _different_
|
||||
# loop, so the changes above are required too.
|
||||
async def _local_request(self, method, url, *args, **kwargs):
|
||||
raw_cookies = kwargs.pop("raw_cookies", None)
|
||||
request_keepalive = kwargs.pop(
|
||||
"request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"]
|
||||
)
|
||||
if not self._session:
|
||||
self._session = self.get_new_session()
|
||||
try:
|
||||
response = await getattr(self._session, method.lower())(
|
||||
url, timeout=request_keepalive, *args, **kwargs
|
||||
)
|
||||
except NameError:
|
||||
raise Exception(response.status_code)
|
||||
|
||||
try:
|
||||
response.json = response.json()
|
||||
except (JSONDecodeError, UnicodeDecodeError):
|
||||
response.json = None
|
||||
|
||||
response.body = await response.aread()
|
||||
response.status = response.status_code
|
||||
response.content_type = response.headers.get("content-type")
|
||||
|
||||
if raw_cookies:
|
||||
response.raw_cookies = {}
|
||||
for cookie in response.cookies:
|
||||
response.raw_cookies[cookie.name] = cookie
|
||||
|
||||
return response
|
||||
|
||||
|
||||
keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse")
|
||||
keep_alive_app_client_timeout = Sanic("test_ka_client_timeout")
|
||||
keep_alive_app_server_timeout = Sanic("test_ka_server_timeout")
|
||||
@@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse():
|
||||
"""If the server keep-alive timeout and client keep-alive timeout are
|
||||
both longer than the delay, the client _and_ server will successfully
|
||||
reuse the existing connection."""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop)
|
||||
client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT)
|
||||
with client:
|
||||
headers = {"Connection": "keep-alive"}
|
||||
request, response = client.get("/1", headers=headers)
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
assert request.protocol.state["requests_count"] == 1
|
||||
|
||||
loop.run_until_complete(aio_sleep(1))
|
||||
|
||||
request, response = client.get("/1")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
assert ReusableSanicConnectionPool.last_reused_connection
|
||||
finally:
|
||||
client.kill_server()
|
||||
assert request.protocol.state["requests_count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse():
|
||||
def test_keep_alive_client_timeout():
|
||||
"""If the server keep-alive timeout is longer than the client
|
||||
keep-alive timeout, client will try to create a new connection here."""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop)
|
||||
client = ReusableClient(
|
||||
keep_alive_app_client_timeout, loop=loop, port=PORT
|
||||
)
|
||||
with client:
|
||||
headers = {"Connection": "keep-alive"}
|
||||
_, response = client.get("/1", headers=headers, request_keepalive=1)
|
||||
request, response = client.get("/1", headers=headers, timeout=1)
|
||||
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
assert request.protocol.state["requests_count"] == 1
|
||||
|
||||
loop.run_until_complete(aio_sleep(2))
|
||||
_, response = client.get("/1", request_keepalive=1)
|
||||
|
||||
assert ReusableSanicConnectionPool.last_reused_connection is None
|
||||
finally:
|
||||
client.kill_server()
|
||||
request, response = client.get("/1", timeout=1)
|
||||
assert request.protocol.state["requests_count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -277,22 +117,23 @@ def test_keep_alive_server_timeout():
|
||||
keep-alive timeout, the client will either a 'Connection reset' error
|
||||
_or_ a new connection. Depending on how the event-loop handles the
|
||||
broken server connection."""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop)
|
||||
client = ReusableClient(
|
||||
keep_alive_app_server_timeout, loop=loop, port=PORT
|
||||
)
|
||||
with client:
|
||||
headers = {"Connection": "keep-alive"}
|
||||
_, response = client.get("/1", headers=headers, request_keepalive=60)
|
||||
request, response = client.get("/1", headers=headers, timeout=60)
|
||||
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
assert request.protocol.state["requests_count"] == 1
|
||||
|
||||
loop.run_until_complete(aio_sleep(3))
|
||||
_, response = client.get("/1", request_keepalive=60)
|
||||
request, response = client.get("/1", timeout=60)
|
||||
|
||||
assert ReusableSanicConnectionPool.last_reused_connection is None
|
||||
finally:
|
||||
client.kill_server()
|
||||
assert request.protocol.state["requests_count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -300,10 +141,10 @@ def test_keep_alive_server_timeout():
|
||||
reason="Not testable with current client",
|
||||
)
|
||||
def test_keep_alive_connection_context():
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
client = ReuseableSanicTestClient(keep_alive_app_context, loop)
|
||||
client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT)
|
||||
with client:
|
||||
headers = {"Connection": "keep-alive"}
|
||||
request1, _ = client.post("/ctx", headers=headers)
|
||||
|
||||
@@ -315,5 +156,4 @@ def test_keep_alive_connection_context():
|
||||
assert (
|
||||
request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello"
|
||||
)
|
||||
finally:
|
||||
client.kill_server()
|
||||
assert request2.protocol.state["requests_count"] == 2
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
|
||||
from importlib import reload
|
||||
from io import StringIO
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -51,7 +52,7 @@ def test_log(app):
|
||||
|
||||
def test_logging_defaults():
|
||||
# reset_logging()
|
||||
app = Sanic("test_logging")
|
||||
Sanic("test_logging")
|
||||
|
||||
for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]:
|
||||
assert (
|
||||
@@ -87,7 +88,7 @@ def test_logging_pass_customer_logconfig():
|
||||
"format"
|
||||
] = "%(asctime)s - (%(name)s)[%(levelname)s]: %(message)s"
|
||||
|
||||
app = Sanic("test_logging", log_config=modified_config)
|
||||
Sanic("test_logging", log_config=modified_config)
|
||||
|
||||
for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]:
|
||||
assert fmt._fmt == modified_config["formatters"]["generic"]["format"]
|
||||
@@ -115,7 +116,9 @@ def test_log_connection_lost(app, debug, monkeypatch):
|
||||
stream = StringIO()
|
||||
error = logging.getLogger("sanic.error")
|
||||
error.addHandler(logging.StreamHandler(stream))
|
||||
monkeypatch.setattr(sanic.server, "error_logger", error)
|
||||
monkeypatch.setattr(
|
||||
sanic.server.protocols.http_protocol, "error_logger", error
|
||||
)
|
||||
|
||||
@app.route("/conn_lost")
|
||||
async def conn_lost(request):
|
||||
@@ -208,6 +211,56 @@ def test_logging_modified_root_logger_config():
|
||||
modified_config = LOGGING_CONFIG_DEFAULTS
|
||||
modified_config["loggers"]["sanic.root"]["level"] = "DEBUG"
|
||||
|
||||
app = Sanic("test_logging", log_config=modified_config)
|
||||
Sanic("test_logging", log_config=modified_config)
|
||||
|
||||
assert logging.getLogger("sanic.root").getEffectiveLevel() == logging.DEBUG
|
||||
|
||||
|
||||
def test_access_log_client_ip_remote_addr(monkeypatch):
|
||||
access = Mock()
|
||||
monkeypatch.setattr(sanic.http, "access_logger", access)
|
||||
|
||||
app = Sanic("test_logging")
|
||||
app.config.PROXIES_COUNT = 2
|
||||
|
||||
@app.route("/")
|
||||
async def handler(request):
|
||||
return text(request.remote_addr)
|
||||
|
||||
headers = {"X-Forwarded-For": "1.1.1.1, 2.2.2.2"}
|
||||
|
||||
request, response = app.test_client.get("/", headers=headers)
|
||||
|
||||
assert request.remote_addr == "1.1.1.1"
|
||||
access.info.assert_called_with(
|
||||
"",
|
||||
extra={
|
||||
"status": 200,
|
||||
"byte": len(response.content),
|
||||
"host": f"{request.remote_addr}:{request.port}",
|
||||
"request": f"GET {request.scheme}://{request.host}/",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_access_log_client_ip_reqip(monkeypatch):
|
||||
access = Mock()
|
||||
monkeypatch.setattr(sanic.http, "access_logger", access)
|
||||
|
||||
app = Sanic("test_logging")
|
||||
|
||||
@app.route("/")
|
||||
async def handler(request):
|
||||
return text(request.ip)
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
|
||||
access.info.assert_called_with(
|
||||
"",
|
||||
extra={
|
||||
"status": 200,
|
||||
"byte": len(response.content),
|
||||
"host": f"{request.ip}:{request.port}",
|
||||
"request": f"GET {request.scheme}://{request.host}/",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -6,85 +6,37 @@ from sanic_testing.testing import PORT
|
||||
from sanic.config import BASE_LOGO
|
||||
|
||||
|
||||
def test_logo_base(app, caplog):
|
||||
server = app.create_server(
|
||||
debug=True, return_asyncio_server=True, port=PORT
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop._stopping = False
|
||||
def test_logo_base(app, run_startup):
|
||||
logs = run_startup(app)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
_server = loop.run_until_complete(server)
|
||||
|
||||
_server.close()
|
||||
loop.run_until_complete(_server.wait_closed())
|
||||
app.stop()
|
||||
|
||||
assert caplog.record_tuples[0][1] == logging.DEBUG
|
||||
assert caplog.record_tuples[0][2] == BASE_LOGO
|
||||
assert logs[0][1] == logging.DEBUG
|
||||
assert logs[0][2] == BASE_LOGO
|
||||
|
||||
|
||||
def test_logo_false(app, caplog):
|
||||
def test_logo_false(app, caplog, run_startup):
|
||||
app.config.LOGO = False
|
||||
|
||||
server = app.create_server(
|
||||
debug=True, return_asyncio_server=True, port=PORT
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop._stopping = False
|
||||
logs = run_startup(app)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
_server = loop.run_until_complete(server)
|
||||
|
||||
_server.close()
|
||||
loop.run_until_complete(_server.wait_closed())
|
||||
app.stop()
|
||||
|
||||
banner, port = caplog.record_tuples[0][2].rsplit(":", 1)
|
||||
assert caplog.record_tuples[0][1] == logging.INFO
|
||||
banner, port = logs[0][2].rsplit(":", 1)
|
||||
assert logs[0][1] == logging.INFO
|
||||
assert banner == "Goin' Fast @ http://127.0.0.1"
|
||||
assert int(port) > 0
|
||||
|
||||
|
||||
def test_logo_true(app, caplog):
|
||||
def test_logo_true(app, run_startup):
|
||||
app.config.LOGO = True
|
||||
|
||||
server = app.create_server(
|
||||
debug=True, return_asyncio_server=True, port=PORT
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop._stopping = False
|
||||
logs = run_startup(app)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
_server = loop.run_until_complete(server)
|
||||
|
||||
_server.close()
|
||||
loop.run_until_complete(_server.wait_closed())
|
||||
app.stop()
|
||||
|
||||
assert caplog.record_tuples[0][1] == logging.DEBUG
|
||||
assert caplog.record_tuples[0][2] == BASE_LOGO
|
||||
assert logs[0][1] == logging.DEBUG
|
||||
assert logs[0][2] == BASE_LOGO
|
||||
|
||||
|
||||
def test_logo_custom(app, caplog):
|
||||
def test_logo_custom(app, run_startup):
|
||||
app.config.LOGO = "My Custom Logo"
|
||||
|
||||
server = app.create_server(
|
||||
debug=True, return_asyncio_server=True, port=PORT
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop._stopping = False
|
||||
logs = run_startup(app)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
_server = loop.run_until_complete(server)
|
||||
|
||||
_server.close()
|
||||
loop.run_until_complete(_server.wait_closed())
|
||||
app.stop()
|
||||
|
||||
assert caplog.record_tuples[0][1] == logging.DEBUG
|
||||
assert caplog.record_tuples[0][2] == "My Custom Logo"
|
||||
assert logs[0][1] == logging.DEBUG
|
||||
assert logs[0][2] == "My Custom Logo"
|
||||
|
||||
@@ -5,7 +5,7 @@ from itertools import count
|
||||
|
||||
from sanic.exceptions import NotFound
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, text
|
||||
from sanic.response import HTTPResponse, json, text
|
||||
|
||||
|
||||
# ------------------------------------------------------------ #
|
||||
@@ -37,14 +37,19 @@ def test_middleware_request_as_convenience(app):
|
||||
async def handler1(request):
|
||||
results.append(request)
|
||||
|
||||
@app.route("/")
|
||||
@app.on_request()
|
||||
async def handler2(request):
|
||||
results.append(request)
|
||||
|
||||
@app.route("/")
|
||||
async def handler3(request):
|
||||
return text("OK")
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
|
||||
assert response.text == "OK"
|
||||
assert type(results[0]) is Request
|
||||
assert type(results[1]) is Request
|
||||
|
||||
|
||||
def test_middleware_response(app):
|
||||
@@ -79,7 +84,12 @@ def test_middleware_response_as_convenience(app):
|
||||
results.append(request)
|
||||
|
||||
@app.on_response
|
||||
async def process_response(request, response):
|
||||
async def process_response_1(request, response):
|
||||
results.append(request)
|
||||
results.append(response)
|
||||
|
||||
@app.on_response()
|
||||
async def process_response_2(request, response):
|
||||
results.append(request)
|
||||
results.append(response)
|
||||
|
||||
@@ -93,6 +103,8 @@ def test_middleware_response_as_convenience(app):
|
||||
assert type(results[0]) is Request
|
||||
assert type(results[1]) is Request
|
||||
assert isinstance(results[2], HTTPResponse)
|
||||
assert type(results[3]) is Request
|
||||
assert isinstance(results[4], HTTPResponse)
|
||||
|
||||
|
||||
def test_middleware_response_as_convenience_called(app):
|
||||
@@ -271,3 +283,17 @@ def test_request_middleware_executes_once(app):
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
assert next(i) == 3
|
||||
|
||||
|
||||
def test_middleware_added_response(app):
|
||||
@app.on_response
|
||||
def display(_, response):
|
||||
response["foo"] = "bar"
|
||||
return json(response)
|
||||
|
||||
@app.get("/")
|
||||
async def handler(request):
|
||||
return {}
|
||||
|
||||
_, response = app.test_client.get("/")
|
||||
assert response.json["foo"] == "bar"
|
||||
|
||||
105
tests/test_pipelining.py
Normal file
105
tests/test_pipelining.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from httpx import AsyncByteStream
|
||||
from sanic_testing.reusable import ReusableClient
|
||||
|
||||
from sanic.response import json, text
|
||||
|
||||
|
||||
def test_no_body_requests(app):
|
||||
@app.get("/")
|
||||
async def handler(request):
|
||||
return json(
|
||||
{
|
||||
"request_id": str(request.id),
|
||||
"connection_id": id(request.conn_info),
|
||||
}
|
||||
)
|
||||
|
||||
client = ReusableClient(app, port=1234)
|
||||
|
||||
with client:
|
||||
_, response1 = client.get("/")
|
||||
_, response2 = client.get("/")
|
||||
|
||||
assert response1.status == response2.status == 200
|
||||
assert response1.json["request_id"] != response2.json["request_id"]
|
||||
assert response1.json["connection_id"] == response2.json["connection_id"]
|
||||
|
||||
|
||||
def test_json_body_requests(app):
|
||||
@app.post("/")
|
||||
async def handler(request):
|
||||
return json(
|
||||
{
|
||||
"request_id": str(request.id),
|
||||
"connection_id": id(request.conn_info),
|
||||
"foo": request.json.get("foo"),
|
||||
}
|
||||
)
|
||||
|
||||
client = ReusableClient(app, port=1234)
|
||||
|
||||
with client:
|
||||
_, response1 = client.post("/", json={"foo": True})
|
||||
_, response2 = client.post("/", json={"foo": True})
|
||||
|
||||
assert response1.status == response2.status == 200
|
||||
assert response1.json["foo"] is response2.json["foo"] is True
|
||||
assert response1.json["request_id"] != response2.json["request_id"]
|
||||
assert response1.json["connection_id"] == response2.json["connection_id"]
|
||||
|
||||
|
||||
def test_streaming_body_requests(app):
|
||||
@app.post("/", stream=True)
|
||||
async def handler(request):
|
||||
data = [part.decode("utf-8") async for part in request.stream]
|
||||
return json(
|
||||
{
|
||||
"request_id": str(request.id),
|
||||
"connection_id": id(request.conn_info),
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
data = ["hello", "world"]
|
||||
|
||||
class Data(AsyncByteStream):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
async def __aiter__(self):
|
||||
for value in self.data:
|
||||
yield value.encode("utf-8")
|
||||
|
||||
client = ReusableClient(app, port=1234)
|
||||
|
||||
with client:
|
||||
_, response1 = client.post("/", data=Data(data))
|
||||
_, response2 = client.post("/", data=Data(data))
|
||||
|
||||
assert response1.status == response2.status == 200
|
||||
assert response1.json["data"] == response2.json["data"] == data
|
||||
assert response1.json["request_id"] != response2.json["request_id"]
|
||||
assert response1.json["connection_id"] == response2.json["connection_id"]
|
||||
|
||||
|
||||
def test_bad_headers(app):
|
||||
@app.get("/")
|
||||
async def handler(request):
|
||||
return text("")
|
||||
|
||||
@app.on_response
|
||||
async def reqid(request, response):
|
||||
response.headers["x-request-id"] = request.id
|
||||
|
||||
client = ReusableClient(app, port=1234)
|
||||
bad_headers = {"bad": "bad" * 5_000}
|
||||
|
||||
with client:
|
||||
_, response1 = client.get("/")
|
||||
_, response2 = client.get("/", headers=bad_headers)
|
||||
|
||||
assert response1.status == 200
|
||||
assert response2.status == 413
|
||||
assert (
|
||||
response1.headers["x-request-id"] != response2.headers["x-request-id"]
|
||||
)
|
||||
@@ -140,3 +140,39 @@ def test_ipv6_address_is_not_wrapped(app):
|
||||
assert resp.json["client"] == "[::1]"
|
||||
assert resp.json["client_ip"] == "::1"
|
||||
assert request.ip == "::1"
|
||||
|
||||
|
||||
def test_request_accept():
|
||||
app = Sanic("req-generator")
|
||||
|
||||
@app.get("/")
|
||||
async def get(request):
|
||||
return response.empty()
|
||||
|
||||
request, _ = app.test_client.get(
|
||||
"/",
|
||||
headers={
|
||||
"Accept": "text/*, text/plain, text/plain;format=flowed, */*"
|
||||
},
|
||||
)
|
||||
assert request.accept == [
|
||||
"text/plain;format=flowed",
|
||||
"text/plain",
|
||||
"text/*",
|
||||
"*/*",
|
||||
]
|
||||
|
||||
request, _ = app.test_client.get(
|
||||
"/",
|
||||
headers={
|
||||
"Accept": (
|
||||
"text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"
|
||||
)
|
||||
},
|
||||
)
|
||||
assert request.accept == [
|
||||
"text/html",
|
||||
"text/x-c",
|
||||
"text/x-dvi; q=0.8",
|
||||
"text/plain; q=0.5",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
|
||||
import httpcore
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from sanic_testing.testing import SanicTestClient
|
||||
|
||||
@@ -48,42 +49,51 @@ class DelayableSanicTestClient(SanicTestClient):
|
||||
return DelayableSanicSession(request_delay=self._request_delay)
|
||||
|
||||
|
||||
request_timeout_default_app = Sanic("test_request_timeout_default")
|
||||
request_no_timeout_app = Sanic("test_request_no_timeout")
|
||||
request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6
|
||||
request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6
|
||||
@pytest.fixture
|
||||
def request_no_timeout_app():
|
||||
app = Sanic("test_request_no_timeout")
|
||||
app.config.REQUEST_TIMEOUT = 0.6
|
||||
|
||||
|
||||
@request_timeout_default_app.route("/1")
|
||||
async def handler1(request):
|
||||
return text("OK")
|
||||
|
||||
|
||||
@request_no_timeout_app.route("/1")
|
||||
@app.route("/1")
|
||||
async def handler2(request):
|
||||
return text("OK")
|
||||
|
||||
return app
|
||||
|
||||
@request_timeout_default_app.websocket("/ws1")
|
||||
|
||||
@pytest.fixture
|
||||
def request_timeout_default_app():
|
||||
app = Sanic("test_request_timeout_default")
|
||||
app.config.REQUEST_TIMEOUT = 0.6
|
||||
|
||||
@app.route("/1")
|
||||
async def handler1(request):
|
||||
return text("OK")
|
||||
|
||||
@app.websocket("/ws1")
|
||||
async def ws_handler1(request, ws):
|
||||
await ws.send("OK")
|
||||
|
||||
return app
|
||||
|
||||
def test_default_server_error_request_timeout():
|
||||
|
||||
def test_default_server_error_request_timeout(request_timeout_default_app):
|
||||
client = DelayableSanicTestClient(request_timeout_default_app, 2)
|
||||
request, response = client.get("/1")
|
||||
_, response = client.get("/1")
|
||||
assert response.status == 408
|
||||
assert "Request Timeout" in response.text
|
||||
|
||||
|
||||
def test_default_server_error_request_dont_timeout():
|
||||
def test_default_server_error_request_dont_timeout(request_no_timeout_app):
|
||||
client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
|
||||
request, response = client.get("/1")
|
||||
_, response = client.get("/1")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
|
||||
|
||||
def test_default_server_error_websocket_request_timeout():
|
||||
def test_default_server_error_websocket_request_timeout(
|
||||
request_timeout_default_app,
|
||||
):
|
||||
|
||||
headers = {
|
||||
"Upgrade": "websocket",
|
||||
@@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout():
|
||||
}
|
||||
|
||||
client = DelayableSanicTestClient(request_timeout_default_app, 2)
|
||||
request, response = client.get("/ws1", headers=headers)
|
||||
_, response = client.get("/ws1", headers=headers)
|
||||
|
||||
assert response.status == 408
|
||||
assert "Request Timeout" in response.text
|
||||
|
||||
@@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("url", ["/ws", "ws"])
|
||||
async def test_websocket_route_asgi(app, url):
|
||||
ev = asyncio.Event()
|
||||
@app.after_server_start
|
||||
async def setup_ev(app, _):
|
||||
app.ctx.ev = asyncio.Event()
|
||||
|
||||
@app.websocket(url)
|
||||
async def handler(request, ws):
|
||||
ev.set()
|
||||
request.app.ctx.ev.set()
|
||||
|
||||
request, response = await app.asgi_client.websocket(url)
|
||||
assert ev.is_set()
|
||||
@app.get("/ev")
|
||||
async def check(request):
|
||||
return json({"set": request.app.ctx.ev.is_set()})
|
||||
|
||||
_, response = await app.asgi_client.websocket(url)
|
||||
_, response = await app.asgi_client.get("/")
|
||||
assert response.json["set"]
|
||||
|
||||
|
||||
def test_websocket_route_with_subprotocols(app):
|
||||
@pytest.mark.parametrize(
|
||||
"subprotocols,expected",
|
||||
(
|
||||
(["one"], "one"),
|
||||
(["three", "one"], "one"),
|
||||
(["tree"], None),
|
||||
(None, None),
|
||||
),
|
||||
)
|
||||
def test_websocket_route_with_subprotocols(app, subprotocols, expected):
|
||||
results = []
|
||||
|
||||
@app.websocket("/ws", subprotocols=["foo", "bar"])
|
||||
@app.websocket("/ws", subprotocols=["zero", "one", "two", "three"])
|
||||
async def handler(request, ws):
|
||||
results.append(ws.subprotocol)
|
||||
nonlocal results
|
||||
results = ws.subprotocol
|
||||
assert ws.subprotocol is not None
|
||||
|
||||
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"])
|
||||
assert response.opened is True
|
||||
assert results == ["bar"]
|
||||
|
||||
_, response = SanicTestClient(app).websocket(
|
||||
"/ws", subprotocols=["bar", "foo"]
|
||||
"/ws", subprotocols=subprotocols
|
||||
)
|
||||
assert response.opened is True
|
||||
assert results == ["bar", "bar"]
|
||||
|
||||
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"])
|
||||
assert response.opened is True
|
||||
assert results == ["bar", "bar", None]
|
||||
|
||||
_, response = SanicTestClient(app).websocket("/ws")
|
||||
assert response.opened is True
|
||||
assert results == ["bar", "bar", None, None]
|
||||
assert results == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strict_slashes", [True, False, None])
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from sanic_testing.testing import HOST, PORT
|
||||
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.exceptions import InvalidUsage, SanicException
|
||||
|
||||
|
||||
AVAILABLE_LISTENERS = [
|
||||
@@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app):
|
||||
async def init_db(app, loop):
|
||||
app.db = MySanicDb()
|
||||
|
||||
await app.create_server(debug=True, return_asyncio_server=True, port=PORT)
|
||||
srv = await app.create_server(
|
||||
debug=True, return_asyncio_server=True, port=PORT
|
||||
)
|
||||
await srv.startup()
|
||||
await srv.before_start()
|
||||
|
||||
assert hasattr(app, "db")
|
||||
assert isinstance(app.db, MySanicDb)
|
||||
@@ -157,14 +161,15 @@ def test_create_server_trigger_events(app):
|
||||
serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
|
||||
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
|
||||
server = loop.run_until_complete(serv_task)
|
||||
server.after_start()
|
||||
loop.run_until_complete(server.startup())
|
||||
loop.run_until_complete(server.after_start())
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt as e:
|
||||
except KeyboardInterrupt:
|
||||
loop.stop()
|
||||
finally:
|
||||
# Run the on_stop function if provided
|
||||
server.before_stop()
|
||||
loop.run_until_complete(server.before_stop())
|
||||
|
||||
# Wait for server to close
|
||||
close_task = server.close()
|
||||
@@ -174,5 +179,19 @@ def test_create_server_trigger_events(app):
|
||||
signal.stopped = True
|
||||
for connection in server.connections:
|
||||
connection.close_if_idle()
|
||||
server.after_stop()
|
||||
loop.run_until_complete(server.after_stop())
|
||||
assert flag1 and flag2 and flag3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_startup_raises_exception(app):
|
||||
@app.listener("before_server_start")
|
||||
async def init_db(app, loop):
|
||||
...
|
||||
|
||||
srv = await app.create_server(
|
||||
debug=True, return_asyncio_server=True, port=PORT
|
||||
)
|
||||
|
||||
with pytest.raises(SanicException):
|
||||
await srv.before_start()
|
||||
|
||||
@@ -95,7 +95,7 @@ def test_windows_workaround():
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
await asyncio.sleep(0.2)
|
||||
assert app.is_stopping
|
||||
assert app.stay_active_task.result() == None
|
||||
assert app.stay_active_task.result() is None
|
||||
# Second Ctrl+C should raise
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
@@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app):
|
||||
|
||||
app.signal_router.finalize()
|
||||
|
||||
assert len(app.signal_router.routes) == 3
|
||||
await app.dispatch("foo.bar.baz")
|
||||
assert counter == 2
|
||||
|
||||
@@ -331,7 +332,8 @@ def test_event_on_bp_not_registered():
|
||||
"event,expected",
|
||||
(
|
||||
("foo.bar.baz", True),
|
||||
("server.init.before", False),
|
||||
("server.init.before", True),
|
||||
("server.init.somethingelse", False),
|
||||
("http.request.start", False),
|
||||
("sanic.notice.anything", True),
|
||||
),
|
||||
|
||||
@@ -461,6 +461,22 @@ def test_nested_dir(app, static_file_directory):
|
||||
assert response.text == "foo\n"
|
||||
|
||||
|
||||
def test_handle_is_a_directory_error(app, static_file_directory):
|
||||
error_text = "Is a directory. Access denied"
|
||||
app.static("/static", static_file_directory)
|
||||
|
||||
@app.exception(Exception)
|
||||
async def handleStaticDirError(request, exception):
|
||||
if isinstance(exception, IsADirectoryError):
|
||||
return text(error_text, status=403)
|
||||
raise exception
|
||||
|
||||
request, response = app.test_client.get("/static/")
|
||||
|
||||
assert response.status == 403
|
||||
assert response.text == error_text
|
||||
|
||||
|
||||
def test_stack_trace_on_not_found(app, static_file_directory, caplog):
|
||||
app.static("/static", static_file_directory)
|
||||
|
||||
@@ -471,7 +487,7 @@ def test_stack_trace_on_not_found(app, static_file_directory, caplog):
|
||||
|
||||
assert response.status == 404
|
||||
assert counter[logging.INFO] == 5
|
||||
assert counter[logging.ERROR] == 1
|
||||
assert counter[logging.ERROR] == 0
|
||||
|
||||
|
||||
def test_no_stack_trace_on_not_found(app, static_file_directory, caplog):
|
||||
@@ -507,3 +523,56 @@ def test_multiple_statics(app, static_file_directory):
|
||||
assert response.body == get_file_content(
|
||||
static_file_directory, "python.png"
|
||||
)
|
||||
|
||||
|
||||
def test_resource_type_default(app, static_file_directory):
|
||||
app.static("/static", static_file_directory)
|
||||
app.static("/file", get_file_path(static_file_directory, "test.file"))
|
||||
|
||||
_, response = app.test_client.get("/static")
|
||||
assert response.status == 404
|
||||
|
||||
_, response = app.test_client.get("/file")
|
||||
assert response.status == 200
|
||||
assert response.body == get_file_content(
|
||||
static_file_directory, "test.file"
|
||||
)
|
||||
|
||||
|
||||
def test_resource_type_file(app, static_file_directory):
|
||||
app.static(
|
||||
"/file",
|
||||
get_file_path(static_file_directory, "test.file"),
|
||||
resource_type="file",
|
||||
)
|
||||
|
||||
_, response = app.test_client.get("/file")
|
||||
assert response.status == 200
|
||||
assert response.body == get_file_content(
|
||||
static_file_directory, "test.file"
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
app.static("/static", static_file_directory, resource_type="file")
|
||||
|
||||
|
||||
def test_resource_type_dir(app, static_file_directory):
|
||||
app.static("/static", static_file_directory, resource_type="dir")
|
||||
|
||||
_, response = app.test_client.get("/static/test.file")
|
||||
assert response.status == 200
|
||||
assert response.body == get_file_content(
|
||||
static_file_directory, "test.file"
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
app.static(
|
||||
"/file",
|
||||
get_file_path(static_file_directory, "test.file"),
|
||||
resource_type="dir",
|
||||
)
|
||||
|
||||
|
||||
def test_resource_type_unknown(app, static_file_directory, caplog):
|
||||
with pytest.raises(ValueError):
|
||||
app.static("/static", static_file_directory, resource_type="unknown")
|
||||
|
||||
21
tests/test_touchup.py
Normal file
21
tests/test_touchup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import logging
|
||||
|
||||
from sanic.signals import RESERVED_NAMESPACES
|
||||
from sanic.touchup import TouchUp
|
||||
|
||||
|
||||
def test_touchup_methods(app):
|
||||
assert len(TouchUp._registry) == 9
|
||||
|
||||
|
||||
async def test_ode_removes_dispatch_events(app, caplog):
|
||||
with caplog.at_level(logging.DEBUG, logger="sanic.root"):
|
||||
await app._startup()
|
||||
logs = caplog.record_tuples
|
||||
|
||||
for signal in RESERVED_NAMESPACES["http"]:
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
f"Disabling event: {signal}",
|
||||
) in logs
|
||||
@@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app):
|
||||
)
|
||||
|
||||
|
||||
def test_websocket_bp_route_name(app):
|
||||
@pytest.mark.parametrize(
|
||||
"name,expected",
|
||||
(
|
||||
("test_route", "/bp/route"),
|
||||
("test_route2", "/bp/route2"),
|
||||
("foobar_3", "/bp/route3"),
|
||||
),
|
||||
)
|
||||
def test_websocket_bp_route_name(app, name, expected):
|
||||
"""Tests that blueprint websocket route is named."""
|
||||
event = asyncio.Event()
|
||||
bp = Blueprint("test_bp", url_prefix="/bp")
|
||||
@@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app):
|
||||
uri = app.url_for("test_bp.main")
|
||||
assert uri == "/bp/main"
|
||||
|
||||
uri = app.url_for("test_bp.test_route")
|
||||
assert uri == "/bp/route"
|
||||
uri = app.url_for(f"test_bp.{name}")
|
||||
assert uri == expected
|
||||
request, response = SanicTestClient(app).websocket(uri)
|
||||
assert response.opened is True
|
||||
assert event.is_set()
|
||||
|
||||
event.clear()
|
||||
uri = app.url_for("test_bp.test_route2")
|
||||
assert uri == "/bp/route2"
|
||||
request, response = SanicTestClient(app).websocket(uri)
|
||||
assert response.opened is True
|
||||
assert event.is_set()
|
||||
|
||||
uri = app.url_for("test_bp.foobar_3")
|
||||
assert uri == "/bp/route3"
|
||||
|
||||
|
||||
# TODO: add test with a route with multiple hosts
|
||||
# TODO: add test with a route with _host in url_for
|
||||
|
||||
@@ -175,7 +175,7 @@ def test_worker_close(worker):
|
||||
worker.wsgi = mock.Mock()
|
||||
conn = mock.Mock()
|
||||
conn.websocket = mock.Mock()
|
||||
conn.websocket.close_connection = mock.Mock(wraps=_a_noop)
|
||||
conn.websocket.fail_connection = mock.Mock(wraps=_a_noop)
|
||||
worker.connections = set([conn])
|
||||
worker.log = mock.Mock()
|
||||
worker.loop = loop
|
||||
@@ -190,5 +190,5 @@ def test_worker_close(worker):
|
||||
loop.run_until_complete(_close)
|
||||
|
||||
assert worker.signal.stopped
|
||||
assert conn.websocket.close_connection.called
|
||||
assert conn.websocket.fail_connection.called
|
||||
assert len(worker.servers) == 0
|
||||
|
||||
55
tox.ini
55
tox.ini
@@ -2,53 +2,28 @@
|
||||
envlist = py37, py38, py39, pyNightly, pypy37, {py37,py38,py39,pyNightly,pypy37}-no-ext, lint, check, security, docs, type-checking
|
||||
|
||||
[testenv]
|
||||
usedevelop = True
|
||||
usedevelop = true
|
||||
setenv =
|
||||
{py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UJSON=1
|
||||
{py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UVLOOP=1
|
||||
deps =
|
||||
sanic-testing>=0.6.0
|
||||
coverage==5.3
|
||||
pytest==5.2.1
|
||||
pytest-cov
|
||||
pytest-sanic
|
||||
pytest-sugar
|
||||
pytest-benchmark
|
||||
chardet==3.*
|
||||
beautifulsoup4
|
||||
gunicorn==20.0.4
|
||||
uvicorn
|
||||
websockets>=9.0
|
||||
extras = test
|
||||
commands =
|
||||
pytest {posargs:tests --cov sanic}
|
||||
- coverage combine --append
|
||||
coverage report -m
|
||||
coverage report -m -i
|
||||
coverage html -i
|
||||
|
||||
[testenv:lint]
|
||||
deps =
|
||||
flake8
|
||||
black
|
||||
isort>=5.0.0
|
||||
bandit
|
||||
|
||||
commands =
|
||||
flake8 sanic
|
||||
black --config ./.black.toml --check --verbose sanic/
|
||||
isort --check-only sanic --profile=black
|
||||
|
||||
[testenv:type-checking]
|
||||
deps =
|
||||
mypy>=0.901
|
||||
types-ujson
|
||||
|
||||
commands =
|
||||
mypy sanic
|
||||
|
||||
[testenv:check]
|
||||
deps =
|
||||
docutils
|
||||
pygments
|
||||
commands =
|
||||
python setup.py check -r -s
|
||||
|
||||
@@ -60,8 +35,6 @@ markers =
|
||||
asyncio
|
||||
|
||||
[testenv:security]
|
||||
deps =
|
||||
bandit
|
||||
|
||||
commands =
|
||||
bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py
|
||||
@@ -69,30 +42,10 @@ commands =
|
||||
[testenv:docs]
|
||||
platform = linux|linux2|darwin
|
||||
whitelist_externals = make
|
||||
deps =
|
||||
sphinx>=2.1.2
|
||||
sphinx_rtd_theme>=0.4.3
|
||||
recommonmark>=0.5.0
|
||||
docutils
|
||||
pygments
|
||||
gunicorn==20.0.4
|
||||
extras = docs
|
||||
commands =
|
||||
make docs-test
|
||||
|
||||
[testenv:coverage]
|
||||
usedevelop = True
|
||||
deps =
|
||||
sanic-testing>=0.6.0
|
||||
coverage==5.3
|
||||
pytest==5.2.1
|
||||
pytest-cov
|
||||
pytest-sanic
|
||||
pytest-sugar
|
||||
pytest-benchmark
|
||||
chardet==3.*
|
||||
beautifulsoup4
|
||||
gunicorn==20.0.4
|
||||
uvicorn
|
||||
websockets>=9.0
|
||||
commands =
|
||||
pytest tests --cov=./sanic --cov-report=xml
|
||||
|
||||
Reference in New Issue
Block a user