Increase testing coverage for ASGI

Beautify
This commit is contained in:
Adam Hopkins 2019-06-19 00:15:41 +03:00
parent fb61834a2e
commit 62e0e5b9ec
9 changed files with 429 additions and 34 deletions

View File

@ -5,3 +5,11 @@ omit = site-packages, sanic/utils.py, sanic/__main__.py
[html] [html]
directory = coverage directory = coverage
[report]
exclude_lines =
no cov
no qa
noqa
NOQA
pragma: no cover

View File

@ -88,7 +88,7 @@ class MockTransport:
self._websocket_connection = WebSocketConnection(send, receive) self._websocket_connection = WebSocketConnection(send, receive)
return self._websocket_connection return self._websocket_connection
def add_task(self) -> None: def add_task(self) -> None: # noqa
raise NotImplementedError raise NotImplementedError
async def send(self, data) -> None: async def send(self, data) -> None:
@ -119,15 +119,15 @@ class Lifespan:
"the ASGI server is stopped." "the ASGI server is stopped."
) )
async def pre_startup(self) -> None: # async def pre_startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[ # for handler in self.asgi_app.sanic_app.listeners[
"before_server_start" # "before_server_start"
]: # ]:
response = handler( # response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop # self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
) # )
if isawaitable(response): # if isawaitable(response):
await response # await response
async def startup(self) -> None: async def startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[ for handler in self.asgi_app.sanic_app.listeners[
@ -233,7 +233,14 @@ class ASGIApp:
) )
if sanic_app.is_request_stream: if sanic_app.is_request_stream:
instance.request.stream = StreamBuffer() is_stream_handler = sanic_app.router.is_stream_handler(
instance.request
)
if is_stream_handler:
instance.request.stream = StreamBuffer(
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
)
instance.do_stream = True
return instance return instance

View File

@ -136,7 +136,7 @@ class SanicTestClient:
try: try:
request, response = results request, response = results
return request, response return request, response
except BaseException: except BaseException: # noqa
raise ValueError( raise ValueError(
"Request and response object expected, got ({})".format( "Request and response object expected, got ({})".format(
results results
@ -145,7 +145,7 @@ class SanicTestClient:
else: else:
try: try:
return results[-1] return results[-1]
except BaseException: except BaseException: # noqa
raise ValueError( raise ValueError(
"Request object expected, got ({})".format(results) "Request object expected, got ({})".format(results)
) )
@ -175,7 +175,7 @@ class SanicTestClient:
return self._sanic_endpoint_test("websocket", *args, **kwargs) return self._sanic_endpoint_test("websocket", *args, **kwargs)
class SanicASGIAdapter(requests.asgi.ASGIAdapter): class SanicASGIAdapter(requests.asgi.ASGIAdapter): # noqa
async def send( # type: ignore async def send( # type: ignore
self, self,
request: requests.PreparedRequest, request: requests.PreparedRequest,
@ -218,19 +218,43 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
for key, value in request.headers.items() for key, value in request.headers.items()
] ]
scope = { no_response = False
"type": "http", if scheme in {"ws", "wss"}:
"http_version": "1.1", subprotocol = request.headers.get("sec-websocket-protocol", None)
"method": request.method, if subprotocol is None:
"path": unquote(path), subprotocols = [] # type: typing.Sequence[str]
"root_path": "", else:
"scheme": scheme, subprotocols = [
"query_string": query.encode(), value.strip() for value in subprotocol.split(",")
"headers": headers, ]
"client": ["testclient", 50000],
"server": [host, port], scope = {
"extensions": {"http.response.template": {}}, "type": "websocket",
} "path": unquote(path),
"root_path": "",
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": subprotocols,
}
no_response = True
else:
scope = {
"type": "http",
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"root_path": "",
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.template": {}},
}
async def receive(): async def receive():
nonlocal request_complete, response_complete nonlocal request_complete, response_complete
@ -306,6 +330,10 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
if not self.suppress_exceptions: if not self.suppress_exceptions:
raise exc from None raise exc from None
if no_response:
response_started = True
raw_kwargs = {"status_code": 204, "headers": []}
if not self.suppress_exceptions: if not self.suppress_exceptions:
assert response_started, "TestClient did not receive any response." assert response_started, "TestClient did not receive any response."
elif not response_started: elif not response_started:
@ -349,13 +377,15 @@ class SanicASGITestClient(requests.ASGISession):
) )
self.mount("http://", adapter) self.mount("http://", adapter)
self.mount("https://", adapter) self.mount("https://", adapter)
self.mount("ws://", adapter)
self.mount("wss://", adapter)
self.headers.update({"user-agent": "testclient"}) self.headers.update({"user-agent": "testclient"})
self.app = app self.app = app
self.base_url = base_url self.base_url = base_url
async def request(self, method, url, gather_request=True, *args, **kwargs): async def request(self, method, url, gather_request=True, *args, **kwargs):
self.gather_request = gather_request self.gather_request = gather_request
print(url)
response = await super().request(method, url, *args, **kwargs) response = await super().request(method, url, *args, **kwargs)
response.status = response.status_code response.status = response.status_code
response.body = response.content response.body = response.content
@ -372,3 +402,22 @@ class SanicASGITestClient(requests.ASGISession):
settings = super().merge_environment_settings(*args, **kwargs) settings = super().merge_environment_settings(*args, **kwargs)
settings.update({"gather_return": self.gather_request}) settings.update({"gather_return": self.gather_request})
return settings return settings
async def websocket(self, uri, subprotocols=None, *args, **kwargs):
if uri.startswith(("ws:", "wss:")):
url = uri
else:
uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri)
url = "ws://testserver{uri}".format(uri=uri)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
headers.setdefault("sec-websocket-key", "testserver==")
headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None:
headers.setdefault(
"sec-websocket-protocol", ", ".join(subprotocols)
)
kwargs["headers"] = headers
return await self.request("websocket", url, **kwargs)

