Merge branch 'main' into zhiwei/bp-copy

This commit is contained in:
Zhiwei
2021-08-06 02:10:40 -05:00
committed by GitHub
36 changed files with 976 additions and 530 deletions

View File

@@ -1,3 +1,5 @@
import asyncio
import logging
import random
import re
import string
@@ -9,10 +11,12 @@ from typing import Tuple
import pytest
from sanic_routing.exceptions import RouteExists
from sanic_testing.testing import PORT
from sanic import Sanic
from sanic.constants import HTTP_METHODS
from sanic.router import Router
from sanic.touchup.service import TouchUp
slugify = re.compile(r"[^a-zA-Z0-9_\-]")
@@ -23,11 +27,6 @@ if sys.platform in ["win32", "cygwin"]:
collect_ignore = ["test_worker.py"]
@pytest.fixture
def caplog(caplog):
yield caplog
async def _handler(request):
"""
Dummy placeholder method used for route resolver when creating a new
@@ -41,33 +40,32 @@ async def _handler(request):
TYPE_TO_GENERATOR_MAP = {
"string": lambda: "".join(
"str": lambda: "".join(
[random.choice(string.ascii_lowercase) for _ in range(4)]
),
"int": lambda: random.choice(range(1000000)),
"number": lambda: random.random(),
"float": lambda: random.random(),
"alpha": lambda: "".join(
[random.choice(string.ascii_lowercase) for _ in range(4)]
),
"uuid": lambda: str(uuid.uuid1()),
}
CACHE = {}
class RouteStringGenerator:
ROUTE_COUNT_PER_DEPTH = 100
HTTP_METHODS = HTTP_METHODS
ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"]
ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"]
def generate_random_direct_route(self, max_route_depth=4):
routes = []
for depth in range(1, max_route_depth + 1):
for _ in range(self.ROUTE_COUNT_PER_DEPTH):
route = "/".join(
[
TYPE_TO_GENERATOR_MAP.get("string")()
for _ in range(depth)
]
[TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)]
)
route = route.replace(".", "", -1)
route_detail = (random.choice(self.HTTP_METHODS), route)
@@ -83,7 +81,7 @@ class RouteStringGenerator:
new_route_part = "/".join(
[
"<{}:{}>".format(
TYPE_TO_GENERATOR_MAP.get("string")(),
TYPE_TO_GENERATOR_MAP.get("str")(),
random.choice(self.ROUTE_PARAM_TYPES),
)
for _ in range(max_route_depth - current_length)
@@ -98,7 +96,7 @@ class RouteStringGenerator:
def generate_url_for_template(template):
url = template
for pattern, param_type in re.findall(
re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"),
re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"),
template,
):
value = TYPE_TO_GENERATOR_MAP.get(param_type)()
@@ -141,5 +139,33 @@ def url_param_generator():
@pytest.fixture(scope="function")
def app(request):
if not CACHE:
for target, method_name in TouchUp._registry:
CACHE[method_name] = getattr(target, method_name)
app = Sanic(slugify.sub("-", request.node.name))
return app
yield app
for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name])
@pytest.fixture(scope="function")
def run_startup(caplog):
def run(app):
nonlocal caplog
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
with caplog.at_level(logging.DEBUG):
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop._stopping = False
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
return caplog.record_tuples
return run

View File

@@ -7,7 +7,7 @@ import uvicorn
from sanic import Sanic
from sanic.asgi import MockTransport
from sanic.exceptions import InvalidUsage
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable
from sanic.request import Request
from sanic.response import json, text
from sanic.websocket import WebSocketConnection
@@ -346,3 +346,32 @@ async def test_content_type(app):
_, response = await app.asgi_client.get("/custom")
assert response.headers.get("content-type") == "somethingelse"
@pytest.mark.asyncio
async def test_request_handle_exception(app):
@app.get("/error-prone")
def _request(request):
raise ServiceUnavailable(message="Service unavailable")
_, response = await app.asgi_client.get("/wrong-path")
assert response.status_code == 404
_, response = await app.asgi_client.get("/error-prone")
assert response.status_code == 503
@pytest.mark.asyncio
async def test_request_exception_suppressed_by_middleware(app):
@app.get("/error-prone")
def _request(request):
raise ServiceUnavailable(message="Service unavailable")
@app.on_request
def forbidden(request):
raise Forbidden(message="forbidden")
_, response = await app.asgi_client.get("/wrong-path")
assert response.status_code == 403
_, response = await app.asgi_client.get("/error-prone")
assert response.status_code == 403

View File

@@ -89,7 +89,7 @@ def test_debug(cmd):
out, err, exitcode = capture(command)
lines = out.split(b"\n")
app_info = lines[9]
app_info = lines[26]
info = json.loads(app_info)
assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO
@@ -103,7 +103,7 @@ def test_auto_reload(cmd):
out, err, exitcode = capture(command)
lines = out.split(b"\n")
app_info = lines[9]
app_info = lines[26]
info = json.loads(app_info)
assert info["debug"] is False
@@ -118,7 +118,7 @@ def test_access_logs(cmd, expected):
out, err, exitcode = capture(command)
lines = out.split(b"\n")
app_info = lines[9]
app_info = lines[26]
info = json.loads(app_info)
assert info["access_log"] is expected

View File

@@ -1,6 +1,5 @@
import asyncio
from queue import Queue
from threading import Event
from sanic.response import text
@@ -13,8 +12,6 @@ def test_create_task(app):
await asyncio.sleep(0.05)
e.set()
app.add_task(coro)
@app.route("/early")
def not_set(request):
return text(str(e.is_set()))
@@ -24,24 +21,30 @@ def test_create_task(app):
await asyncio.sleep(0.1)
return text(str(e.is_set()))
app.add_task(coro)
request, response = app.test_client.get("/early")
assert response.body == b"False"
app.signal_router.reset()
app.add_task(coro)
request, response = app.test_client.get("/late")
assert response.body == b"True"
def test_create_task_with_app_arg(app):
q = Queue()
@app.after_server_start
async def setup_q(app, _):
app.ctx.q = asyncio.Queue()
@app.route("/")
def not_set(request):
return "hello"
async def not_set(request):
return text(await request.app.ctx.q.get())
async def coro(app):
q.put(app.name)
await app.ctx.q.put(app.name)
app.add_task(coro)
request, response = app.test_client.get("/")
assert q.get() == "test_create_task_with_app_arg"
_, response = app.test_client.get("/")
assert response.text == "test_create_task_with_app_arg"

View File

@@ -127,7 +127,6 @@ def test_html_traceback_output_in_debug_mode():
soup = BeautifulSoup(response.body, "html.parser")
html = str(soup)
assert "response = handler(request, **kwargs)" in html
assert "handler_4" in html
assert "foo = bar" in html
@@ -151,7 +150,6 @@ def test_chained_exception_handler():
soup = BeautifulSoup(response.body, "html.parser")
html = str(soup)
assert "response = handler(request, **kwargs)" in html
assert "handler_6" in html
assert "foo = 1 / arg" in html
assert "ValueError" in html

View File

@@ -2,16 +2,13 @@ import asyncio
import platform
from asyncio import sleep as aio_sleep
from json import JSONDecodeError
from os import environ
import httpcore
import httpx
import pytest
from sanic_testing.testing import HOST, SanicTestClient
from sanic_testing.reusable import ReusableClient
from sanic import Sanic, server
from sanic import Sanic
from sanic.compat import OS_IS_WINDOWS
from sanic.response import text
@@ -21,164 +18,6 @@ CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool):
last_reused_connection = None
async def _get_connection_from_pool(self, *args, **kwargs):
conn = await super()._get_connection_from_pool(*args, **kwargs)
self.__class__.last_reused_connection = conn
return conn
class ResusableSanicSession(httpx.AsyncClient):
def __init__(self, *args, **kwargs) -> None:
transport = ReusableSanicConnectionPool()
super().__init__(transport=transport, *args, **kwargs)
class ReuseableSanicTestClient(SanicTestClient):
def __init__(self, app, loop=None):
super().__init__(app)
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._server = None
self._tcp_connector = None
self._session = None
def get_new_session(self):
return ResusableSanicSession()
# Copied from SanicTestClient, but with some changes to reuse the
# same loop for the same app.
def _sanic_endpoint_test(
self,
method="get",
uri="/",
gather_request=True,
debug=False,
server_kwargs=None,
*request_args,
**request_kwargs,
):
loop = self._loop
results = [None, None]
exceptions = []
server_kwargs = server_kwargs or {"return_asyncio_server": True}
if gather_request:
def _collect_request(request):
if results[0] is None:
results[0] = request
self.app.request_middleware.appendleft(_collect_request)
if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
url = uri
else:
uri = uri if uri.startswith("/") else f"/{uri}"
scheme = "http"
url = f"{scheme}://{HOST}:{PORT}{uri}"
@self.app.listener("after_server_start")
async def _collect_response(loop):
try:
response = await self._local_request(
method, url, *request_args, **request_kwargs
)
results[-1] = response
except Exception as e2:
exceptions.append(e2)
if self._server is not None:
_server = self._server
else:
_server_co = self.app.create_server(
host=HOST, debug=debug, port=PORT, **server_kwargs
)
server.trigger_events(
self.app.listeners["before_server_start"], loop
)
try:
loop._stopping = False
_server = loop.run_until_complete(_server_co)
except Exception as e1:
raise e1
self._server = _server
server.trigger_events(self.app.listeners["after_server_start"], loop)
self.app.listeners["after_server_start"].pop()
if exceptions:
raise ValueError(f"Exception during request: {exceptions}")
if gather_request:
self.app.request_middleware.pop()
try:
request, response = results
return request, response
except Exception:
raise ValueError(
f"Request and response object expected, got ({results})"
)
else:
try:
return results[-1]
except Exception:
raise ValueError(f"Request object expected, got ({results})")
def kill_server(self):
try:
if self._server:
self._server.close()
self._loop.run_until_complete(self._server.wait_closed())
self._server = None
if self._session:
self._loop.run_until_complete(self._session.aclose())
self._session = None
except Exception as e3:
raise e3
# Copied from SanicTestClient, but with some changes to reuse the
# same TCPConnection and the sane ClientSession more than once.
# Note, you cannot use the same session if you are in a _different_
# loop, so the changes above are required too.
async def _local_request(self, method, url, *args, **kwargs):
raw_cookies = kwargs.pop("raw_cookies", None)
request_keepalive = kwargs.pop(
"request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"]
)
if not self._session:
self._session = self.get_new_session()
try:
response = await getattr(self._session, method.lower())(
url, timeout=request_keepalive, *args, **kwargs
)
except NameError:
raise Exception(response.status_code)
try:
response.json = response.json()
except (JSONDecodeError, UnicodeDecodeError):
response.json = None
response.body = await response.aread()
response.status = response.status_code
response.content_type = response.headers.get("content-type")
if raw_cookies:
response.raw_cookies = {}
for cookie in response.cookies:
response.raw_cookies[cookie.name] = cookie
return response
keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse")
keep_alive_app_client_timeout = Sanic("test_ka_client_timeout")
keep_alive_app_server_timeout = Sanic("test_ka_server_timeout")
@@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully
reuse the existing connection."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT)
with client:
headers = {"Connection": "keep-alive"}
request, response = client.get("/1", headers=headers)
assert response.status == 200
assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(1))
request, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
assert ReusableSanicConnectionPool.last_reused_connection
finally:
client.kill_server()
assert request.protocol.state["requests_count"] == 2
@pytest.mark.skipif(
@@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse():
def test_keep_alive_client_timeout():
"""If the server keep-alive timeout is longer than the client
keep-alive timeout, client will try to create a new connection here."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(
keep_alive_app_client_timeout, loop=loop, port=PORT
)
with client:
headers = {"Connection": "keep-alive"}
_, response = client.get("/1", headers=headers, request_keepalive=1)
request, response = client.get("/1", headers=headers, timeout=1)
assert response.status == 200
assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(2))
_, response = client.get("/1", request_keepalive=1)
assert ReusableSanicConnectionPool.last_reused_connection is None
finally:
client.kill_server()
request, response = client.get("/1", timeout=1)
assert request.protocol.state["requests_count"] == 1
@pytest.mark.skipif(
@@ -277,22 +117,23 @@ def test_keep_alive_server_timeout():
keep-alive timeout, the client will either a 'Connection reset' error
_or_ a new connection. Depending on how the event-loop handles the
broken server connection."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(
keep_alive_app_server_timeout, loop=loop, port=PORT
)
with client:
headers = {"Connection": "keep-alive"}
_, response = client.get("/1", headers=headers, request_keepalive=60)
request, response = client.get("/1", headers=headers, timeout=60)
assert response.status == 200
assert response.text == "OK"
assert request.protocol.state["requests_count"] == 1
loop.run_until_complete(aio_sleep(3))
_, response = client.get("/1", request_keepalive=60)
request, response = client.get("/1", timeout=60)
assert ReusableSanicConnectionPool.last_reused_connection is None
finally:
client.kill_server()
assert request.protocol.state["requests_count"] == 1
@pytest.mark.skipif(
@@ -300,10 +141,10 @@ def test_keep_alive_server_timeout():
reason="Not testable with current client",
)
def test_keep_alive_connection_context():
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_context, loop)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT)
with client:
headers = {"Connection": "keep-alive"}
request1, _ = client.post("/ctx", headers=headers)
@@ -315,5 +156,4 @@ def test_keep_alive_connection_context():
assert (
request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello"
)
finally:
client.kill_server()
assert request2.protocol.state["requests_count"] == 2

