Merge pull request #1475 from tomchristie/asgi-refactor-attempt

ASGI refactoring attempt
This commit is contained in:
7
2019-06-20 16:33:44 -07:00
committed by GitHub
23 changed files with 2068 additions and 61 deletions

View File

@@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app):
def test_app_loop_not_running(app):
with pytest.raises(SanicException) as excinfo:
_ = app.loop
app.loop
assert str(excinfo.value) == (
"Loop can only be retrieved after the app has started "

203
tests/test_asgi.py Normal file
View File

@@ -0,0 +1,203 @@
import asyncio
from collections import deque
import pytest
import uvicorn
from sanic.asgi import MockTransport
from sanic.exceptions import InvalidUsage
from sanic.websocket import WebSocketConnection
@pytest.fixture
def message_stack():
return deque()
@pytest.fixture
def receive(message_stack):
async def _receive():
return message_stack.popleft()
return _receive
@pytest.fixture
def send(message_stack):
async def _send(message):
message_stack.append(message)
return _send
@pytest.fixture
def transport(message_stack, receive, send):
return MockTransport({}, receive, send)
@pytest.fixture
# @pytest.mark.asyncio
def protocol(transport, loop):
return transport.get_protocol()
def test_listeners_triggered(app):
before_server_start = False
after_server_start = False
before_server_stop = False
after_server_stop = False
@app.listener("before_server_start")
def do_before_server_start(*args, **kwargs):
nonlocal before_server_start
before_server_start = True
@app.listener("after_server_start")
def do_after_server_start(*args, **kwargs):
nonlocal after_server_start
after_server_start = True
@app.listener("before_server_stop")
def do_before_server_stop(*args, **kwargs):
nonlocal before_server_stop
before_server_stop = True
@app.listener("after_server_stop")
def do_after_server_stop(*args, **kwargs):
nonlocal after_server_stop
after_server_stop = True
class CustomServer(uvicorn.Server):
def install_signal_handlers(self):
pass
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
server = CustomServer(config=config)
with pytest.warns(UserWarning):
server.run()
for task in asyncio.Task.all_tasks():
task.cancel()
assert before_server_start
assert after_server_start
assert before_server_stop
assert after_server_stop
def test_listeners_triggered_async(app):
before_server_start = False
after_server_start = False
before_server_stop = False
after_server_stop = False
@app.listener("before_server_start")
async def do_before_server_start(*args, **kwargs):
nonlocal before_server_start
before_server_start = True
@app.listener("after_server_start")
async def do_after_server_start(*args, **kwargs):
nonlocal after_server_start
after_server_start = True
@app.listener("before_server_stop")
async def do_before_server_stop(*args, **kwargs):
nonlocal before_server_stop
before_server_stop = True
@app.listener("after_server_stop")
async def do_after_server_stop(*args, **kwargs):
nonlocal after_server_stop
after_server_stop = True
class CustomServer(uvicorn.Server):
def install_signal_handlers(self):
pass
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
server = CustomServer(config=config)
with pytest.warns(UserWarning):
server.run()
for task in asyncio.Task.all_tasks():
task.cancel()
assert before_server_start
assert after_server_start
assert before_server_stop
assert after_server_stop
@pytest.mark.asyncio
async def test_mockprotocol_events(protocol):
assert protocol._not_paused.is_set()
protocol.pause_writing()
assert not protocol._not_paused.is_set()
protocol.resume_writing()
assert protocol._not_paused.is_set()
@pytest.mark.asyncio
async def test_protocol_push_data(protocol, message_stack):
text = b"hello"
await protocol.push_data(text)
await protocol.complete()
assert len(message_stack) == 2
message = message_stack.popleft()
assert message["type"] == "http.response.body"
assert message["more_body"]
assert message["body"] == text
message = message_stack.popleft()
assert message["type"] == "http.response.body"
assert not message["more_body"]
assert message["body"] == b""
@pytest.mark.asyncio
async def test_websocket_send(send, receive, message_stack):
text_string = "hello"
text_bytes = b"hello"
ws = WebSocketConnection(send, receive)
await ws.send(text_string)
await ws.send(text_bytes)
assert len(message_stack) == 2
message = message_stack.popleft()
assert message["type"] == "websocket.send"
assert message["text"] == text_string
assert "bytes" not in message
message = message_stack.popleft()
assert message["type"] == "websocket.send"
assert message["bytes"] == text_bytes
assert "text" not in message
@pytest.mark.asyncio
async def test_websocket_receive(send, receive, message_stack):
msg = {"text": "hello", "type": "websocket.receive"}
message_stack.append(msg)
ws = WebSocketConnection(send, receive)
text = await ws.receive()
assert text == msg["text"]
def test_improper_websocket_connection(transport, send, receive):
with pytest.raises(InvalidUsage):
transport.get_websocket_connection()
transport.create_websocket_connection(send, receive)
connection = transport.get_websocket_connection()
assert isinstance(connection, WebSocketConnection)

View File

@@ -0,0 +1,5 @@
from sanic.testing import SanicASGITestClient
def test_asgi_client_instantiation(app):
assert isinstance(app.asgi_client, SanicASGITestClient)

View File

@@ -239,6 +239,7 @@ def test_config_access_log_passing_in_run(app):
assert app.config.ACCESS_LOG == True
@pytest.mark.asyncio
async def test_config_access_log_passing_in_create_server(app):
assert app.config.ACCESS_LOG == True

View File

@@ -27,6 +27,24 @@ def test_cookies(app):
assert response_cookies["right_back"].value == "at you"
@pytest.mark.asyncio
async def test_cookies_asgi(app):
@app.route("/")
def handler(request):
response = text("Cookies are: {}".format(request.cookies["test"]))
response.cookies["right_back"] = "at you"
return response
request, response = await app.asgi_client.get(
"/", cookies={"test": "working!"}
)
response_cookies = SimpleCookie()
response_cookies.load(response.headers.get("set-cookie", {}))
assert response.text == "Cookies are: working!"
assert response_cookies["right_back"].value == "at you"
@pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)])
def test_false_cookies_encoded(app, httponly, expected):
@app.route("/")