View File

@ -143,9 +143,8 @@ class WebSocketConnection:
return message["text"] return message["text"]
elif message["type"] == "websocket.disconnect": elif message["type"] == "websocket.disconnect":
pass pass
# await self._send({
# "type": "websocket.close" receive = recv
# })
async def accept(self) -> None: async def accept(self) -> None:
await self._send({"type": "websocket.accept", "subprotocol": ""}) await self._send({"type": "websocket.accept", "subprotocol": ""})

View File

@ -1,5 +1,203 @@
from sanic.testing import SanicASGITestClient import asyncio
from collections import deque
import pytest
import uvicorn
from sanic.asgi import MockTransport
from sanic.exceptions import InvalidUsage
from sanic.websocket import WebSocketConnection
def test_asgi_client_instantiation(app): @pytest.fixture
assert isinstance(app.asgi_client, SanicASGITestClient) def message_stack():
return deque()
@pytest.fixture
def receive(message_stack):
async def _receive():
return message_stack.popleft()
return _receive
@pytest.fixture
def send(message_stack):
async def _send(message):
message_stack.append(message)
return _send
@pytest.fixture
def transport(message_stack, receive, send):
return MockTransport({}, receive, send)
@pytest.fixture
# @pytest.mark.asyncio
def protocol(transport, loop):
return transport.get_protocol()
def test_listeners_triggered(app):
before_server_start = False
after_server_start = False
before_server_stop = False
after_server_stop = False
@app.listener("before_server_start")
def do_before_server_start(*args, **kwargs):
nonlocal before_server_start
before_server_start = True
@app.listener("after_server_start")
def do_after_server_start(*args, **kwargs):
nonlocal after_server_start
after_server_start = True
@app.listener("before_server_stop")
def do_before_server_stop(*args, **kwargs):
nonlocal before_server_stop
before_server_stop = True
@app.listener("after_server_stop")
def do_after_server_stop(*args, **kwargs):
nonlocal after_server_stop
after_server_stop = True
class CustomServer(uvicorn.Server):
def install_signal_handlers(self):
pass
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
server = CustomServer(config=config)
with pytest.warns(UserWarning):
server.run()
for task in asyncio.Task.all_tasks():
task.cancel()
assert before_server_start
assert after_server_start
assert before_server_stop
assert after_server_stop
def test_listeners_triggered_async(app):
before_server_start = False
after_server_start = False
before_server_stop = False
after_server_stop = False
@app.listener("before_server_start")
async def do_before_server_start(*args, **kwargs):
nonlocal before_server_start
before_server_start = True
@app.listener("after_server_start")
async def do_after_server_start(*args, **kwargs):
nonlocal after_server_start
after_server_start = True
@app.listener("before_server_stop")
async def do_before_server_stop(*args, **kwargs):
nonlocal before_server_stop
before_server_stop = True
@app.listener("after_server_stop")
async def do_after_server_stop(*args, **kwargs):
nonlocal after_server_stop
after_server_stop = True
class CustomServer(uvicorn.Server):
def install_signal_handlers(self):
pass
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
server = CustomServer(config=config)
with pytest.warns(UserWarning):
server.run()
for task in asyncio.Task.all_tasks():
task.cancel()
assert before_server_start
assert after_server_start
assert before_server_stop
assert after_server_stop
@pytest.mark.asyncio
async def test_mockprotocol_events(protocol):
assert protocol._not_paused.is_set()
protocol.pause_writing()
assert not protocol._not_paused.is_set()
protocol.resume_writing()
assert protocol._not_paused.is_set()
@pytest.mark.asyncio
async def test_protocol_push_data(protocol, message_stack):
text = b"hello"
await protocol.push_data(text)
await protocol.complete()
assert len(message_stack) == 2
message = message_stack.popleft()
assert message["type"] == "http.response.body"
assert message["more_body"]
assert message["body"] == text
message = message_stack.popleft()
assert message["type"] == "http.response.body"
assert not message["more_body"]
assert message["body"] == b""
@pytest.mark.asyncio
async def test_websocket_send(send, receive, message_stack):
text_string = "hello"
text_bytes = b"hello"
ws = WebSocketConnection(send, receive)
await ws.send(text_string)
await ws.send(text_bytes)
assert len(message_stack) == 2
message = message_stack.popleft()
assert message["type"] == "websocket.send"
assert message["text"] == text_string
assert "bytes" not in message
message = message_stack.popleft()
assert message["type"] == "websocket.send"
assert message["bytes"] == text_bytes
assert "text" not in message
@pytest.mark.asyncio
async def test_websocket_receive(send, receive, message_stack):
msg = {"text": "hello", "type": "websocket.receive"}
message_stack.append(msg)
ws = WebSocketConnection(send, receive)
text = await ws.receive()
assert text == msg["text"]
def test_improper_websocket_connection(transport, send, receive):
with pytest.raises(InvalidUsage):
transport.get_websocket_connection()
transport.create_websocket_connection(send, receive)
connection = transport.get_websocket_connection()
assert isinstance(connection, WebSocketConnection)