View File

@@ -5,6 +5,7 @@ import uuid
from importlib import reload
from io import StringIO
from unittest.mock import Mock
import pytest
@@ -51,7 +52,7 @@ def test_log(app):
def test_logging_defaults():
# reset_logging()
app = Sanic("test_logging")
Sanic("test_logging")
for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]:
assert (
@@ -87,7 +88,7 @@ def test_logging_pass_customer_logconfig():
"format"
] = "%(asctime)s - (%(name)s)[%(levelname)s]: %(message)s"
app = Sanic("test_logging", log_config=modified_config)
Sanic("test_logging", log_config=modified_config)
for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]:
assert fmt._fmt == modified_config["formatters"]["generic"]["format"]
@@ -208,6 +209,56 @@ def test_logging_modified_root_logger_config():
modified_config = LOGGING_CONFIG_DEFAULTS
modified_config["loggers"]["sanic.root"]["level"] = "DEBUG"
app = Sanic("test_logging", log_config=modified_config)
Sanic("test_logging", log_config=modified_config)
assert logging.getLogger("sanic.root").getEffectiveLevel() == logging.DEBUG
def test_access_log_client_ip_remote_addr(monkeypatch):
access = Mock()
monkeypatch.setattr(sanic.http, "access_logger", access)
app = Sanic("test_logging")
app.config.PROXIES_COUNT = 2
@app.route("/")
async def handler(request):
return text(request.remote_addr)
headers = {"X-Forwarded-For": "1.1.1.1, 2.2.2.2"}
request, response = app.test_client.get("/", headers=headers)
assert request.remote_addr == "1.1.1.1"
access.info.assert_called_with(
"",
extra={
"status": 200,
"byte": len(response.content),
"host": f"{request.remote_addr}:{request.port}",
"request": f"GET {request.scheme}://{request.host}/",
},
)
def test_access_log_client_ip_reqip(monkeypatch):
access = Mock()
monkeypatch.setattr(sanic.http, "access_logger", access)
app = Sanic("test_logging")
@app.route("/")
async def handler(request):
return text(request.ip)
request, response = app.test_client.get("/")
access.info.assert_called_with(
"",
extra={
"status": 200,
"byte": len(response.content),
"host": f"{request.ip}:{request.port}",
"request": f"GET {request.scheme}://{request.host}/",
},
)

