Merge pull request #1475 from tomchristie/asgi-refactor-attempt
ASGI refactoring attempt
This commit is contained in:
commit
891f99d71d
|
@ -5,3 +5,11 @@ omit = site-packages, sanic/utils.py, sanic/__main__.py
|
|||
|
||||
[html]
|
||||
directory = coverage
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
no cov
|
||||
no qa
|
||||
noqa
|
||||
NOQA
|
||||
pragma: no cover
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
# Deploying
|
||||
|
||||
Deploying Sanic is made simple by the inbuilt webserver. After defining an
|
||||
instance of `sanic.Sanic`, we can call the `run` method with the following
|
||||
Deploying Sanic is very simple using one of three options: the inbuilt webserver,
|
||||
an [ASGI webserver](https://asgi.readthedocs.io/en/latest/implementations.html), or `gunicorn`.
|
||||
It is also very common to place Sanic behind a reverse proxy, like `nginx`.
|
||||
|
||||
## Running via Sanic webserver
|
||||
|
||||
After defining an instance of `sanic.Sanic`, we can call the `run` method with the following
|
||||
keyword arguments:
|
||||
|
||||
- `host` *(default `"127.0.0.1"`)*: Address to host the server on.
|
||||
|
@ -17,7 +22,13 @@ keyword arguments:
|
|||
[asyncio.protocol](https://docs.python.org/3/library/asyncio-protocol.html#protocol-classes).
|
||||
- `access_log` *(default `True`)*: Enables log on handling requests (significantly slows server).
|
||||
|
||||
## Workers
|
||||
```python
|
||||
app.run(host='0.0.0.0', port=1337, access_log=False)
|
||||
```
|
||||
|
||||
In the above example, we decided to turn off the access log in order to increase performance.
|
||||
|
||||
### Workers
|
||||
|
||||
By default, Sanic listens in the main process using only one CPU core. To crank
|
||||
up the juice, just specify the number of workers in the `run` arguments.
|
||||
|
@ -29,9 +40,9 @@ app.run(host='0.0.0.0', port=1337, workers=4)
|
|||
Sanic will automatically spin up multiple processes and route traffic between
|
||||
them. We recommend as many workers as you have available cores.
|
||||
|
||||
## Running via command
|
||||
### Running via command
|
||||
|
||||
If you like using command line arguments, you can launch a Sanic server by
|
||||
If you like using command line arguments, you can launch a Sanic webserver by
|
||||
executing the module. For example, if you initialized Sanic as `app` in a file
|
||||
named `server.py`, you could run the server like so:
|
||||
|
||||
|
@ -46,6 +57,33 @@ if __name__ == '__main__':
|
|||
app.run(host='0.0.0.0', port=1337, workers=4)
|
||||
```
|
||||
|
||||
## Running via ASGI
|
||||
|
||||
Sanic is also ASGI-compliant. This means you can use your preferred ASGI webserver
|
||||
to run Sanic. The three main implementations of ASGI are
|
||||
[Daphne](http://github.com/django/daphne), [Uvicorn](https://www.uvicorn.org/),
|
||||
and [Hypercorn](https://pgjones.gitlab.io/hypercorn/index.html).
|
||||
|
||||
Follow their documentation for the proper way to run them, but it should look
|
||||
something like:
|
||||
|
||||
```
|
||||
daphne myapp:app
|
||||
uvicorn myapp:app
|
||||
hypercorn myapp:app
|
||||
```
|
||||
|
||||
A couple things to note when using ASGI:
|
||||
|
||||
1. When using the Sanic webserver, websockets will run using the [`websockets`](https://websockets.readthedocs.io/) package. In ASGI mode, there is no need for this package since websockets are managed in the ASGI server.
|
||||
1. The ASGI [lifespan protocol](https://asgi.readthedocs.io/en/latest/specs/lifespan.html) supports
|
||||
only two server events: startup and shutdown. Sanic has four: before startup, after startup,
|
||||
before shutdown, and after shutdown. Therefore, in ASGI mode, the startup and shutdown events will
|
||||
run consecutively and not actually around the server process beginning and ending (since that
|
||||
is now controlled by the ASGI server). Therefore, it is best to use `after_server_start` and
|
||||
`before_server_stop`.
|
||||
1. ASGI mode is still in "beta" as of Sanic v19.6.
|
||||
|
||||
## Running via Gunicorn
|
||||
|
||||
[Gunicorn](http://gunicorn.org/) ‘Green Unicorn’ is a WSGI HTTP Server for UNIX.
|
||||
|
@ -64,7 +102,9 @@ of the memory leak.
|
|||
|
||||
See the [Gunicorn Docs](http://docs.gunicorn.org/en/latest/settings.html#max-requests) for more information.
|
||||
|
||||
## Running behind a reverse proxy
|
||||
## Other deployment considerations
|
||||
|
||||
### Running behind a reverse proxy
|
||||
|
||||
Sanic can be used with a reverse proxy (e.g. nginx). There's a simple example of nginx configuration:
|
||||
|
||||
|
@ -84,7 +124,7 @@ server {
|
|||
|
||||
If you want to get real client ip, you should configure `X-Real-IP` and `X-Forwarded-For` HTTP headers and set `app.config.PROXIES_COUNT` to `1`; see the configuration page for more information.
|
||||
|
||||
## Disable debug logging
|
||||
### Disable debug logging for performance
|
||||
|
||||
To improve the performance add `debug=False` and `access_log=False` in the `run` arguments.
|
||||
|
||||
|
@ -104,9 +144,10 @@ Or you can rewrite app config directly
|
|||
app.config.ACCESS_LOG = False
|
||||
```
|
||||
|
||||
## Asynchronous support
|
||||
This is suitable if you *need* to share the sanic process with other applications, in particular the `loop`.
|
||||
However be advised that this method does not support using multiple processes, and is not the preferred way
|
||||
### Asynchronous support and sharing the loop
|
||||
|
||||
This is suitable if you *need* to share the Sanic process with other applications, in particular the `loop`.
|
||||
However, be advised that this method does not support using multiple processes, and is not the preferred way
|
||||
to run the app in general.
|
||||
|
||||
Here is an incomplete example (please see `run_async.py` in examples for something more practical):
|
||||
|
@ -116,4 +157,4 @@ server = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True
|
|||
loop = asyncio.get_event_loop()
|
||||
task = asyncio.ensure_future(server)
|
||||
loop.run_forever()
|
||||
```
|
||||
```
|
88
examples/run_asgi.py
Normal file
88
examples/run_asgi.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
"""
|
||||
1. Create a simple Sanic app
|
||||
0. Run with an ASGI server:
|
||||
$ uvicorn run_asgi:app
|
||||
or
|
||||
$ hypercorn run_asgi:app
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from sanic import Sanic, response
|
||||
|
||||
|
||||
app = Sanic(__name__)
|
||||
|
||||
|
||||
@app.route("/text")
|
||||
def handler_text(request):
|
||||
return response.text("Hello")
|
||||
|
||||
|
||||
@app.route("/json")
|
||||
def handler_json(request):
|
||||
return response.json({"foo": "bar"})
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def handler_ws(request, ws):
|
||||
name = "<someone>"
|
||||
while True:
|
||||
data = f"Hello {name}"
|
||||
await ws.send(data)
|
||||
name = await ws.recv()
|
||||
|
||||
if not name:
|
||||
break
|
||||
|
||||
|
||||
@app.route("/file")
|
||||
async def handler_file(request):
|
||||
return await response.file(Path("../") / "setup.py")
|
||||
|
||||
|
||||
@app.route("/file_stream")
|
||||
async def handler_file_stream(request):
|
||||
return await response.file_stream(
|
||||
Path("../") / "setup.py", chunk_size=1024
|
||||
)
|
||||
|
||||
|
||||
@app.route("/stream", stream=True)
|
||||
async def handler_stream(request):
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
body = body.decode("utf-8").replace("1", "A")
|
||||
# await response.write(body)
|
||||
return response.stream(body)
|
||||
|
||||
|
||||
@app.listener("before_server_start")
|
||||
async def listener_before_server_start(*args, **kwargs):
|
||||
print("before_server_start")
|
||||
|
||||
|
||||
@app.listener("after_server_start")
|
||||
async def listener_after_server_start(*args, **kwargs):
|
||||
print("after_server_start")
|
||||
|
||||
|
||||
@app.listener("before_server_stop")
|
||||
async def listener_before_server_stop(*args, **kwargs):
|
||||
print("before_server_stop")
|
||||
|
||||
|
||||
@app.listener("after_server_stop")
|
||||
async def listener_after_server_stop(*args, **kwargs):
|
||||
print("after_server_stop")
|
||||
|
||||
|
||||
@app.middleware("request")
|
||||
async def print_on_request(request):
|
||||
print("print_on_request")
|
||||
|
||||
|
||||
@app.middleware("response")
|
||||
async def print_on_response(request, response):
|
||||
print("print_on_response")
|
62
sanic/app.py
62
sanic/app.py
|
@ -15,6 +15,7 @@ from typing import Any, Optional, Type, Union
|
|||
from urllib.parse import urlencode, urlunparse
|
||||
|
||||
from sanic import reloader_helpers
|
||||
from sanic.asgi import ASGIApp
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.config import BASE_LOGO, Config
|
||||
from sanic.constants import HTTP_METHODS
|
||||
|
@ -25,7 +26,7 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse
|
|||
from sanic.router import Router
|
||||
from sanic.server import HttpProtocol, Signal, serve, serve_multiple
|
||||
from sanic.static import register as static_register
|
||||
from sanic.testing import SanicTestClient
|
||||
from sanic.testing import SanicASGITestClient, SanicTestClient
|
||||
from sanic.views import CompositionView
|
||||
from sanic.websocket import ConnectionClosed, WebSocketProtocol
|
||||
|
||||
|
@ -53,6 +54,7 @@ class Sanic:
|
|||
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
|
||||
|
||||
self.name = name
|
||||
self.asgi = False
|
||||
self.router = router or Router()
|
||||
self.request_class = request_class
|
||||
self.error_handler = error_handler or ErrorHandler()
|
||||
|
@ -80,7 +82,7 @@ class Sanic:
|
|||
|
||||
Only supported when using the `app.run` method.
|
||||
"""
|
||||
if not self.is_running:
|
||||
if not self.is_running and self.asgi is False:
|
||||
raise SanicException(
|
||||
"Loop can only be retrieved after the app has started "
|
||||
"running. Not supported with `create_server` function"
|
||||
|
@ -469,13 +471,23 @@ class Sanic:
|
|||
getattr(handler, "__blueprintname__", "")
|
||||
+ handler.__name__
|
||||
)
|
||||
try:
|
||||
protocol = request.transport.get_protocol()
|
||||
except AttributeError:
|
||||
# On Python3.5 the Transport classes in asyncio do not
|
||||
# have a get_protocol() method as in uvloop
|
||||
protocol = request.transport._protocol
|
||||
ws = await protocol.websocket_handshake(request, subprotocols)
|
||||
|
||||
pass
|
||||
|
||||
if self.asgi:
|
||||
ws = request.transport.get_websocket_connection()
|
||||
else:
|
||||
try:
|
||||
protocol = request.transport.get_protocol()
|
||||
except AttributeError:
|
||||
# On Python3.5 the Transport classes in asyncio do not
|
||||
# have a get_protocol() method as in uvloop
|
||||
protocol = request.transport._protocol
|
||||
protocol.app = self
|
||||
|
||||
ws = await protocol.websocket_handshake(
|
||||
request, subprotocols
|
||||
)
|
||||
|
||||
# schedule the application handler
|
||||
# its future is kept in self.websocket_tasks in case it
|
||||
|
@ -983,8 +995,16 @@ class Sanic:
|
|||
raise CancelledError()
|
||||
|
||||
# pass the response to the correct callback
|
||||
if isinstance(response, StreamingHTTPResponse):
|
||||
await stream_callback(response)
|
||||
if write_callback is None or isinstance(
|
||||
response, StreamingHTTPResponse
|
||||
):
|
||||
if stream_callback:
|
||||
await stream_callback(response)
|
||||
else:
|
||||
# Should only end here IF it is an ASGI websocket.
|
||||
# TODO:
|
||||
# - Add exception handling
|
||||
pass
|
||||
else:
|
||||
write_callback(response)
|
||||
|
||||
|
@ -996,6 +1016,10 @@ class Sanic:
|
|||
def test_client(self):
|
||||
return SanicTestClient(self)
|
||||
|
||||
@property
|
||||
def asgi_client(self):
|
||||
return SanicASGITestClient(self)
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# Execution
|
||||
# -------------------------------------------------------------------- #
|
||||
|
@ -1122,10 +1146,6 @@ class Sanic:
|
|||
"""This kills the Sanic"""
|
||||
get_event_loop().stop()
|
||||
|
||||
def __call__(self):
|
||||
"""gunicorn compatibility"""
|
||||
return self
|
||||
|
||||
async def create_server(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
|
@ -1367,3 +1387,15 @@ class Sanic:
|
|||
def _build_endpoint_name(self, *parts):
|
||||
parts = [self.name, *parts]
|
||||
return ".".join(parts)
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# ASGI
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""To be ASGI compliant, our instance must be a callable that accepts
|
||||
three arguments: scope, receive, send. See the ASGI reference for more
|
||||
details: https://asgi.readthedocs.io/en/latest/"""
|
||||
self.asgi = True
|
||||
asgi_app = await ASGIApp.create(self, scope, receive, send)
|
||||
await asgi_app()
|
||||
|
|
350
sanic/asgi.py
Normal file
350
sanic/asgi.py
Normal file
|
@ -0,0 +1,350 @@
|
|||
import asyncio
|
||||
import warnings
|
||||
|
||||
from http.cookies import SimpleCookie
|
||||
from inspect import isawaitable
|
||||
from typing import Any, Awaitable, Callable, MutableMapping, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from multidict import CIMultiDict
|
||||
|
||||
from sanic.exceptions import InvalidUsage, ServerError
|
||||
from sanic.log import logger
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||
from sanic.server import StreamBuffer
|
||||
from sanic.websocket import WebSocketConnection
|
||||
|
||||
|
||||
ASGIScope = MutableMapping[str, Any]
|
||||
ASGIMessage = MutableMapping[str, Any]
|
||||
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
|
||||
ASGIReceive = Callable[[], Awaitable[ASGIMessage]]
|
||||
|
||||
|
||||
class MockProtocol:
|
||||
def __init__(self, transport: "MockTransport", loop):
|
||||
self.transport = transport
|
||||
self._not_paused = asyncio.Event(loop=loop)
|
||||
self._not_paused.set()
|
||||
self._complete = asyncio.Event(loop=loop)
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
self._not_paused.clear()
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
self._not_paused.set()
|
||||
|
||||
async def complete(self) -> None:
|
||||
self._not_paused.set()
|
||||
await self.transport.send(
|
||||
{"type": "http.response.body", "body": b"", "more_body": False}
|
||||
)
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
return self._complete.is_set()
|
||||
|
||||
async def push_data(self, data: bytes) -> None:
|
||||
if not self.is_complete:
|
||||
await self.transport.send(
|
||||
{"type": "http.response.body", "body": data, "more_body": True}
|
||||
)
|
||||
|
||||
async def drain(self) -> None:
|
||||
await self._not_paused.wait()
|
||||
|
||||
|
||||
class MockTransport:
|
||||
def __init__(
|
||||
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self._receive = receive
|
||||
self._send = send
|
||||
self._protocol = None
|
||||
self.loop = None
|
||||
|
||||
def get_protocol(self) -> MockProtocol:
|
||||
if not self._protocol:
|
||||
self._protocol = MockProtocol(self, self.loop)
|
||||
return self._protocol
|
||||
|
||||
def get_extra_info(self, info: str) -> Union[str, bool]:
|
||||
if info == "peername":
|
||||
return self.scope.get("server")
|
||||
elif info == "sslcontext":
|
||||
return self.scope.get("scheme") in ["https", "wss"]
|
||||
|
||||
def get_websocket_connection(self) -> WebSocketConnection:
|
||||
try:
|
||||
return self._websocket_connection
|
||||
except AttributeError:
|
||||
raise InvalidUsage("Improper websocket connection.")
|
||||
|
||||
def create_websocket_connection(
|
||||
self, send: ASGISend, receive: ASGIReceive
|
||||
) -> WebSocketConnection:
|
||||
self._websocket_connection = WebSocketConnection(send, receive)
|
||||
return self._websocket_connection
|
||||
|
||||
def add_task(self) -> None: # noqa
|
||||
raise NotImplementedError
|
||||
|
||||
async def send(self, data) -> None:
|
||||
# TODO:
|
||||
# - Validation on data and that it is formatted properly and is valid
|
||||
await self._send(data)
|
||||
|
||||
async def receive(self) -> ASGIMessage:
|
||||
return await self._receive()
|
||||
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self, asgi_app: "ASGIApp") -> None:
|
||||
self.asgi_app = asgi_app
|
||||
|
||||
if "before_server_start" in self.asgi_app.sanic_app.listeners:
|
||||
warnings.warn(
|
||||
'You have set a listener for "before_server_start" '
|
||||
"in ASGI mode. "
|
||||
"It will be executed as early as possible, but not before "
|
||||
"the ASGI server is started."
|
||||
)
|
||||
if "after_server_stop" in self.asgi_app.sanic_app.listeners:
|
||||
warnings.warn(
|
||||
'You have set a listener for "after_server_stop" '
|
||||
"in ASGI mode. "
|
||||
"It will be executed as late as possible, but not after "
|
||||
"the ASGI server is stopped."
|
||||
)
|
||||
|
||||
# async def pre_startup(self) -> None:
|
||||
# for handler in self.asgi_app.sanic_app.listeners[
|
||||
# "before_server_start"
|
||||
# ]:
|
||||
# response = handler(
|
||||
# self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
# )
|
||||
# if isawaitable(response):
|
||||
# await response
|
||||
|
||||
async def startup(self) -> None:
|
||||
for handler in self.asgi_app.sanic_app.listeners[
|
||||
"before_server_start"
|
||||
]:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
await response
|
||||
|
||||
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
await response
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
await response
|
||||
|
||||
for handler in self.asgi_app.sanic_app.listeners["after_server_stop"]:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
await response
|
||||
|
||||
async def __call__(
|
||||
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
||||
) -> None:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await self.startup()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.shutdown":
|
||||
await self.shutdown()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
|
||||
class ASGIApp:
|
||||
def __init__(self) -> None:
|
||||
self.ws = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
||||
) -> "ASGIApp":
|
||||
instance = cls()
|
||||
instance.sanic_app = sanic_app
|
||||
instance.transport = MockTransport(scope, receive, send)
|
||||
instance.transport.add_task = sanic_app.loop.create_task
|
||||
instance.transport.loop = sanic_app.loop
|
||||
|
||||
headers = CIMultiDict(
|
||||
[
|
||||
(key.decode("latin-1"), value.decode("latin-1"))
|
||||
for key, value in scope.get("headers", [])
|
||||
]
|
||||
)
|
||||
instance.do_stream = (
|
||||
True if headers.get("expect") == "100-continue" else False
|
||||
)
|
||||
instance.lifespan = Lifespan(instance)
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await instance.lifespan(scope, receive, send)
|
||||
else:
|
||||
url_bytes = scope.get("root_path", "") + quote(scope["path"])
|
||||
url_bytes = url_bytes.encode("latin-1")
|
||||
url_bytes += b"?" + scope["query_string"]
|
||||
|
||||
if scope["type"] == "http":
|
||||
version = scope["http_version"]
|
||||
method = scope["method"]
|
||||
elif scope["type"] == "websocket":
|
||||
version = "1.1"
|
||||
method = "GET"
|
||||
|
||||
instance.ws = instance.transport.create_websocket_connection(
|
||||
send, receive
|
||||
)
|
||||
await instance.ws.accept()
|
||||
else:
|
||||
pass
|
||||
# TODO:
|
||||
# - close connection
|
||||
|
||||
instance.request = Request(
|
||||
url_bytes,
|
||||
headers,
|
||||
version,
|
||||
method,
|
||||
instance.transport,
|
||||
sanic_app,
|
||||
)
|
||||
|
||||
if sanic_app.is_request_stream:
|
||||
is_stream_handler = sanic_app.router.is_stream_handler(
|
||||
instance.request
|
||||
)
|
||||
if is_stream_handler:
|
||||
instance.request.stream = StreamBuffer(
|
||||
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
|
||||
)
|
||||
instance.do_stream = True
|
||||
|
||||
return instance
|
||||
|
||||
async def read_body(self) -> bytes:
|
||||
"""
|
||||
Read and return the entire body from an incoming ASGI message.
|
||||
"""
|
||||
body = b""
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await self.transport.receive()
|
||||
body += message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
return body
|
||||
|
||||
async def stream_body(self) -> None:
|
||||
"""
|
||||
Read and stream the body in chunks from an incoming ASGI message.
|
||||
"""
|
||||
more_body = True
|
||||
|
||||
while more_body:
|
||||
message = await self.transport.receive()
|
||||
chunk = message.get("body", b"")
|
||||
await self.request.stream.put(chunk)
|
||||
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
await self.request.stream.put(None)
|
||||
|
||||
async def __call__(self) -> None:
|
||||
"""
|
||||
Handle the incoming request.
|
||||
"""
|
||||
if not self.do_stream:
|
||||
self.request.body = await self.read_body()
|
||||
else:
|
||||
self.sanic_app.loop.create_task(self.stream_body())
|
||||
|
||||
handler = self.sanic_app.handle_request
|
||||
callback = None if self.ws else self.stream_callback
|
||||
await handler(self.request, None, callback)
|
||||
|
||||
async def stream_callback(self, response: HTTPResponse) -> None:
|
||||
"""
|
||||
Write the response.
|
||||
"""
|
||||
|
||||
try:
|
||||
headers = [
|
||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||
for name, value in response.headers.items()
|
||||
]
|
||||
except AttributeError:
|
||||
logger.error(
|
||||
"Invalid response object for url %s, "
|
||||
"Expected Type: HTTPResponse, Actual Type: %s",
|
||||
self.request.url,
|
||||
type(response),
|
||||
)
|
||||
exception = ServerError("Invalid response type")
|
||||
response = self.sanic_app.error_handler.response(
|
||||
self.request, exception
|
||||
)
|
||||
headers = [
|
||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||
for name, value in response.headers.items()
|
||||
if name not in (b"Set-Cookie",)
|
||||
]
|
||||
|
||||
if "content-length" not in response.headers and not isinstance(
|
||||
response, StreamingHTTPResponse
|
||||
):
|
||||
headers += [
|
||||
(b"content-length", str(len(response.body)).encode("latin-1"))
|
||||
]
|
||||
|
||||
if response.cookies:
|
||||
cookies = SimpleCookie()
|
||||
cookies.load(response.cookies)
|
||||
headers += [
|
||||
(b"set-cookie", cookie.encode("utf-8"))
|
||||
for name, cookie in response.cookies.items()
|
||||
]
|
||||
|
||||
await self.transport.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": response.status,
|
||||
"headers": headers,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(response, StreamingHTTPResponse):
|
||||
response.protocol = self.transport.get_protocol()
|
||||
await response.stream()
|
||||
await response.protocol.complete()
|
||||
|
||||
else:
|
||||
await self.transport.send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": response.body,
|
||||
"more_body": False,
|
||||
}
|
||||
)
|
|
@ -87,9 +87,9 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
|||
data = self._encode_body(data)
|
||||
|
||||
if self.chunked:
|
||||
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
|
||||
await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
|
||||
else:
|
||||
self.protocol.push_data(data)
|
||||
await self.protocol.push_data(data)
|
||||
await self.protocol.drain()
|
||||
|
||||
async def stream(
|
||||
|
@ -105,11 +105,11 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
|||
keep_alive=keep_alive,
|
||||
keep_alive_timeout=keep_alive_timeout,
|
||||
)
|
||||
self.protocol.push_data(headers)
|
||||
await self.protocol.push_data(headers)
|
||||
await self.protocol.drain()
|
||||
await self.streaming_fn(self)
|
||||
if self.chunked:
|
||||
self.protocol.push_data(b"0\r\n\r\n")
|
||||
await self.protocol.push_data(b"0\r\n\r\n")
|
||||
# no need to await drain here after this write, because it is the
|
||||
# very last thing we write and nothing needs to wait for it.
|
||||
|
||||
|
|
|
@ -406,6 +406,7 @@ class Router:
|
|||
if not self.hosts:
|
||||
return self._get(request.path, request.method, "")
|
||||
# virtual hosts specified; try to match route to the host header
|
||||
|
||||
try:
|
||||
return self._get(
|
||||
request.path, request.method, request.headers.get("Host", "")
|
||||
|
|
|
@ -477,7 +477,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
async def drain(self):
|
||||
await self._not_paused.wait()
|
||||
|
||||
def push_data(self, data):
|
||||
async def push_data(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
async def stream_response(self, response):
|
||||
|
@ -728,6 +728,8 @@ def serve(
|
|||
if debug:
|
||||
loop.set_debug(debug)
|
||||
|
||||
app.asgi = False
|
||||
|
||||
connections = connections if connections is not None else set()
|
||||
server = partial(
|
||||
protocol,
|
||||
|
|
262
sanic/testing.py
262
sanic/testing.py
|
@ -1,14 +1,22 @@
|
|||
import asyncio
|
||||
import types
|
||||
import typing
|
||||
|
||||
from json import JSONDecodeError
|
||||
from socket import socket
|
||||
from urllib.parse import unquote, urlsplit
|
||||
|
||||
import httpcore
|
||||
import requests_async as requests
|
||||
import websockets
|
||||
|
||||
from sanic.asgi import ASGIApp
|
||||
from sanic.exceptions import MethodNotSupported
|
||||
from sanic.log import logger
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
ASGI_HOST = "mockserver"
|
||||
HOST = "127.0.0.1"
|
||||
PORT = 42101
|
||||
|
||||
|
@ -64,7 +72,7 @@ class SanicTestClient:
|
|||
debug=False,
|
||||
server_kwargs={"auto_reload": False},
|
||||
*request_args,
|
||||
**request_kwargs
|
||||
**request_kwargs,
|
||||
):
|
||||
results = [None, None]
|
||||
exceptions = []
|
||||
|
@ -128,7 +136,7 @@ class SanicTestClient:
|
|||
try:
|
||||
request, response = results
|
||||
return request, response
|
||||
except BaseException:
|
||||
except BaseException: # noqa
|
||||
raise ValueError(
|
||||
"Request and response object expected, got ({})".format(
|
||||
results
|
||||
|
@ -137,7 +145,7 @@ class SanicTestClient:
|
|||
else:
|
||||
try:
|
||||
return results[-1]
|
||||
except BaseException:
|
||||
except BaseException: # noqa
|
||||
raise ValueError(
|
||||
"Request object expected, got ({})".format(results)
|
||||
)
|
||||
|
@ -165,3 +173,251 @@ class SanicTestClient:
|
|||
|
||||
def websocket(self, *args, **kwargs):
|
||||
return self._sanic_endpoint_test("websocket", *args, **kwargs)
|
||||
|
||||
|
||||
class SanicASGIAdapter(requests.asgi.ASGIAdapter): # noqa
|
||||
async def send( # type: ignore
|
||||
self,
|
||||
request: requests.PreparedRequest,
|
||||
gather_return: bool = False,
|
||||
*args: typing.Any,
|
||||
**kwargs: typing.Any,
|
||||
) -> requests.Response:
|
||||
"""This method is taken MOSTLY verbatim from requests-asyn. The
|
||||
difference is the capturing of a response on the ASGI call and then
|
||||
returning it on the response object. This is implemented to achieve:
|
||||
|
||||
request, response = await app.asgi_client.get("/")
|
||||
|
||||
You can see the original code here:
|
||||
https://github.com/encode/requests-async/blob/614f40f77f19e6c6da8a212ae799107b0384dbf9/requests_async/asgi.py#L51""" # noqa
|
||||
scheme, netloc, path, query, fragment = urlsplit(
|
||||
request.url
|
||||
) # type: ignore
|
||||
|
||||
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
|
||||
|
||||
if ":" in netloc:
|
||||
host, port_string = netloc.split(":", 1)
|
||||
port = int(port_string)
|
||||
else:
|
||||
host = netloc
|
||||
port = default_port
|
||||
|
||||
# Include the 'host' header.
|
||||
if "host" in request.headers:
|
||||
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
elif port == default_port:
|
||||
headers = [(b"host", host.encode())]
|
||||
else:
|
||||
headers = [(b"host", (f"{host}:{port}").encode())]
|
||||
|
||||
# Include other request headers.
|
||||
headers += [
|
||||
(key.lower().encode(), value.encode())
|
||||
for key, value in request.headers.items()
|
||||
]
|
||||
|
||||
no_response = False
|
||||
if scheme in {"ws", "wss"}:
|
||||
subprotocol = request.headers.get("sec-websocket-protocol", None)
|
||||
if subprotocol is None:
|
||||
subprotocols = [] # type: typing.Sequence[str]
|
||||
else:
|
||||
subprotocols = [
|
||||
value.strip() for value in subprotocol.split(",")
|
||||
]
|
||||
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"path": unquote(path),
|
||||
"root_path": "",
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"server": [host, port],
|
||||
"subprotocols": subprotocols,
|
||||
}
|
||||
no_response = True
|
||||
|
||||
else:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"path": unquote(path),
|
||||
"root_path": "",
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": ["testclient", 50000],
|
||||
"server": [host, port],
|
||||
"extensions": {"http.response.template": {}},
|
||||
}
|
||||
|
||||
async def receive():
|
||||
nonlocal request_complete, response_complete
|
||||
|
||||
if request_complete:
|
||||
while not response_complete:
|
||||
await asyncio.sleep(0.0001)
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
body = request.body
|
||||
if isinstance(body, str):
|
||||
body_bytes = body.encode("utf-8") # type: bytes
|
||||
elif body is None:
|
||||
body_bytes = b""
|
||||
elif isinstance(body, types.GeneratorType):
|
||||
try:
|
||||
chunk = body.send(None)
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.encode("utf-8")
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": True,
|
||||
}
|
||||
except StopIteration:
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b""}
|
||||
else:
|
||||
body_bytes = body
|
||||
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body_bytes}
|
||||
|
||||
async def send(message) -> None:
|
||||
nonlocal raw_kwargs, response_started, response_complete, template, context # noqa
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert (
|
||||
not response_started
|
||||
), 'Received multiple "http.response.start" messages.'
|
||||
raw_kwargs["status_code"] = message["status"]
|
||||
raw_kwargs["headers"] = message["headers"]
|
||||
response_started = True
|
||||
elif message["type"] == "http.response.body":
|
||||
assert response_started, (
|
||||
'Received "http.response.body" '
|
||||
'without "http.response.start".'
|
||||
)
|
||||
assert (
|
||||
not response_complete
|
||||
), 'Received "http.response.body" after response completed.'
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if request.method != "HEAD":
|
||||
raw_kwargs["content"] += body
|
||||
if not more_body:
|
||||
response_complete = True
|
||||
elif message["type"] == "http.response.template":
|
||||
template = message["template"]
|
||||
context = message["context"]
|
||||
|
||||
request_complete = False
|
||||
response_started = False
|
||||
response_complete = False
|
||||
raw_kwargs = {"content": b""} # type: typing.Dict[str, typing.Any]
|
||||
template = None
|
||||
context = None
|
||||
return_value = None
|
||||
|
||||
try:
|
||||
return_value = await self.app(scope, receive, send)
|
||||
except BaseException as exc:
|
||||
if not self.suppress_exceptions:
|
||||
raise exc from None
|
||||
|
||||
if no_response:
|
||||
response_started = True
|
||||
raw_kwargs = {"status_code": 204, "headers": []}
|
||||
|
||||
if not self.suppress_exceptions:
|
||||
assert response_started, "TestClient did not receive any response."
|
||||
elif not response_started:
|
||||
raw_kwargs = {"status_code": 500, "headers": []}
|
||||
|
||||
raw = httpcore.Response(**raw_kwargs)
|
||||
response = self.build_response(request, raw)
|
||||
if template is not None:
|
||||
response.template = template
|
||||
response.context = context
|
||||
|
||||
if gather_return:
|
||||
response.return_value = return_value
|
||||
return response
|
||||
|
||||
|
||||
class TestASGIApp(ASGIApp):
|
||||
async def __call__(self):
|
||||
await super().__call__()
|
||||
return self.request
|
||||
|
||||
|
||||
async def app_call_with_return(self, scope, receive, send):
|
||||
asgi_app = await TestASGIApp.create(self, scope, receive, send)
|
||||
return await asgi_app()
|
||||
|
||||
|
||||
class SanicASGITestClient(requests.ASGISession):
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
base_url: str = "http://{}".format(ASGI_HOST),
|
||||
suppress_exceptions: bool = False,
|
||||
) -> None:
|
||||
app.__class__.__call__ = app_call_with_return
|
||||
app.asgi = True
|
||||
super().__init__(app)
|
||||
|
||||
adapter = SanicASGIAdapter(
|
||||
app, suppress_exceptions=suppress_exceptions
|
||||
)
|
||||
self.mount("http://", adapter)
|
||||
self.mount("https://", adapter)
|
||||
self.mount("ws://", adapter)
|
||||
self.mount("wss://", adapter)
|
||||
self.headers.update({"user-agent": "testclient"})
|
||||
self.app = app
|
||||
self.base_url = base_url
|
||||
|
||||
async def request(self, method, url, gather_request=True, *args, **kwargs):
|
||||
|
||||
self.gather_request = gather_request
|
||||
response = await super().request(method, url, *args, **kwargs)
|
||||
response.status = response.status_code
|
||||
response.body = response.content
|
||||
response.content_type = response.headers.get("content-type")
|
||||
|
||||
if hasattr(response, "return_value"):
|
||||
request = response.return_value
|
||||
del response.return_value
|
||||
return request, response
|
||||
|
||||
return response
|
||||
|
||||
def merge_environment_settings(self, *args, **kwargs):
|
||||
settings = super().merge_environment_settings(*args, **kwargs)
|
||||
settings.update({"gather_return": self.gather_request})
|
||||
return settings
|
||||
|
||||
async def websocket(self, uri, subprotocols=None, *args, **kwargs):
|
||||
if uri.startswith(("ws:", "wss:")):
|
||||
url = uri
|
||||
else:
|
||||
uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri)
|
||||
url = "ws://testserver{uri}".format(uri=uri)
|
||||
|
||||
headers = kwargs.get("headers", {})
|
||||
headers.setdefault("connection", "upgrade")
|
||||
headers.setdefault("sec-websocket-key", "testserver==")
|
||||
headers.setdefault("sec-websocket-version", "13")
|
||||
if subprotocols is not None:
|
||||
headers.setdefault(
|
||||
"sec-websocket-protocol", ", ".join(subprotocols)
|
||||
)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
return await self.request("websocket", url, **kwargs)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
|
||||
|
||||
from httptools import HttpParserUpgrade
|
||||
from websockets import ConnectionClosed # noqa
|
||||
from websockets import InvalidHandshake, WebSocketCommonProtocol, handshake
|
||||
|
@ -6,6 +8,9 @@ from sanic.exceptions import InvalidUsage
|
|||
from sanic.server import HttpProtocol
|
||||
|
||||
|
||||
ASIMessage = MutableMapping[str, Any]
|
||||
|
||||
|
||||
class WebSocketProtocol(HttpProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -19,6 +24,7 @@ class WebSocketProtocol(HttpProtocol):
|
|||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.websocket = None
|
||||
# self.app = None
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_max_size = websocket_max_size
|
||||
self.websocket_max_queue = websocket_max_queue
|
||||
|
@ -103,3 +109,45 @@ class WebSocketProtocol(HttpProtocol):
|
|||
self.websocket.connection_made(request.transport)
|
||||
self.websocket.connection_open()
|
||||
return self.websocket
|
||||
|
||||
|
||||
class WebSocketConnection:
|
||||
|
||||
# TODO
|
||||
# - Implement ping/pong
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send: Callable[[ASIMessage], Awaitable[None]],
|
||||
receive: Callable[[], Awaitable[ASIMessage]],
|
||||
) -> None:
|
||||
self._send = send
|
||||
self._receive = receive
|
||||
|
||||
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
||||
message = {"type": "websocket.send"}
|
||||
|
||||
try:
|
||||
data.decode()
|
||||
except AttributeError:
|
||||
message.update({"text": str(data)})
|
||||
else:
|
||||
message.update({"bytes": data})
|
||||
|
||||
await self._send(message)
|
||||
|
||||
async def recv(self, *args, **kwargs) -> Optional[str]:
|
||||
message = await self._receive()
|
||||
|
||||
if message["type"] == "websocket.receive":
|
||||
return message["text"]
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
pass
|
||||
|
||||
receive = recv
|
||||
|
||||
async def accept(self) -> None:
|
||||
await self._send({"type": "websocket.accept", "subprotocol": ""})
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app):
|
|||
|
||||
def test_app_loop_not_running(app):
|
||||
with pytest.raises(SanicException) as excinfo:
|
||||
_ = app.loop
|
||||
app.loop
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Loop can only be retrieved after the app has started "
|
||||
|
|
203
tests/test_asgi.py
Normal file
203
tests/test_asgi.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
import asyncio
|
||||
|
||||
from collections import deque
|
||||
|
||||
import pytest
|
||||
import uvicorn
|
||||
|
||||
from sanic.asgi import MockTransport
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.websocket import WebSocketConnection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_stack():
|
||||
return deque()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def receive(message_stack):
|
||||
async def _receive():
|
||||
return message_stack.popleft()
|
||||
|
||||
return _receive
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def send(message_stack):
|
||||
async def _send(message):
|
||||
message_stack.append(message)
|
||||
|
||||
return _send
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport(message_stack, receive, send):
|
||||
return MockTransport({}, receive, send)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
# @pytest.mark.asyncio
|
||||
def protocol(transport, loop):
|
||||
return transport.get_protocol()
|
||||
|
||||
|
||||
def test_listeners_triggered(app):
|
||||
before_server_start = False
|
||||
after_server_start = False
|
||||
before_server_stop = False
|
||||
after_server_stop = False
|
||||
|
||||
@app.listener("before_server_start")
|
||||
def do_before_server_start(*args, **kwargs):
|
||||
nonlocal before_server_start
|
||||
before_server_start = True
|
||||
|
||||
@app.listener("after_server_start")
|
||||
def do_after_server_start(*args, **kwargs):
|
||||
nonlocal after_server_start
|
||||
after_server_start = True
|
||||
|
||||
@app.listener("before_server_stop")
|
||||
def do_before_server_stop(*args, **kwargs):
|
||||
nonlocal before_server_stop
|
||||
before_server_stop = True
|
||||
|
||||
@app.listener("after_server_stop")
|
||||
def do_after_server_stop(*args, **kwargs):
|
||||
nonlocal after_server_stop
|
||||
after_server_stop = True
|
||||
|
||||
class CustomServer(uvicorn.Server):
|
||||
def install_signal_handlers(self):
|
||||
pass
|
||||
|
||||
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
|
||||
server = CustomServer(config=config)
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
server.run()
|
||||
|
||||
for task in asyncio.Task.all_tasks():
|
||||
task.cancel()
|
||||
|
||||
assert before_server_start
|
||||
assert after_server_start
|
||||
assert before_server_stop
|
||||
assert after_server_stop
|
||||
|
||||
|
||||
def test_listeners_triggered_async(app):
|
||||
before_server_start = False
|
||||
after_server_start = False
|
||||
before_server_stop = False
|
||||
after_server_stop = False
|
||||
|
||||
@app.listener("before_server_start")
|
||||
async def do_before_server_start(*args, **kwargs):
|
||||
nonlocal before_server_start
|
||||
before_server_start = True
|
||||
|
||||
@app.listener("after_server_start")
|
||||
async def do_after_server_start(*args, **kwargs):
|
||||
nonlocal after_server_start
|
||||
after_server_start = True
|
||||
|
||||
@app.listener("before_server_stop")
|
||||
async def do_before_server_stop(*args, **kwargs):
|
||||
nonlocal before_server_stop
|
||||
before_server_stop = True
|
||||
|
||||
@app.listener("after_server_stop")
|
||||
async def do_after_server_stop(*args, **kwargs):
|
||||
nonlocal after_server_stop
|
||||
after_server_stop = True
|
||||
|
||||
class CustomServer(uvicorn.Server):
|
||||
def install_signal_handlers(self):
|
||||
pass
|
||||
|
||||
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
|
||||
server = CustomServer(config=config)
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
server.run()
|
||||
|
||||
for task in asyncio.Task.all_tasks():
|
||||
task.cancel()
|
||||
|
||||
assert before_server_start
|
||||
assert after_server_start
|
||||
assert before_server_stop
|
||||
assert after_server_stop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mockprotocol_events(protocol):
|
||||
assert protocol._not_paused.is_set()
|
||||
protocol.pause_writing()
|
||||
assert not protocol._not_paused.is_set()
|
||||
protocol.resume_writing()
|
||||
assert protocol._not_paused.is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_protocol_push_data(protocol, message_stack):
|
||||
text = b"hello"
|
||||
|
||||
await protocol.push_data(text)
|
||||
await protocol.complete()
|
||||
|
||||
assert len(message_stack) == 2
|
||||
|
||||
message = message_stack.popleft()
|
||||
assert message["type"] == "http.response.body"
|
||||
assert message["more_body"]
|
||||
assert message["body"] == text
|
||||
|
||||
message = message_stack.popleft()
|
||||
assert message["type"] == "http.response.body"
|
||||
assert not message["more_body"]
|
||||
assert message["body"] == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_send(send, receive, message_stack):
|
||||
text_string = "hello"
|
||||
text_bytes = b"hello"
|
||||
|
||||
ws = WebSocketConnection(send, receive)
|
||||
await ws.send(text_string)
|
||||
await ws.send(text_bytes)
|
||||
|
||||
assert len(message_stack) == 2
|
||||
|
||||
message = message_stack.popleft()
|
||||
assert message["type"] == "websocket.send"
|
||||
assert message["text"] == text_string
|
||||
assert "bytes" not in message
|
||||
|
||||
message = message_stack.popleft()
|
||||
assert message["type"] == "websocket.send"
|
||||
assert message["bytes"] == text_bytes
|
||||
assert "text" not in message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_receive(send, receive, message_stack):
|
||||
msg = {"text": "hello", "type": "websocket.receive"}
|
||||
message_stack.append(msg)
|
||||
|
||||
ws = WebSocketConnection(send, receive)
|
||||
text = await ws.receive()
|
||||
|
||||
assert text == msg["text"]
|
||||
|
||||
|
||||
def test_improper_websocket_connection(transport, send, receive):
|
||||
with pytest.raises(InvalidUsage):
|
||||
transport.get_websocket_connection()
|
||||
|
||||
transport.create_websocket_connection(send, receive)
|
||||
connection = transport.get_websocket_connection()
|
||||
assert isinstance(connection, WebSocketConnection)
|
5
tests/test_asgi_client.py
Normal file
5
tests/test_asgi_client.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from sanic.testing import SanicASGITestClient
|
||||
|
||||
|
||||
def test_asgi_client_instantiation(app):
|
||||
assert isinstance(app.asgi_client, SanicASGITestClient)
|
|
@ -239,6 +239,7 @@ def test_config_access_log_passing_in_run(app):
|
|||
assert app.config.ACCESS_LOG == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_access_log_passing_in_create_server(app):
|
||||
assert app.config.ACCESS_LOG == True
|
||||
|
||||
|
|
|
@ -27,6 +27,24 @@ def test_cookies(app):
|
|||
assert response_cookies["right_back"].value == "at you"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cookies_asgi(app):
|
||||
@app.route("/")
|
||||
def handler(request):
|
||||
response = text("Cookies are: {}".format(request.cookies["test"]))
|
||||
response.cookies["right_back"] = "at you"
|
||||
return response
|
||||
|
||||
request, response = await app.asgi_client.get(
|
||||
"/", cookies={"test": "working!"}
|
||||
)
|
||||
response_cookies = SimpleCookie()
|
||||
response_cookies.load(response.headers.get("set-cookie", {}))
|
||||
|
||||
assert response.text == "Cookies are: working!"
|
||||
assert response_cookies["right_back"].value == "at you"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)])
|
||||
def test_false_cookies_encoded(app, httponly, expected):
|
||||
@app.route("/")
|
||||
|
|
|
@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"])
|
||||
async def test_redirect_with_params(app, sanic_client, test_str):
|
||||
def test_redirect_with_params(app, test_str):
|
||||
use_in_uri = quote(test_str)
|
||||
|
||||
@app.route("/api/v1/test/<test>/")
|
||||
async def init_handler(request, test):
|
||||
assert test == test_str
|
||||
return redirect("/api/v2/test/{}/".format(quote(test)))
|
||||
return redirect("/api/v2/test/{}/".format(use_in_uri))
|
||||
|
||||
@app.route("/api/v2/test/<test>/")
|
||||
async def target_handler(request, test):
|
||||
assert test == test_str
|
||||
return text("OK")
|
||||
|
||||
test_cli = await sanic_client(app)
|
||||
|
||||
response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str)))
|
||||
_, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri))
|
||||
assert response.status == 200
|
||||
|
||||
txt = await response.text()
|
||||
assert txt == "OK"
|
||||
assert response.content == b"OK"
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic.response import stream, text
|
||||
|
||||
|
||||
async def test_request_cancel_when_connection_lost(loop, app, sanic_client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_cancel_when_connection_lost(app):
|
||||
app.still_serving_cancelled_request = False
|
||||
|
||||
@app.get("/")
|
||||
|
@ -14,10 +17,9 @@ async def test_request_cancel_when_connection_lost(loop, app, sanic_client):
|
|||
app.still_serving_cancelled_request = True
|
||||
return text("OK")
|
||||
|
||||
test_cli = await sanic_client(app)
|
||||
|
||||
# schedule client call
|
||||
task = loop.create_task(test_cli.get("/"))
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(app.asgi_client.get("/"))
|
||||
loop.call_later(0.01, task)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
@ -33,7 +35,8 @@ async def test_request_cancel_when_connection_lost(loop, app, sanic_client):
|
|||
assert app.still_serving_cancelled_request is False
|
||||
|
||||
|
||||
async def test_stream_request_cancel_when_conn_lost(loop, app, sanic_client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_request_cancel_when_conn_lost(app):
|
||||
app.still_serving_cancelled_request = False
|
||||
|
||||
@app.post("/post/<id>", stream=True)
|
||||
|
@ -53,10 +56,9 @@ async def test_stream_request_cancel_when_conn_lost(loop, app, sanic_client):
|
|||
|
||||
return stream(streaming)
|
||||
|
||||
test_cli = await sanic_client(app)
|
||||
|
||||
# schedule client call
|
||||
task = loop.create_task(test_cli.post("/post/1"))
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(app.asgi_client.post("/post/1"))
|
||||
loop.call_later(0.01, task)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.exceptions import HeaderExpectationFailed
|
||||
from sanic.request import StreamBuffer
|
||||
|
@ -42,13 +43,15 @@ def test_request_stream_method_view(app):
|
|||
assert response.text == data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("headers, expect_raise_exception", [
|
||||
({"EXPECT": "100-continue"}, False),
|
||||
({"EXPECT": "100-continue-extra"}, True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"headers, expect_raise_exception",
|
||||
[
|
||||
({"EXPECT": "100-continue"}, False),
|
||||
({"EXPECT": "100-continue-extra"}, True),
|
||||
],
|
||||
)
|
||||
def test_request_stream_100_continue(app, headers, expect_raise_exception):
|
||||
class SimpleView(HTTPMethodView):
|
||||
|
||||
@stream_decorator
|
||||
async def post(self, request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
|
@ -65,12 +68,18 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception):
|
|||
assert app.is_request_stream is True
|
||||
|
||||
if not expect_raise_exception:
|
||||
request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"})
|
||||
request, response = app.test_client.post(
|
||||
"/method_view", data=data, headers={"EXPECT": "100-continue"}
|
||||
)
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
||||
else:
|
||||
with pytest.raises(ValueError) as e:
|
||||
app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"})
|
||||
app.test_client.post(
|
||||
"/method_view",
|
||||
data=data,
|
||||
headers={"EXPECT": "100-continue-extra"},
|
||||
)
|
||||
assert "Unknown Expect: 100-continue-extra" in str(e)
|
||||
|
||||
|
||||
|
@ -188,6 +197,121 @@ def test_request_stream_app(app):
|
|||
assert response.text == data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_stream_app_asgi(app):
|
||||
"""for self.is_request_stream = True and decorators"""
|
||||
|
||||
@app.get("/get")
|
||||
async def get(request):
|
||||
assert request.stream is None
|
||||
return text("GET")
|
||||
|
||||
@app.head("/head")
|
||||
async def head(request):
|
||||
assert request.stream is None
|
||||
return text("HEAD")
|
||||
|
||||
@app.delete("/delete")
|
||||
async def delete(request):
|
||||
assert request.stream is None
|
||||
return text("DELETE")
|
||||
|
||||
@app.options("/options")
|
||||
async def options(request):
|
||||
assert request.stream is None
|
||||
return text("OPTIONS")
|
||||
|
||||
@app.post("/_post/<id>")
|
||||
async def _post(request, id):
|
||||
assert request.stream is None
|
||||
return text("_POST")
|
||||
|
||||
@app.post("/post/<id>", stream=True)
|
||||
async def post(request, id):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
@app.put("/_put")
|
||||
async def _put(request):
|
||||
assert request.stream is None
|
||||
return text("_PUT")
|
||||
|
||||
@app.put("/put", stream=True)
|
||||
async def put(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
@app.patch("/_patch")
|
||||
async def _patch(request):
|
||||
assert request.stream is None
|
||||
return text("_PATCH")
|
||||
|
||||
@app.patch("/patch", stream=True)
|
||||
async def patch(request):
|
||||
assert isinstance(request.stream, StreamBuffer)
|
||||
result = ""
|
||||
while True:
|
||||
body = await request.stream.read()
|
||||
if body is None:
|
||||
break
|
||||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = await app.asgi_client.get("/get")
|
||||
assert response.status == 200
|
||||
assert response.text == "GET"
|
||||
|
||||
request, response = await app.asgi_client.head("/head")
|
||||
assert response.status == 200
|
||||
assert response.text == ""
|
||||
|
||||
request, response = await app.asgi_client.delete("/delete")
|
||||
assert response.status == 200
|
||||
assert response.text == "DELETE"
|
||||
|
||||
request, response = await app.asgi_client.options("/options")
|
||||
assert response.status == 200
|
||||
assert response.text == "OPTIONS"
|
||||
|
||||
request, response = await app.asgi_client.post("/_post/1", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == "_POST"
|
||||
|
||||
request, response = await app.asgi_client.post("/post/1", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
||||
|
||||
request, response = await app.asgi_client.put("/_put", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == "_PUT"
|
||||
|
||||
request, response = await app.asgi_client.put("/put", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
||||
|
||||
request, response = await app.asgi_client.patch("/_patch", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == "_PATCH"
|
||||
|
||||
request, response = await app.asgi_client.patch("/patch", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
||||
|
||||
|
||||
def test_request_stream_handle_exception(app):
|
||||
"""for handling exceptions properly"""
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked(
|
|||
async def mock_drain():
|
||||
pass
|
||||
|
||||
def mock_push_data(data):
|
||||
async def mock_push_data(data):
|
||||
response.protocol.transport.write(data)
|
||||
|
||||
response.protocol.push_data = mock_push_data
|
||||
|
@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
|
|||
async def mock_drain():
|
||||
pass
|
||||
|
||||
def mock_push_data(data):
|
||||
async def mock_push_data(data):
|
||||
response.protocol.transport.write(data)
|
||||
|
||||
response.protocol.push_data = mock_push_data
|
||||
|
|
|
@ -474,6 +474,19 @@ def test_websocket_route(app, url):
|
|||
assert ev.is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("url", ["/ws", "ws"])
|
||||
async def test_websocket_route_asgi(app, url):
|
||||
ev = asyncio.Event()
|
||||
|
||||
@app.websocket(url)
|
||||
async def handler(request, ws):
|
||||
ev.set()
|
||||
|
||||
request, response = await app.asgi_client.websocket(url)
|
||||
assert ev.is_set()
|
||||
|
||||
|
||||
def test_websocket_route_with_subprotocols(app):
|
||||
results = []
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ def test_all_listeners(app):
|
|||
assert app.name + listener_name == output.pop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_before_events_create_server(app):
|
||||
class MySanicDb:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue
Block a user