View File

@ -0,0 +1,5 @@
from sanic.testing import SanicASGITestClient
def test_asgi_client_instantiation(app):
assert isinstance(app.asgi_client, SanicASGITestClient)

View File

@ -197,6 +197,121 @@ def test_request_stream_app(app):
assert response.text == data assert response.text == data
@pytest.mark.asyncio
async def test_request_stream_app_asgi(app):
"""for self.is_request_stream = True and decorators"""
@app.get("/get")
async def get(request):
assert request.stream is None
return text("GET")
@app.head("/head")
async def head(request):
assert request.stream is None
return text("HEAD")
@app.delete("/delete")
async def delete(request):
assert request.stream is None
return text("DELETE")
@app.options("/options")
async def options(request):
assert request.stream is None
return text("OPTIONS")
@app.post("/_post/<id>")
async def _post(request, id):
assert request.stream is None
return text("_POST")
@app.post("/post/<id>", stream=True)
async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
@app.put("/_put")
async def _put(request):
assert request.stream is None
return text("_PUT")
@app.put("/put", stream=True)
async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
@app.patch("/_patch")
async def _patch(request):
assert request.stream is None
return text("_PATCH")
@app.patch("/patch", stream=True)
async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
assert app.is_request_stream is True
request, response = await app.asgi_client.get("/get")
assert response.status == 200
assert response.text == "GET"
request, response = await app.asgi_client.head("/head")
assert response.status == 200
assert response.text == ""
request, response = await app.asgi_client.delete("/delete")
assert response.status == 200
assert response.text == "DELETE"
request, response = await app.asgi_client.options("/options")
assert response.status == 200
assert response.text == "OPTIONS"
request, response = await app.asgi_client.post("/_post/1", data=data)
assert response.status == 200
assert response.text == "_POST"
request, response = await app.asgi_client.post("/post/1", data=data)
assert response.status == 200
assert response.text == data
request, response = await app.asgi_client.put("/_put", data=data)
assert response.status == 200
assert response.text == "_PUT"
request, response = await app.asgi_client.put("/put", data=data)
assert response.status == 200
assert response.text == data
request, response = await app.asgi_client.patch("/_patch", data=data)
assert response.status == 200
assert response.text == "_PATCH"
request, response = await app.asgi_client.patch("/patch", data=data)
assert response.status == 200
assert response.text == data
def test_request_stream_handle_exception(app): def test_request_stream_handle_exception(app):
"""for handling exceptions properly""" """for handling exceptions properly"""

View File

@ -474,6 +474,19 @@ def test_websocket_route(app, url):
assert ev.is_set() assert ev.is_set()
@pytest.mark.asyncio
@pytest.mark.parametrize("url", ["/ws", "ws"])
async def test_websocket_route_asgi(app, url):
ev = asyncio.Event()
@app.websocket(url)
async def handler(request, ws):
ev.set()
request, response = await app.asgi_client.websocket(url)
assert ev.is_set()
def test_websocket_route_with_subprotocols(app): def test_websocket_route_with_subprotocols(app):
results = [] results = []

View File

@ -18,6 +18,7 @@ deps =
beautifulsoup4 beautifulsoup4
gunicorn gunicorn
pytest-benchmark pytest-benchmark
uvicorn
commands = commands =
pytest {posargs:tests --cov sanic} pytest {posargs:tests --cov sanic}
- coverage combine --append - coverage combine --append