View File

@@ -6,85 +6,37 @@ from sanic_testing.testing import PORT
from sanic.config import BASE_LOGO
def test_logo_base(app, caplog):
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
def test_logo_base(app, run_startup):
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == BASE_LOGO
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == BASE_LOGO
def test_logo_false(app, caplog):
def test_logo_false(app, caplog, run_startup):
app.config.LOGO = False
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
banner, port = caplog.record_tuples[0][2].rsplit(":", 1)
assert caplog.record_tuples[0][1] == logging.INFO
banner, port = logs[0][2].rsplit(":", 1)
assert logs[0][1] == logging.INFO
assert banner == "Goin' Fast @ http://127.0.0.1"
assert int(port) > 0
def test_logo_true(app, caplog):
def test_logo_true(app, run_startup):
app.config.LOGO = True
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == BASE_LOGO
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == BASE_LOGO
def test_logo_custom(app, caplog):
def test_logo_custom(app, run_startup):
app.config.LOGO = "My Custom Logo"
server = app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
logs = run_startup(app)
with caplog.at_level(logging.DEBUG):
_server = loop.run_until_complete(server)
_server.close()
loop.run_until_complete(_server.wait_closed())
app.stop()
assert caplog.record_tuples[0][1] == logging.DEBUG
assert caplog.record_tuples[0][2] == "My Custom Logo"
assert logs[0][1] == logging.DEBUG
assert logs[0][2] == "My Custom Logo"

