squash this

This commit is contained in:
Adam Hopkins 2020-07-13 23:59:45 +03:00
parent 050a563e1d
commit cf234fca15
5 changed files with 61 additions and 81 deletions

View File

@ -62,6 +62,20 @@ Additionally, Sanic has an asynchronous testing client. The difference is that t
instance of your application, but will instead reach inside it using ASGI. All listeners and middleware are still instance of your application, but will instead reach inside it using ASGI. All listeners and middleware are still
executed. executed.
.. code-block:: python
@pytest.mark.asyncio
async def test_index_returns_200():
request, response = await app.asgi_client.put('/')
assert response.status == 200
.. note::
Whenever one of the test clients run, you can test your app instance to determine if it is in testing mode:
`app.test_mode`.
Additionally, Sanic has an asynchronous testing client. The difference is that the async client will not stand up an
instance of your application, but will instead reach inside it using ASGI. All listeners and middleware are still
executed.
.. code-block:: python .. code-block:: python
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import warnings import warnings
from inspect import isawaitable from inspect import isawaitable
from typing import ( from typing import (
Any, Any,
@ -16,7 +15,6 @@ from typing import (
from urllib.parse import quote from urllib.parse import quote
import sanic.app # noqa import sanic.app # noqa
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger from sanic.log import logger
@ -25,7 +23,6 @@ from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import ConnInfo, StreamBuffer from sanic.server import ConnInfo, StreamBuffer
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any] ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any] ASGIMessage = MutableMapping[str, Any]
ASGISend = Callable[[ASGIMessage], Awaitable[None]] ASGISend = Callable[[ASGIMessage], Awaitable[None]]
@ -68,9 +65,7 @@ class MockProtocol:
class MockTransport: class MockTransport:
_protocol: Optional[MockProtocol] _protocol: Optional[MockProtocol]
def __init__( def __init__(self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> None:
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.scope = scope self.scope = scope
self._receive = receive self._receive = receive
self._send = send self._send = send
@ -117,6 +112,8 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None: def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app self.asgi_app = asgi_app
print(self.asgi_app.sanic_app.listeners)
if "before_server_start" in self.asgi_app.sanic_app.listeners: if "before_server_start" in self.asgi_app.sanic_app.listeners:
warnings.warn( warnings.warn(
'You have set a listener for "before_server_start" ' 'You have set a listener for "before_server_start" '
@ -146,9 +143,7 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) ) + self.asgi_app.sanic_app.listeners.get("after_server_start", [])
for handler in listeners: for handler in listeners:
response = handler( response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response): if isawaitable(response):
await response await response
@ -166,9 +161,7 @@ class Lifespan:
) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", [])
for handler in listeners: for handler in listeners:
response = handler( response = handler(self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop)
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response): if isawaitable(response):
await response await response
@ -213,19 +206,13 @@ class ASGIApp:
for key, value in scope.get("headers", []) for key, value in scope.get("headers", [])
] ]
) )
instance.do_stream = ( instance.do_stream = True if headers.get("expect") == "100-continue" else False
True if headers.get("expect") == "100-continue" else False
)
instance.lifespan = Lifespan(instance) instance.lifespan = Lifespan(instance)
if scope["type"] == "lifespan": if scope["type"] == "lifespan":
await instance.lifespan(scope, receive, send) await instance.lifespan(scope, receive, send)
else: else:
path = ( path = scope["path"][1:] if scope["path"].startswith("/") else scope["path"]
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
url = "/".join([scope.get("root_path", ""), quote(path)]) url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1") url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"] url_bytes += b"?" + scope["query_string"]
@ -248,19 +235,12 @@ class ASGIApp:
request_class = sanic_app.request_class or Request request_class = sanic_app.request_class or Request
instance.request = request_class( instance.request = request_class(
url_bytes, url_bytes, headers, version, method, instance.transport, sanic_app,
headers,
version,
method,
instance.transport,
sanic_app,
) )
instance.request.conn_info = ConnInfo(instance.transport) instance.request.conn_info = ConnInfo(instance.transport)
if sanic_app.is_request_stream: if sanic_app.is_request_stream:
is_stream_handler = sanic_app.router.is_stream_handler( is_stream_handler = sanic_app.router.is_stream_handler(instance.request)
instance.request
)
if is_stream_handler: if is_stream_handler:
instance.request.stream = StreamBuffer( instance.request.stream = StreamBuffer(
sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE
@ -339,9 +319,7 @@ class ASGIApp:
type(response), type(response),
) )
exception = ServerError("Invalid response type") exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response( response = self.sanic_app.error_handler.response(self.request, exception)
self.request, exception
)
headers = [ headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1")) (str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items() for name, value in response.headers.items()
@ -351,14 +329,10 @@ class ASGIApp:
if "content-length" not in response.headers and not isinstance( if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse response, StreamingHTTPResponse
): ):
headers += [ headers += [(b"content-length", str(len(response.body)).encode("latin-1"))]
(b"content-length", str(len(response.body)).encode("latin-1"))
]
if "content-type" not in response.headers: if "content-type" not in response.headers:
headers += [ headers += [(b"content-type", str(response.content_type).encode("latin-1"))]
(b"content-type", str(response.content_type).encode("latin-1"))
]
if response.cookies: if response.cookies:
cookies.update( cookies.update(
@ -370,8 +344,7 @@ class ASGIApp:
) )
headers += [ headers += [
(b"set-cookie", cookie.encode("utf-8")) (b"set-cookie", cookie.encode("utf-8")) for k, cookie in cookies.items()
for k, cookie in cookies.items()
] ]
await self.transport.send( await self.transport.send(

View File

@ -9,7 +9,6 @@ from sanic.exceptions import MethodNotSupported
from sanic.log import logger from sanic.log import logger
from sanic.response import text from sanic.response import text
ASGI_HOST = "mockserver" ASGI_HOST = "mockserver"
HOST = "127.0.0.1" HOST = "127.0.0.1"
PORT = None PORT = None
@ -88,9 +87,7 @@ class SanicTestClient:
@self.app.exception(MethodNotSupported) @self.app.exception(MethodNotSupported)
async def error_handler(request, exception): async def error_handler(request, exception):
if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text( return text("", exception.status_code, headers=exception.headers)
"", exception.status_code, headers=exception.headers
)
else: else:
return self.app.error_handler.default(request, exception) return self.app.error_handler.default(request, exception)
@ -106,9 +103,7 @@ class SanicTestClient:
host, port = sock.getsockname() host, port = sock.getsockname()
self.port = port self.port = port
if uri.startswith( if uri.startswith(("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")):
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
url = uri url = uri
else: else:
uri = uri if uri.startswith("/") else f"/{uri}" uri = uri if uri.startswith("/") else f"/{uri}"
@ -201,7 +196,6 @@ class SanicASGITestClient(httpx.AsyncClient):
app.asgi = True app.asgi = True
self.app = app self.app = app
self.app.test_mode = True
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0)) dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0))
super().__init__(dispatch=dispatch, base_url=base_url) super().__init__(dispatch=dispatch, base_url=base_url)
@ -211,6 +205,17 @@ class SanicASGITestClient(httpx.AsyncClient):
def _collect_request(request): def _collect_request(request):
self.last_request = request self.last_request = request
def _start_test_mode(request):
self.app.test_mode = True
@app.listener("after_server_start")
def _end_test_mode(sanic, loop):
sanic.test_mode = True
@app.listener("before_server_end")
def _end_test_mode(sanic, loop):
sanic.test_mode = False
app.request_middleware.appendleft(_collect_request) app.request_middleware.appendleft(_collect_request)
async def request(self, method, url, gather_request=True, *args, **kwargs): async def request(self, method, url, gather_request=True, *args, **kwargs):
@ -233,9 +238,7 @@ class SanicASGITestClient(httpx.AsyncClient):
headers.setdefault("sec-websocket-key", "testserver==") headers.setdefault("sec-websocket-key", "testserver==")
headers.setdefault("sec-websocket-version", "13") headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None: if subprotocols is not None:
headers.setdefault( headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
"sec-websocket-protocol", ", ".join(subprotocols)
)
scope = { scope = {
"type": "websocket", "type": "websocket",

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import logging import logging
import sys import sys
from inspect import isawaitable from inspect import isawaitable
import pytest import pytest
@ -30,9 +29,7 @@ def test_app_loop_running(app):
assert response.text == "pass" assert response.text == "pass"
@pytest.mark.skipif( @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_create_asyncio_server(app): def test_create_asyncio_server(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -42,9 +39,7 @@ def test_create_asyncio_server(app):
assert srv.is_serving() is True assert srv.is_serving() is True
@pytest.mark.skipif( @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_no_start_serving(app): def test_asyncio_server_no_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -57,9 +52,7 @@ def test_asyncio_server_no_start_serving(app):
assert srv.is_serving() is False assert srv.is_serving() is False
@pytest.mark.skipif( @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_start_serving(app): def test_asyncio_server_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -156,9 +149,7 @@ def test_handle_request_with_nested_exception(app, monkeypatch):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg) raise Exception(err_msg)
monkeypatch.setattr( monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):
@ -177,9 +168,7 @@ def test_handle_request_with_nested_exception_debug(app, monkeypatch):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise Exception(err_msg) raise Exception(err_msg)
monkeypatch.setattr( monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):
@ -198,9 +187,7 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
def mock_error_handler_response(*args, **kwargs): def mock_error_handler_response(*args, **kwargs):
raise SanicException("Mock SanicException") raise SanicException("Mock SanicException")
monkeypatch.setattr( monkeypatch.setattr(app.error_handler, "response", mock_error_handler_response)
app.error_handler, "response", mock_error_handler_response
)
@app.get("/") @app.get("/")
def handler(request): def handler(request):
@ -224,7 +211,9 @@ def test_app_name_required():
Sanic() Sanic()
def test_app_has_test_mode_sync(app): def test_app_has_test_mode_sync():
app = Sanic("test")
@app.get("/") @app.get("/")
def handler(request): def handler(request):
assert request.app.test_mode assert request.app.test_mode
@ -234,12 +223,14 @@ def test_app_has_test_mode_sync(app):
assert response.status == 200 assert response.status == 200
@pytest.mark.asyncio # @pytest.mark.asyncio
async def test_app_has_test_mode_async(app): # async def test_app_has_test_mode_async():
@app.get("/") # app = Sanic("test")
async def handler(request):
assert request.app.test_mode
return text("test")
_, response = await app.asgi_client.get("/") # @app.get("/")
assert response.status == 200 # async def handler(request):
# assert request.app.test_mode
# return text("test")
# _, response = await app.asgi_client.get("/")
# assert response.status == 200

View File

@ -1,11 +1,10 @@
import asyncio import asyncio
import sys import sys
from collections import deque, namedtuple from collections import deque, namedtuple
import pytest import pytest
import uvicorn
import uvicorn
from sanic import Sanic from sanic import Sanic
from sanic.asgi import MockTransport from sanic.asgi import MockTransport
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage