Merge branch 'master' into streaming
This commit is contained in:
commit
1aac4f546b
|
@ -95,6 +95,7 @@ class Sanic:
|
|||
self.sock = None
|
||||
self.strict_slashes = strict_slashes
|
||||
self.listeners = defaultdict(list)
|
||||
self.is_stopping = False
|
||||
self.is_running = False
|
||||
self.websocket_enabled = False
|
||||
self.websocket_tasks = set()
|
||||
|
@ -1181,6 +1182,7 @@ class Sanic:
|
|||
|
||||
try:
|
||||
self.is_running = True
|
||||
self.is_stopping = False
|
||||
if workers > 1 and os.name != "posix":
|
||||
logger.warn(
|
||||
f"Multiprocessing is currently not supported on {os.name},"
|
||||
|
@ -1213,7 +1215,9 @@ class Sanic:
|
|||
|
||||
def stop(self):
|
||||
"""This kills the Sanic"""
|
||||
get_event_loop().stop()
|
||||
if not self.is_stopping:
|
||||
self.is_stopping = True
|
||||
get_event_loop().stop()
|
||||
|
||||
async def create_server(
|
||||
self,
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
from asyncio import CancelledError
|
||||
import asyncio
|
||||
import signal
|
||||
|
||||
from sys import argv
|
||||
|
||||
from multidict import CIMultiDict # type: ignore
|
||||
|
||||
|
||||
try:
|
||||
from trio import Cancelled # type: ignore
|
||||
|
||||
CancelledErrors = tuple([CancelledError, Cancelled])
|
||||
except ImportError:
|
||||
CancelledErrors = tuple([CancelledError])
|
||||
|
||||
|
||||
class Header(CIMultiDict):
|
||||
def get_all(self, key):
|
||||
return self.getall(key, default=[])
|
||||
|
@ -20,15 +14,42 @@ class Header(CIMultiDict):
|
|||
use_trio = argv[0].endswith("hypercorn") and "trio" in argv
|
||||
|
||||
if use_trio:
|
||||
from trio import open_file as open_async, Path # type: ignore
|
||||
import trio # type: ignore
|
||||
|
||||
def stat_async(path):
|
||||
return Path(path).stat()
|
||||
|
||||
return trio.Path(path).stat()
|
||||
|
||||
open_async = trio.open_file
|
||||
CancelledErrors = tuple([asyncio.CancelledError, trio.Cancelled])
|
||||
else:
|
||||
from aiofiles import open as aio_open # type: ignore
|
||||
from aiofiles.os import stat as stat_async # type: ignore # noqa: F401
|
||||
|
||||
async def open_async(file, mode="r", **kwargs):
|
||||
return aio_open(file, mode, **kwargs)
|
||||
|
||||
CancelledErrors = tuple([asyncio.CancelledError])
|
||||
|
||||
|
||||
def ctrlc_workaround_for_windows(app):
|
||||
async def stay_active(app):
|
||||
"""Asyncio wakeups to allow receiving SIGINT in Python"""
|
||||
while not die:
|
||||
# If someone else stopped the app, just exit
|
||||
if app.is_stopping:
|
||||
return
|
||||
# Windows Python blocks signal handlers while the event loop is
|
||||
# waiting for I/O. Frequent wakeups keep interrupts flowing.
|
||||
await asyncio.sleep(0.1)
|
||||
# Can't be called from signal handler, so call it from here
|
||||
app.stop()
|
||||
|
||||
def ctrlc_handler(sig, frame):
|
||||
nonlocal die
|
||||
if die:
|
||||
raise KeyboardInterrupt("Non-graceful Ctrl+C")
|
||||
die = True
|
||||
|
||||
die = False
|
||||
signal.signal(signal.SIGINT, ctrlc_handler)
|
||||
app.add_task(stay_active)
|
||||
|
|
|
@ -84,6 +84,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
|||
self.status = status
|
||||
self.headers = Header(headers or {})
|
||||
self._cookies = None
|
||||
self.protocol = None
|
||||
|
||||
async def write(self, data):
|
||||
"""Writes a chunk of data to the streaming response.
|
||||
|
|
|
@ -11,6 +11,7 @@ from signal import signal as signal_func
|
|||
from socket import SO_REUSEADDR, SOL_SOCKET, socket
|
||||
from time import monotonic as current_time
|
||||
|
||||
from sanic.compat import ctrlc_workaround_for_windows
|
||||
from sanic.exceptions import RequestTimeout, ServiceUnavailable
|
||||
from sanic.http import Http, Stage
|
||||
from sanic.log import logger
|
||||
|
@ -25,6 +26,8 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
OS_IS_WINDOWS = os.name == "nt"
|
||||
|
||||
|
||||
class Signal:
|
||||
stopped = False
|
||||
|
@ -530,15 +533,11 @@ def serve(
|
|||
|
||||
# Register signals for graceful termination
|
||||
if register_sys_signals:
|
||||
_singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM)
|
||||
for _signal in _singals:
|
||||
try:
|
||||
loop.add_signal_handler(_signal, loop.stop)
|
||||
except NotImplementedError:
|
||||
logger.warning(
|
||||
"Sanic tried to use loop.add_signal_handler "
|
||||
"but it is not implemented on this platform."
|
||||
)
|
||||
if OS_IS_WINDOWS:
|
||||
ctrlc_workaround_for_windows(app)
|
||||
else:
|
||||
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
|
||||
loop.add_signal_handler(_signal, app.stop)
|
||||
pid = os.getpid()
|
||||
try:
|
||||
logger.info("Starting worker [%s]", pid)
|
||||
|
|
|
@ -12,7 +12,7 @@ from sanic.response import text
|
|||
|
||||
ASGI_HOST = "mockserver"
|
||||
HOST = "127.0.0.1"
|
||||
PORT = 42101
|
||||
PORT = None
|
||||
|
||||
|
||||
class SanicTestClient:
|
||||
|
@ -99,7 +99,7 @@ class SanicTestClient:
|
|||
|
||||
if self.port:
|
||||
server_kwargs = dict(
|
||||
host=host or self.host, port=self.port, **server_kwargs
|
||||
host=host or self.host, port=self.port, **server_kwargs,
|
||||
)
|
||||
host, port = host or self.host, self.port
|
||||
else:
|
||||
|
@ -107,6 +107,7 @@ class SanicTestClient:
|
|||
sock.bind((host or self.host, 0))
|
||||
server_kwargs = dict(sock=sock, **server_kwargs)
|
||||
host, port = sock.getsockname()
|
||||
self.port = port
|
||||
|
||||
if uri.startswith(
|
||||
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
|
||||
|
@ -118,6 +119,9 @@ class SanicTestClient:
|
|||
url = "{scheme}://{host}:{port}{uri}".format(
|
||||
scheme=scheme, host=host, port=port, uri=uri
|
||||
)
|
||||
# Tests construct URLs using PORT = None, which means random port not
|
||||
# known until this function is called, so fix that here
|
||||
url = url.replace(":None/", f":{port}/")
|
||||
|
||||
@self.app.listener("after_server_start")
|
||||
async def _collect_response(sanic, loop):
|
||||
|
@ -207,7 +211,7 @@ class SanicASGITestClient(httpx.AsyncClient):
|
|||
|
||||
self.app = app
|
||||
|
||||
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT))
|
||||
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0))
|
||||
super().__init__(dispatch=dispatch, base_url=base_url)
|
||||
|
||||
self.last_request = None
|
||||
|
|
|
@ -6,6 +6,7 @@ from inspect import isawaitable
|
|||
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.exceptions import SanicException
|
||||
from sanic.response import text
|
||||
|
||||
|
@ -48,6 +49,7 @@ def test_asyncio_server_no_start_serving(app):
|
|||
if not uvloop_installed():
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio_srv_coro = app.create_server(
|
||||
port=43123,
|
||||
return_asyncio_server=True,
|
||||
asyncio_server_kwargs=dict(start_serving=False),
|
||||
)
|
||||
|
@ -61,6 +63,7 @@ def test_asyncio_server_start_serving(app):
|
|||
if not uvloop_installed():
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio_srv_coro = app.create_server(
|
||||
port=43124,
|
||||
return_asyncio_server=True,
|
||||
asyncio_server_kwargs=dict(start_serving=False),
|
||||
)
|
||||
|
@ -199,10 +202,17 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
|
|||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
request, response = app.test_client.get("/")
|
||||
port = request.server_port
|
||||
assert port > 0
|
||||
assert response.status == 500
|
||||
assert "Mock SanicException" in response.text
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.ERROR,
|
||||
"Exception occurred while handling uri: 'http://127.0.0.1:42101/'",
|
||||
f"Exception occurred while handling uri: 'http://127.0.0.1:{port}/'",
|
||||
) in caplog.record_tuples
|
||||
|
||||
|
||||
def test_app_name_required():
|
||||
with pytest.deprecated_call():
|
||||
Sanic()
|
||||
|
|
|
@ -221,7 +221,7 @@ async def test_request_class_custom():
|
|||
class MyCustomRequest(Request):
|
||||
pass
|
||||
|
||||
app = Sanic(request_class=MyCustomRequest)
|
||||
app = Sanic(name=__name__, request_class=MyCustomRequest)
|
||||
|
||||
@app.get("/custom")
|
||||
def custom_request(request):
|
||||
|
|
|
@ -44,42 +44,42 @@ def test_load_from_object_string_exception(app):
|
|||
|
||||
def test_auto_load_env():
|
||||
environ["SANIC_TEST_ANSWER"] = "42"
|
||||
app = Sanic()
|
||||
app = Sanic(name=__name__)
|
||||
assert app.config.TEST_ANSWER == 42
|
||||
del environ["SANIC_TEST_ANSWER"]
|
||||
|
||||
|
||||
def test_auto_load_bool_env():
|
||||
environ["SANIC_TEST_ANSWER"] = "True"
|
||||
app = Sanic()
|
||||
app = Sanic(name=__name__)
|
||||
assert app.config.TEST_ANSWER == True
|
||||
del environ["SANIC_TEST_ANSWER"]
|
||||
|
||||
|
||||
def test_dont_load_env():
|
||||
environ["SANIC_TEST_ANSWER"] = "42"
|
||||
app = Sanic(load_env=False)
|
||||
app = Sanic(name=__name__, load_env=False)
|
||||
assert getattr(app.config, "TEST_ANSWER", None) is None
|
||||
del environ["SANIC_TEST_ANSWER"]
|
||||
|
||||
|
||||
def test_load_env_prefix():
|
||||
environ["MYAPP_TEST_ANSWER"] = "42"
|
||||
app = Sanic(load_env="MYAPP_")
|
||||
app = Sanic(name=__name__, load_env="MYAPP_")
|
||||
assert app.config.TEST_ANSWER == 42
|
||||
del environ["MYAPP_TEST_ANSWER"]
|
||||
|
||||
|
||||
def test_load_env_prefix_float_values():
|
||||
environ["MYAPP_TEST_ROI"] = "2.3"
|
||||
app = Sanic(load_env="MYAPP_")
|
||||
app = Sanic(name=__name__, load_env="MYAPP_")
|
||||
assert app.config.TEST_ROI == 2.3
|
||||
del environ["MYAPP_TEST_ROI"]
|
||||
|
||||
|
||||
def test_load_env_prefix_string_value():
|
||||
environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken"
|
||||
app = Sanic(load_env="MYAPP_")
|
||||
app = Sanic(name=__name__, load_env="MYAPP_")
|
||||
assert app.config.TEST_TOKEN == "somerandomtesttoken"
|
||||
del environ["MYAPP_TEST_TOKEN"]
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ def test_deprecated_custom_request():
|
|||
Sanic(request_class=DeprecCustomRequest)
|
||||
|
||||
def test_custom_request():
|
||||
app = Sanic(request_class=CustomRequest)
|
||||
app = Sanic(name=__name__, request_class=CustomRequest)
|
||||
|
||||
@app.route("/post", methods=["POST"])
|
||||
async def post_handler(request):
|
||||
|
|
|
@ -7,12 +7,12 @@ import httpx
|
|||
|
||||
from sanic import Sanic, server
|
||||
from sanic.response import text
|
||||
from sanic.testing import HOST, PORT, SanicTestClient
|
||||
|
||||
from sanic.testing import HOST, SanicTestClient
|
||||
|
||||
CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
|
||||
|
||||
old_conn = None
|
||||
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
|
||||
|
||||
|
||||
class ReusableSanicConnectionPool(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from importlib import reload
|
||||
|
@ -12,6 +13,7 @@ import sanic
|
|||
from sanic import Sanic
|
||||
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
|
||||
from sanic.response import text
|
||||
from sanic.testing import SanicTestClient
|
||||
|
||||
|
||||
logging_format = """module: %(module)s; \
|
||||
|
@ -126,7 +128,7 @@ def test_log_connection_lost(app, debug, monkeypatch):
|
|||
def test_logger(caplog):
|
||||
rand_string = str(uuid.uuid4())
|
||||
|
||||
app = Sanic()
|
||||
app = Sanic(name=__name__)
|
||||
|
||||
@app.get("/")
|
||||
def log_info(request):
|
||||
|
@ -136,15 +138,67 @@ def test_logger(caplog):
|
|||
with caplog.at_level(logging.INFO):
|
||||
request, response = app.test_client.get("/")
|
||||
|
||||
port = request.server_port
|
||||
|
||||
# Note: testing with random port doesn't show the banner because it doesn't
|
||||
# define host and port. This test supports both modes.
|
||||
if caplog.record_tuples[0] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
f"Goin' Fast @ http://127.0.0.1:{port}",
|
||||
):
|
||||
caplog.record_tuples.pop(0)
|
||||
|
||||
assert caplog.record_tuples[0] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"Goin' Fast @ http://127.0.0.1:42101",
|
||||
f"http://127.0.0.1:{port}/",
|
||||
)
|
||||
assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string)
|
||||
assert caplog.record_tuples[-1] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"Server Stopped",
|
||||
)
|
||||
|
||||
|
||||
def test_logger_static_and_secure(caplog):
|
||||
# Same as test_logger, except for more coverage:
|
||||
# - test_client initialised separately for static port
|
||||
# - using ssl
|
||||
rand_string = str(uuid.uuid4())
|
||||
|
||||
app = Sanic(name=__name__)
|
||||
|
||||
@app.get("/")
|
||||
def log_info(request):
|
||||
logger.info(rand_string)
|
||||
return text("hello")
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ssl_cert = os.path.join(current_dir, "certs/selfsigned.cert")
|
||||
ssl_key = os.path.join(current_dir, "certs/selfsigned.key")
|
||||
|
||||
ssl_dict = {"cert": ssl_cert, "key": ssl_key}
|
||||
|
||||
test_client = SanicTestClient(app, port=42101)
|
||||
with caplog.at_level(logging.INFO):
|
||||
request, response = test_client.get(
|
||||
f"https://127.0.0.1:{test_client.port}/",
|
||||
server_kwargs=dict(ssl=ssl_dict),
|
||||
)
|
||||
|
||||
port = test_client.port
|
||||
|
||||
assert caplog.record_tuples[0] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
f"Goin' Fast @ https://127.0.0.1:{port}",
|
||||
)
|
||||
assert caplog.record_tuples[1] == (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"http://127.0.0.1:42101/",
|
||||
f"https://127.0.0.1:{port}/",
|
||||
)
|
||||
assert caplog.record_tuples[2] == ("sanic.root", logging.INFO, rand_string)
|
||||
assert caplog.record_tuples[-1] == (
|
||||
|
|
|
@ -49,10 +49,10 @@ def test_logo_false(app, caplog):
|
|||
loop.run_until_complete(_server.wait_closed())
|
||||
app.stop()
|
||||
|
||||
banner, port = caplog.record_tuples[ROW][2].rsplit(":", 1)
|
||||
assert caplog.record_tuples[ROW][1] == logging.INFO
|
||||
assert caplog.record_tuples[ROW][
|
||||
2
|
||||
] == f"Goin' Fast @ http://127.0.0.1:{PORT}"
|
||||
assert banner == "Goin' Fast @ http://127.0.0.1"
|
||||
assert int(port) > 0
|
||||
|
||||
|
||||
def test_logo_true(app, caplog):
|
||||
|
|
|
@ -7,7 +7,7 @@ from sanic.views import CompositionView, HTTPMethodView
|
|||
from sanic.views import stream as stream_decorator
|
||||
|
||||
|
||||
data = "abc" * 10000000
|
||||
data = "abc" * 1_000_000
|
||||
|
||||
|
||||
def test_request_stream_method_view(app):
|
||||
|
|
|
@ -12,7 +12,7 @@ from sanic import Blueprint, Sanic
|
|||
from sanic.exceptions import ServerError
|
||||
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters
|
||||
from sanic.response import html, json, text
|
||||
from sanic.testing import ASGI_HOST, HOST, PORT
|
||||
from sanic.testing import ASGI_HOST, HOST, PORT, SanicTestClient
|
||||
|
||||
|
||||
# ------------------------------------------------------------ #
|
||||
|
@ -1029,7 +1029,7 @@ def test_url_attributes_no_ssl(app, path, query, expected_url):
|
|||
app.add_route(handler, path)
|
||||
|
||||
request, response = app.test_client.get(path + f"?{query}")
|
||||
assert request.url == expected_url.format(HOST, PORT)
|
||||
assert request.url == expected_url.format(HOST, request.server_port)
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
|
@ -1086,11 +1086,12 @@ def test_url_attributes_with_ssl_context(app, path, query, expected_url):
|
|||
|
||||
app.add_route(handler, path)
|
||||
|
||||
port = app.test_client.port
|
||||
request, response = app.test_client.get(
|
||||
f"https://{HOST}:{PORT}" + path + f"?{query}",
|
||||
server_kwargs={"ssl": context},
|
||||
)
|
||||
assert request.url == expected_url.format(HOST, PORT)
|
||||
assert request.url == expected_url.format(HOST, request.server_port)
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
|
@ -1125,7 +1126,7 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
|
|||
f"https://{HOST}:{PORT}" + path + f"?{query}",
|
||||
server_kwargs={"ssl": ssl_dict},
|
||||
)
|
||||
assert request.url == expected_url.format(HOST, PORT)
|
||||
assert request.url == expected_url.format(HOST, request.server_port)
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
|
@ -1917,8 +1918,9 @@ def test_request_server_port(app):
|
|||
def handler(request):
|
||||
return text("OK")
|
||||
|
||||
request, response = app.test_client.get("/", headers={"Host": "my-server"})
|
||||
assert request.server_port == app.test_client.port
|
||||
test_client = SanicTestClient(app)
|
||||
request, response = test_client.get("/", headers={"Host": "my-server"})
|
||||
assert request.server_port == test_client.port
|
||||
|
||||
|
||||
def test_request_server_port_in_host_header(app):
|
||||
|
@ -1939,7 +1941,10 @@ def test_request_server_port_in_host_header(app):
|
|||
request, response = app.test_client.get(
|
||||
"/", headers={"Host": "mal_formed:5555"}
|
||||
)
|
||||
assert request.server_port == app.test_client.port
|
||||
if PORT is None:
|
||||
assert request.server_port != 5555
|
||||
else:
|
||||
assert request.server_port == app.test_client.port
|
||||
|
||||
|
||||
def test_request_server_port_forwarded(app):
|
||||
|
@ -1979,7 +1984,7 @@ def test_server_name_and_url_for(app):
|
|||
request, response = app.test_client.get("/foo")
|
||||
assert (
|
||||
request.url_for("handler")
|
||||
== f"http://my-server:{app.test_client.port}/foo"
|
||||
== f"http://my-server:{request.server_port}/foo"
|
||||
)
|
||||
|
||||
app.config.SERVER_NAME = "https://my-server/path"
|
||||
|
@ -2040,7 +2045,7 @@ async def test_request_form_invalid_content_type_asgi(app):
|
|||
|
||||
|
||||
def test_endpoint_basic():
|
||||
app = Sanic()
|
||||
app = Sanic(name=__name__)
|
||||
|
||||
@app.route("/")
|
||||
def my_unique_handler(request):
|
||||
|
@ -2053,7 +2058,7 @@ def test_endpoint_basic():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_basic_asgi():
|
||||
app = Sanic()
|
||||
app = Sanic(name=__name__)
|
||||
|
||||
@app.route("/")
|
||||
def my_unique_handler(request):
|
||||
|
@ -2132,5 +2137,5 @@ def test_url_for_without_server_name(app):
|
|||
request, response = app.test_client.get("/sample")
|
||||
assert (
|
||||
response.json["url"]
|
||||
== f"http://127.0.0.1:{app.test_client.port}/url-for"
|
||||
== f"http://127.0.0.1:{request.server_port}/url-for"
|
||||
)
|
||||
|
|
|
@ -6,6 +6,7 @@ from sanic import Sanic
|
|||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.response import json, text
|
||||
from sanic.router import ParameterNameConflicts, RouteDoesNotExist, RouteExists
|
||||
from sanic.testing import SanicTestClient
|
||||
|
||||
|
||||
# ------------------------------------------------------------ #
|
||||
|
@ -163,35 +164,36 @@ def test_route_optional_slash(app):
|
|||
def test_route_strict_slashes_set_to_false_and_host_is_a_list(app):
|
||||
# Part of regression test for issue #1120
|
||||
|
||||
site1 = f"127.0.0.1:{app.test_client.port}"
|
||||
test_client = SanicTestClient(app, port=42101)
|
||||
site1 = f"127.0.0.1:{test_client.port}"
|
||||
|
||||
# before fix, this raises a RouteExists error
|
||||
@app.get("/get", host=[site1, "site2.com"], strict_slashes=False)
|
||||
def get_handler(request):
|
||||
return text("OK")
|
||||
|
||||
request, response = app.test_client.get("http://" + site1 + "/get")
|
||||
request, response = test_client.get("http://" + site1 + "/get")
|
||||
assert response.text == "OK"
|
||||
|
||||
@app.post("/post", host=[site1, "site2.com"], strict_slashes=False)
|
||||
def post_handler(request):
|
||||
return text("OK")
|
||||
|
||||
request, response = app.test_client.post("http://" + site1 + "/post")
|
||||
request, response = test_client.post("http://" + site1 + "/post")
|
||||
assert response.text == "OK"
|
||||
|
||||
@app.put("/put", host=[site1, "site2.com"], strict_slashes=False)
|
||||
def put_handler(request):
|
||||
return text("OK")
|
||||
|
||||
request, response = app.test_client.put("http://" + site1 + "/put")
|
||||
request, response = test_client.put("http://" + site1 + "/put")
|
||||
assert response.text == "OK"
|
||||
|
||||
@app.delete("/delete", host=[site1, "site2.com"], strict_slashes=False)
|
||||
def delete_handler(request):
|
||||
return text("OK")
|
||||
|
||||
request, response = app.test_client.delete("http://" + site1 + "/delete")
|
||||
request, response = test_client.delete("http://" + site1 + "/delete")
|
||||
assert response.text == "OK"
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import asyncio
|
||||
import signal
|
||||
|
||||
from contextlib import closing
|
||||
from socket import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic.testing import HOST, PORT
|
||||
|
@ -118,25 +121,30 @@ def test_create_server_trigger_events(app):
|
|||
app.listener("after_server_stop")(after_stop)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
serv_coro = app.create_server(return_asyncio_server=True)
|
||||
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
|
||||
server = loop.run_until_complete(serv_task)
|
||||
server.after_start()
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt as e:
|
||||
loop.stop()
|
||||
finally:
|
||||
# Run the on_stop function if provided
|
||||
server.before_stop()
|
||||
|
||||
# Wait for server to close
|
||||
close_task = server.close()
|
||||
loop.run_until_complete(close_task)
|
||||
# Use random port for tests
|
||||
with closing(socket()) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
|
||||
# Complete all tasks on the loop
|
||||
signal.stopped = True
|
||||
for connection in server.connections:
|
||||
connection.close_if_idle()
|
||||
server.after_stop()
|
||||
assert flag1 and flag2 and flag3
|
||||
serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
|
||||
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
|
||||
server = loop.run_until_complete(serv_task)
|
||||
server.after_start()
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt as e:
|
||||
loop.stop()
|
||||
finally:
|
||||
# Run the on_stop function if provided
|
||||
server.before_stop()
|
||||
|
||||
# Wait for server to close
|
||||
close_task = server.close()
|
||||
loop.run_until_complete(close_task)
|
||||
|
||||
# Complete all tasks on the loop
|
||||
signal.stopped = True
|
||||
for connection in server.connections:
|
||||
connection.close_if_idle()
|
||||
server.after_stop()
|
||||
assert flag1 and flag2 and flag3
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
|
||||
from queue import Queue
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic.compat import ctrlc_workaround_for_windows
|
||||
from sanic.response import HTTPResponse
|
||||
from sanic.testing import HOST, PORT
|
||||
|
||||
|
@ -16,13 +21,21 @@ calledq = Queue()
|
|||
|
||||
|
||||
def set_loop(app, loop):
|
||||
loop.add_signal_handler = MagicMock()
|
||||
global mock
|
||||
mock = MagicMock()
|
||||
if os.name == "nt":
|
||||
signal.signal = mock
|
||||
else:
|
||||
loop.add_signal_handler = mock
|
||||
|
||||
|
||||
def after(app, loop):
|
||||
calledq.put(loop.add_signal_handler.called)
|
||||
calledq.put(mock.called)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.name == "nt", reason="May hang CI on py38/windows"
|
||||
)
|
||||
def test_register_system_signals(app):
|
||||
"""Test if sanic register system signals"""
|
||||
|
||||
|
@ -38,6 +51,9 @@ def test_register_system_signals(app):
|
|||
assert calledq.get() is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.name == "nt", reason="May hang CI on py38/windows"
|
||||
)
|
||||
def test_dont_register_system_signals(app):
|
||||
"""Test if sanic don't register system signals"""
|
||||
|
||||
|
@ -51,3 +67,49 @@ def test_dont_register_system_signals(app):
|
|||
|
||||
app.run(HOST, PORT, register_sys_signals=False)
|
||||
assert calledq.get() is False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.name == "nt", reason="windows cannot SIGINT processes"
|
||||
)
|
||||
def test_windows_workaround():
|
||||
"""Test Windows workaround (on any other OS)"""
|
||||
# At least some code coverage, even though this test doesn't work on
|
||||
# Windows...
|
||||
class MockApp:
|
||||
def __init__(self):
|
||||
self.is_stopping = False
|
||||
|
||||
def stop(self):
|
||||
assert not self.is_stopping
|
||||
self.is_stopping = True
|
||||
|
||||
def add_task(self, func):
|
||||
loop = asyncio.get_event_loop()
|
||||
self.stay_active_task = loop.create_task(func(self))
|
||||
|
||||
async def atest(stop_first):
|
||||
app = MockApp()
|
||||
ctrlc_workaround_for_windows(app)
|
||||
await asyncio.sleep(0.05)
|
||||
if stop_first:
|
||||
app.stop()
|
||||
await asyncio.sleep(0.2)
|
||||
assert app.is_stopping == stop_first
|
||||
# First Ctrl+C: should call app.stop() within 0.1 seconds
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
await asyncio.sleep(0.2)
|
||||
assert app.is_stopping
|
||||
assert app.stay_active_task.result() == None
|
||||
# Second Ctrl+C should raise
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
return "OK"
|
||||
|
||||
# Run in our private loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
res = loop.run_until_complete(atest(False))
|
||||
assert res == "OK"
|
||||
res = loop.run_until_complete(atest(True))
|
||||
assert res == "OK"
|
||||
|
|
|
@ -27,7 +27,8 @@ def test_test_client_port_default(app):
|
|||
return json(request.transport.get_extra_info("sockname")[1])
|
||||
|
||||
test_client = SanicTestClient(app)
|
||||
assert test_client.port == PORT
|
||||
assert test_client.port == PORT # Can be None before request
|
||||
|
||||
request, response = test_client.get("/get")
|
||||
assert response.json == PORT
|
||||
assert test_client.port > 0
|
||||
assert response.json == test_client.port
|
||||
|
|
Loading…
Reference in New Issue
Block a user