View File

@@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app):
@pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"])
async def test_redirect_with_params(app, sanic_client, test_str):
def test_redirect_with_params(app, test_str):
use_in_uri = quote(test_str)
@app.route("/api/v1/test/<test>/")
async def init_handler(request, test):
assert test == test_str
return redirect("/api/v2/test/{}/".format(quote(test)))
return redirect("/api/v2/test/{}/".format(use_in_uri))
@app.route("/api/v2/test/<test>/")
async def target_handler(request, test):
assert test == test_str
return text("OK")
test_cli = await sanic_client(app)
response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str)))
_, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri))
assert response.status == 200
txt = await response.text()
assert txt == "OK"
assert response.content == b"OK"

View File

@@ -1,10 +1,13 @@
import asyncio
import contextlib
import pytest
from sanic.response import stream, text
async def test_request_cancel_when_connection_lost(loop, app, sanic_client):
@pytest.mark.asyncio
async def test_request_cancel_when_connection_lost(app):
app.still_serving_cancelled_request = False
@app.get("/")
@@ -14,10 +17,9 @@ async def test_request_cancel_when_connection_lost(loop, app, sanic_client):
app.still_serving_cancelled_request = True
return text("OK")
test_cli = await sanic_client(app)
# schedule client call
task = loop.create_task(test_cli.get("/"))
loop = asyncio.get_event_loop()
task = loop.create_task(app.asgi_client.get("/"))
loop.call_later(0.01, task)
await asyncio.sleep(0.5)
@@ -33,7 +35,8 @@ async def test_request_cancel_when_connection_lost(loop, app, sanic_client):
assert app.still_serving_cancelled_request is False
async def test_stream_request_cancel_when_conn_lost(loop, app, sanic_client):
@pytest.mark.asyncio
async def test_stream_request_cancel_when_conn_lost(app):
app.still_serving_cancelled_request = False
@app.post("/post/<id>", stream=True)
@@ -53,10 +56,9 @@ async def test_stream_request_cancel_when_conn_lost(loop, app, sanic_client):
return stream(streaming)
test_cli = await sanic_client(app)
# schedule client call
task = loop.create_task(test_cli.post("/post/1"))
loop = asyncio.get_event_loop()
task = loop.create_task(app.asgi_client.post("/post/1"))
loop.call_later(0.01, task)
await asyncio.sleep(0.5)

View File