View File

@@ -1,7 +1,7 @@
from httpx import AsyncByteStream
from sanic_testing.reusable import ReusableClient
from sanic.response import json
from sanic.response import json, text
def test_no_body_requests(app):
@@ -80,3 +80,26 @@ def test_streaming_body_requests(app):
assert response1.json["data"] == response2.json["data"] == data
assert response1.json["request_id"] != response2.json["request_id"]
assert response1.json["connection_id"] == response2.json["connection_id"]
def test_bad_headers(app):
@app.get("/")
async def handler(request):
return text("")
@app.on_response
async def reqid(request, response):
response.headers["x-request-id"] = request.id
client = ReusableClient(app, port=1234)
bad_headers = {"bad": "bad" * 5_000}
with client:
_, response1 = client.get("/")
_, response2 = client.get("/", headers=bad_headers)
assert response1.status == 200
assert response2.status == 413
assert (
response1.headers["x-request-id"] != response2.headers["x-request-id"]
)

View File

@@ -2,6 +2,7 @@ import asyncio
import httpcore
import httpx
import pytest
from sanic_testing.testing import SanicTestClient
@@ -48,42 +49,51 @@ class DelayableSanicTestClient(SanicTestClient):
return DelayableSanicSession(request_delay=self._request_delay)
request_timeout_default_app = Sanic("test_request_timeout_default")
request_no_timeout_app = Sanic("test_request_no_timeout")
request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6
request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6
@pytest.fixture
def request_no_timeout_app():
app = Sanic("test_request_no_timeout")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler2(request):
return text("OK")
return app
@request_timeout_default_app.route("/1")
async def handler1(request):
return text("OK")
@pytest.fixture
def request_timeout_default_app():
app = Sanic("test_request_timeout_default")
app.config.REQUEST_TIMEOUT = 0.6
@app.route("/1")
async def handler1(request):
return text("OK")
@app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
return app
@request_no_timeout_app.route("/1")
async def handler2(request):
return text("OK")
@request_timeout_default_app.websocket("/ws1")
async def ws_handler1(request, ws):
await ws.send("OK")
def test_default_server_error_request_timeout():
def test_default_server_error_request_timeout(request_timeout_default_app):
client = DelayableSanicTestClient(request_timeout_default_app, 2)
request, response = client.get("/1")
_, response = client.get("/1")
assert response.status == 408
assert "Request Timeout" in response.text
def test_default_server_error_request_dont_timeout():
def test_default_server_error_request_dont_timeout(request_no_timeout_app):
client = DelayableSanicTestClient(request_no_timeout_app, 0.2)
request, response = client.get("/1")
_, response = client.get("/1")
assert response.status == 200
assert response.text == "OK"
def test_default_server_error_websocket_request_timeout():
def test_default_server_error_websocket_request_timeout(
request_timeout_default_app,
):
headers = {
"Upgrade": "websocket",
@@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout():
}
client = DelayableSanicTestClient(request_timeout_default_app, 2)
request, response = client.get("/ws1", headers=headers)
_, response = client.get("/ws1", headers=headers)
assert response.status == 408
assert "Request Timeout" in response.text

View File

@@ -654,41 +654,46 @@ def test_websocket_route_invalid_handler(app):
@pytest.mark.asyncio
@pytest.mark.parametrize("url", ["/ws", "ws"])
async def test_websocket_route_asgi(app, url):
ev = asyncio.Event()
@app.after_server_start
async def setup_ev(app, _):
app.ctx.ev = asyncio.Event()
@app.websocket(url)
async def handler(request, ws):
ev.set()
request.app.ctx.ev.set()
request, response = await app.asgi_client.websocket(url)
assert ev.is_set()
@app.get("/ev")
async def check(request):
return json({"set": request.app.ctx.ev.is_set()})
_, response = await app.asgi_client.websocket(url)
_, response = await app.asgi_client.get("/")
assert response.json["set"]
def test_websocket_route_with_subprotocols(app):
results = []
@pytest.mark.parametrize(
"subprotocols,expected",
(
(["bar"], "bar"),
(["bar", "foo"], "bar"),
(["baz"], None),
(None, None),
),
)
def test_websocket_route_with_subprotocols(app, subprotocols, expected):
results = "unset"
@app.websocket("/ws", subprotocols=["foo", "bar"])
async def handler(request, ws):
results.append(ws.subprotocol)
nonlocal results
results = ws.subprotocol
assert ws.subprotocol is not None
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"])
assert response.opened is True
assert results == ["bar"]
_, response = SanicTestClient(app).websocket(
"/ws", subprotocols=["bar", "foo"]
"/ws", subprotocols=subprotocols
)
assert response.opened is True
assert results == ["bar", "bar"]
_, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"])
assert response.opened is True
assert results == ["bar", "bar", None]
_, response = SanicTestClient(app).websocket("/ws")
assert response.opened is True
assert results == ["bar", "bar", None, None]
assert results == expected
@pytest.mark.parametrize("strict_slashes", [True, False, None])