@@ -1,4 +1,5 @@
import pytest
from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed
from sanic.request import StreamBuffer
@@ -42,13 +43,15 @@ def test_request_stream_method_view(app):
assert response.text == data
@pytest.mark.parametrize("headers, expect_raise_exception", [
({"EXPECT": "100-continue"}, False),
({"EXPECT": "100-continue-extra"}, True),
])
@pytest.mark.parametrize(
"headers, expect_raise_exception",
[
({"EXPECT": "100-continue"}, False),
({"EXPECT": "100-continue-extra"}, True),
],
)
def test_request_stream_100_continue(app, headers, expect_raise_exception):
class SimpleView(HTTPMethodView):
@stream_decorator
async def post(self, request):
assert isinstance(request.stream, StreamBuffer)
@@ -65,12 +68,18 @@ def test_request_stream_100_continue(app, headers, expect_raise_exception):
assert app.is_request_stream is True
if not expect_raise_exception:
request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"})
request, response = app.test_client.post(
"/method_view", data=data, headers={"EXPECT": "100-continue"}
)
assert response.status == 200
assert response.text == data
else:
with pytest.raises(ValueError) as e:
app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"})
app.test_client.post(
"/method_view",
data=data,
headers={"EXPECT": "100-continue-extra"},
)
assert "Unknown Expect: 100-continue-extra" in str(e)
@@ -188,6 +197,121 @@ def test_request_stream_app(app):
assert response.text == data
@pytest.mark.asyncio
async def test_request_stream_app_asgi(app):
"""for self.is_request_stream = True and decorators"""
@app.get("/get")
async def get(request):
assert request.stream is None
return text("GET")
@app.head("/head")
async def head(request):
assert request.stream is None
return text("HEAD")
@app.delete("/delete")
async def delete(request):
assert request.stream is None
return text("DELETE")
@app.options("/options")
async def options(request):
assert request.stream is None
return text("OPTIONS")
@app.post("/_post/<id>")
async def _post(request, id):
assert request.stream is None
return text("_POST")
@app.post("/post/<id>", stream=True)
async def post(request, id):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
@app.put("/_put")
async def _put(request):
assert request.stream is None
return text("_PUT")
@app.put("/put", stream=True)
async def put(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
@app.patch("/_patch")
async def _patch(request):
assert request.stream is None
return text("_PATCH")
@app.patch("/patch", stream=True)
async def patch(request):
assert isinstance(request.stream, StreamBuffer)
result = ""
while True:
body = await request.stream.read()
if body is None:
break
result += body.decode("utf-8")
return text(result)
assert app.is_request_stream is True
request, response = await app.asgi_client.get("/get")
assert response.status == 200
assert response.text == "GET"
request, response = await app.asgi_client.head("/head")
assert response.status == 200
assert response.text == ""
request, response = await app.asgi_client.delete("/delete")
assert response.status == 200
assert response.text == "DELETE"
request, response = await app.asgi_client.options("/options")
assert response.status == 200
assert response.text == "OPTIONS"
request, response = await app.asgi_client.post("/_post/1", data=data)
assert response.status == 200
assert response.text == "_POST"
request, response = await app.asgi_client.post("/post/1", data=data)
assert response.status == 200
assert response.text == data
request, response = await app.asgi_client.put("/_put", data=data)
assert response.status == 200
assert response.text == "_PUT"
request, response = await app.asgi_client.put("/put", data=data)
assert response.status == 200
assert response.text == data
request, response = await app.asgi_client.patch("/_patch", data=data)
assert response.status == 200
assert response.text == "_PATCH"
request, response = await app.asgi_client.patch("/patch", data=data)
assert response.status == 200
assert response.text == data
def test_request_stream_handle_exception(app):
"""for handling exceptions properly"""

File diff suppressed because it is too large Load Diff

View File

@@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked(
async def mock_drain():
pass
def mock_push_data(data):
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data
@@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
async def mock_drain():
pass
def mock_push_data(data):
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data

View File

@@ -474,6 +474,19 @@ def test_websocket_route(app, url):
assert ev.is_set()
@pytest.mark.asyncio
@pytest.mark.parametrize("url", ["/ws", "ws"])
async def test_websocket_route_asgi(app, url):
ev = asyncio.Event()
@app.websocket(url)
async def handler(request, ws):
ev.set()
request, response = await app.asgi_client.websocket(url)
assert ev.is_set()
def test_websocket_route_with_subprotocols(app):
results = []

View File

@@ -76,6 +76,7 @@ def test_all_listeners(app):
assert app.name + listener_name == output.pop()
@pytest.mark.asyncio
async def test_trigger_before_events_create_server(app):
class MySanicDb:
pass