View File

@@ -8,7 +8,7 @@ import pytest
from sanic_testing.testing import HOST, PORT
from sanic.exceptions import InvalidUsage
from sanic.exceptions import InvalidUsage, SanicException
AVAILABLE_LISTENERS = [
@@ -103,7 +103,11 @@ async def test_trigger_before_events_create_server(app):
async def init_db(app, loop):
app.db = MySanicDb()
await app.create_server(debug=True, return_asyncio_server=True, port=PORT)
srv = await app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
await srv.startup()
await srv.before_start()
assert hasattr(app, "db")
assert isinstance(app.db, MySanicDb)
@@ -157,14 +161,15 @@ def test_create_server_trigger_events(app):
serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
server = loop.run_until_complete(serv_task)
server.after_start()
loop.run_until_complete(server.startup())
loop.run_until_complete(server.after_start())
try:
loop.run_forever()
except KeyboardInterrupt as e:
except KeyboardInterrupt:
loop.stop()
finally:
# Run the on_stop function if provided
server.before_stop()
loop.run_until_complete(server.before_stop())
# Wait for server to close
close_task = server.close()
@@ -174,5 +179,19 @@ def test_create_server_trigger_events(app):
signal.stopped = True
for connection in server.connections:
connection.close_if_idle()
server.after_stop()
loop.run_until_complete(server.after_stop())
assert flag1 and flag2 and flag3
@pytest.mark.asyncio
async def test_missing_startup_raises_exception(app):
@app.listener("before_server_start")
async def init_db(app, loop):
...
srv = await app.create_server(
debug=True, return_asyncio_server=True, port=PORT
)
with pytest.raises(SanicException):
await srv.before_start()

View File

@@ -7,6 +7,7 @@ from unittest.mock import MagicMock
import pytest
from sanic_testing.reusable import ReusableClient
from sanic_testing.testing import HOST, PORT
from sanic.compat import ctrlc_workaround_for_windows
@@ -28,9 +29,13 @@ def set_loop(app, loop):
signal.signal = mock
else:
loop.add_signal_handler = mock
print(">>>>>>>>>>>>>>>1", id(loop))
print(">>>>>>>>>>>>>>>1", loop.add_signal_handler)
def after(app, loop):
print(">>>>>>>>>>>>>>>2", id(loop))
print(">>>>>>>>>>>>>>>2", loop.add_signal_handler)
calledq.put(mock.called)

View File

@@ -68,6 +68,7 @@ async def test_dispatch_signal_triggers_multiple_handlers(app):
app.signal_router.finalize()
assert len(app.signal_router.routes) == 3
await app.dispatch("foo.bar.baz")
assert counter == 2
@@ -331,7 +332,8 @@ def test_event_on_bp_not_registered():
"event,expected",
(
("foo.bar.baz", True),
("server.init.before", False),
("server.init.before", True),
("server.init.somethingelse", False),
("http.request.start", False),
("sanic.notice.anything", True),
),

21
tests/test_touchup.py Normal file
View File

@@ -0,0 +1,21 @@
import logging
from sanic.signals import RESERVED_NAMESPACES
from sanic.touchup import TouchUp
def test_touchup_methods(app):
assert len(TouchUp._registry) == 9
async def test_ode_removes_dispatch_events(app, caplog):
with caplog.at_level(logging.DEBUG, logger="sanic.root"):
await app._startup()
logs = caplog.record_tuples
for signal in RESERVED_NAMESPACES["http"]:
assert (
"sanic.root",
logging.DEBUG,
f"Disabling event: {signal}",
) in logs

View File

@@ -43,7 +43,15 @@ def test_routes_with_multiple_hosts(app):
)
def test_websocket_bp_route_name(app):
@pytest.mark.parametrize(
"name,expected",
(
("test_route", "/bp/route"),
("test_route2", "/bp/route2"),
("foobar_3", "/bp/route3"),
),
)
def test_websocket_bp_route_name(app, name, expected):
"""Tests that blueprint websocket route is named."""
event = asyncio.Event()
bp = Blueprint("test_bp", url_prefix="/bp")
@@ -69,22 +77,12 @@ def test_websocket_bp_route_name(app):
uri = app.url_for("test_bp.main")
assert uri == "/bp/main"
uri = app.url_for("test_bp.test_route")
assert uri == "/bp/route"
uri = app.url_for(f"test_bp.{name}")
assert uri == expected
request, response = SanicTestClient(app).websocket(uri)
assert response.opened is True
assert event.is_set()
event.clear()
uri = app.url_for("test_bp.test_route2")
assert uri == "/bp/route2"
request, response = SanicTestClient(app).websocket(uri)
assert response.opened is True
assert event.is_set()
uri = app.url_for("test_bp.foobar_3")
assert uri == "/bp/route3"
# TODO: add test with a route with multiple hosts
# TODO: add test with a route with _host in url_for