Fix Ctrl+C and tests on Windows. (#1808)

* Fix Ctrl+C on Windows.

* Disable testing of a function N/A on Windows.

* Add test for coverage, avoid crash on missing _stopping.

* Initialise StreamingHTTPResponse.protocol = None

* Improved comments.

* Reduce amount of data in test_request_stream to avoid failures on Windows.

* The Windows test doesn't work on Windows :(

* Use port numbers more likely to be free than 8000.

* Disable the other signal tests on Windows as well.

* Windows doesn't properly support SO_REUSEADDR, so that's disabled in Python, and thus rebinding fails. For successful testing, reuse port instead.

* app.run argument handling: added server kwargs (alike create_server), added warning on extra kwargs, made auto_reload explicit argument. Another go at Windows tests

* Revert "app.run argument handling: added server kwargs (alike create_server), added warning on extra kwargs, made auto_reload explicit argument. Another go at Windows tests"

This reverts commit dc5d682448.

* Use random test server port on most tests. Should avoid port/addr reuse issues.

* Another test to random port instead of 8000.

* Fix deprecation warnings about missing name on Sanic() in tests.

* Linter and typing

* Increase test coverage

* Rewrite test for ctrlc_windows_workaround

* py36 compat

* py36 compat

* py36 compat

* Don't rely on loop internals but add a stopping flag to app.

* App may be restarted.

* py36 compat

* Linter

* Add a constant for OS checking.

Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
This commit is contained in:
L. Kärkkäinen 2020-03-26 06:42:46 +02:00 committed by GitHub
parent 4db075ffc1
commit 120f0262f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 248 additions and 72 deletions

View File

@ -81,6 +81,7 @@ class Sanic:
self.sock = None self.sock = None
self.strict_slashes = strict_slashes self.strict_slashes = strict_slashes
self.listeners = defaultdict(list) self.listeners = defaultdict(list)
self.is_stopping = False
self.is_running = False self.is_running = False
self.is_request_stream = False self.is_request_stream = False
self.websocket_enabled = False self.websocket_enabled = False
@ -1177,6 +1178,7 @@ class Sanic:
try: try:
self.is_running = True self.is_running = True
self.is_stopping = False
if workers > 1 and os.name != "posix": if workers > 1 and os.name != "posix":
logger.warn( logger.warn(
f"Multiprocessing is currently not supported on {os.name}," f"Multiprocessing is currently not supported on {os.name},"
@ -1209,7 +1211,9 @@ class Sanic:
def stop(self): def stop(self):
"""This kills the Sanic""" """This kills the Sanic"""
get_event_loop().stop() if not self.is_stopping:
self.is_stopping = True
get_event_loop().stop()
async def create_server( async def create_server(
self, self,

View File

@ -1,3 +1,6 @@
import asyncio
import signal
from sys import argv from sys import argv
from multidict import CIMultiDict # type: ignore from multidict import CIMultiDict # type: ignore
@ -23,3 +26,27 @@ else:
async def open_async(file, mode="r", **kwargs): async def open_async(file, mode="r", **kwargs):
return aio_open(file, mode, **kwargs) return aio_open(file, mode, **kwargs)
def ctrlc_workaround_for_windows(app):
async def stay_active(app):
"""Asyncio wakeups to allow receiving SIGINT in Python"""
while not die:
# If someone else stopped the app, just exit
if app.is_stopping:
return
# Windows Python blocks signal handlers while the event loop is
# waiting for I/O. Frequent wakeups keep interrupts flowing.
await asyncio.sleep(0.1)
# Can't be called from signal handler, so call it from here
app.stop()
def ctrlc_handler(sig, frame):
nonlocal die
if die:
raise KeyboardInterrupt("Non-graceful Ctrl+C")
die = True
die = False
signal.signal(signal.SIGINT, ctrlc_handler)
app.add_task(stay_active)

View File

@ -91,6 +91,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
self.headers = Header(headers or {}) self.headers = Header(headers or {})
self.chunked = chunked self.chunked = chunked
self._cookies = None self._cookies = None
self.protocol = None
async def write(self, data): async def write(self, data):
"""Writes a chunk of data to the streaming response. """Writes a chunk of data to the streaming response.

View File

@ -15,7 +15,7 @@ from time import time
from httptools import HttpRequestParser # type: ignore from httptools import HttpRequestParser # type: ignore
from httptools.parser.errors import HttpParserError # type: ignore from httptools.parser.errors import HttpParserError # type: ignore
from sanic.compat import Header from sanic.compat import Header, ctrlc_workaround_for_windows
from sanic.exceptions import ( from sanic.exceptions import (
HeaderExpectationFailed, HeaderExpectationFailed,
InvalidUsage, InvalidUsage,
@ -37,6 +37,8 @@ try:
except ImportError: except ImportError:
pass pass
OS_IS_WINDOWS = os.name == "nt"
class Signal: class Signal:
stopped = False stopped = False
@ -929,15 +931,11 @@ def serve(
# Register signals for graceful termination # Register signals for graceful termination
if register_sys_signals: if register_sys_signals:
_singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM) if OS_IS_WINDOWS:
for _signal in _singals: ctrlc_workaround_for_windows(app)
try: else:
loop.add_signal_handler(_signal, loop.stop) for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
except NotImplementedError: loop.add_signal_handler(_signal, app.stop)
logger.warning(
"Sanic tried to use loop.add_signal_handler "
"but it is not implemented on this platform."
)
pid = os.getpid() pid = os.getpid()
try: try:
logger.info("Starting worker [%s]", pid) logger.info("Starting worker [%s]", pid)

View File

@ -12,7 +12,7 @@ from sanic.response import text
ASGI_HOST = "mockserver" ASGI_HOST = "mockserver"
HOST = "127.0.0.1" HOST = "127.0.0.1"
PORT = 42101 PORT = None
class SanicTestClient: class SanicTestClient:
@ -95,7 +95,7 @@ class SanicTestClient:
if self.port: if self.port:
server_kwargs = dict( server_kwargs = dict(
host=host or self.host, port=self.port, **server_kwargs host=host or self.host, port=self.port, **server_kwargs,
) )
host, port = host or self.host, self.port host, port = host or self.host, self.port
else: else:
@ -103,6 +103,7 @@ class SanicTestClient:
sock.bind((host or self.host, 0)) sock.bind((host or self.host, 0))
server_kwargs = dict(sock=sock, **server_kwargs) server_kwargs = dict(sock=sock, **server_kwargs)
host, port = sock.getsockname() host, port = sock.getsockname()
self.port = port
if uri.startswith( if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
@ -114,6 +115,9 @@ class SanicTestClient:
url = "{scheme}://{host}:{port}{uri}".format( url = "{scheme}://{host}:{port}{uri}".format(
scheme=scheme, host=host, port=port, uri=uri scheme=scheme, host=host, port=port, uri=uri
) )
# Tests construct URLs using PORT = None, which means random port not
# known until this function is called, so fix that here
url = url.replace(":None/", f":{port}/")
@self.app.listener("after_server_start") @self.app.listener("after_server_start")
async def _collect_response(sanic, loop): async def _collect_response(sanic, loop):
@ -203,7 +207,7 @@ class SanicASGITestClient(httpx.AsyncClient):
self.app = app self.app = app
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT)) dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0))
super().__init__(dispatch=dispatch, base_url=base_url) super().__init__(dispatch=dispatch, base_url=base_url)
self.last_request = None self.last_request = None

View File

@ -6,6 +6,7 @@ from inspect import isawaitable
import pytest import pytest
from sanic import Sanic
from sanic.exceptions import SanicException from sanic.exceptions import SanicException
from sanic.response import text from sanic.response import text
@ -48,6 +49,7 @@ def test_asyncio_server_no_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server( asyncio_srv_coro = app.create_server(
port=43123,
return_asyncio_server=True, return_asyncio_server=True,
asyncio_server_kwargs=dict(start_serving=False), asyncio_server_kwargs=dict(start_serving=False),
) )
@ -61,6 +63,7 @@ def test_asyncio_server_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server( asyncio_srv_coro = app.create_server(
port=43124,
return_asyncio_server=True, return_asyncio_server=True,
asyncio_server_kwargs=dict(start_serving=False), asyncio_server_kwargs=dict(start_serving=False),
) )
@ -199,10 +202,17 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
port = request.server_port
assert port > 0
assert response.status == 500 assert response.status == 500
assert "Mock SanicException" in response.text assert "Mock SanicException" in response.text
assert ( assert (
"sanic.root", "sanic.root",
logging.ERROR, logging.ERROR,
"Exception occurred while handling uri: 'http://127.0.0.1:42101/'", f"Exception occurred while handling uri: 'http://127.0.0.1:{port}/'",
) in caplog.record_tuples ) in caplog.record_tuples
def test_app_name_required():
with pytest.deprecated_call():
Sanic()

View File

@ -221,7 +221,7 @@ async def test_request_class_custom():
class MyCustomRequest(Request): class MyCustomRequest(Request):
pass pass
app = Sanic(request_class=MyCustomRequest) app = Sanic(name=__name__, request_class=MyCustomRequest)
@app.get("/custom") @app.get("/custom")
def custom_request(request): def custom_request(request):

View File

@ -44,42 +44,42 @@ def test_load_from_object_string_exception(app):
def test_auto_load_env(): def test_auto_load_env():
environ["SANIC_TEST_ANSWER"] = "42" environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic() app = Sanic(name=__name__)
assert app.config.TEST_ANSWER == 42 assert app.config.TEST_ANSWER == 42
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_auto_load_bool_env(): def test_auto_load_bool_env():
environ["SANIC_TEST_ANSWER"] = "True" environ["SANIC_TEST_ANSWER"] = "True"
app = Sanic() app = Sanic(name=__name__)
assert app.config.TEST_ANSWER == True assert app.config.TEST_ANSWER == True
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_dont_load_env(): def test_dont_load_env():
environ["SANIC_TEST_ANSWER"] = "42" environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(load_env=False) app = Sanic(name=__name__, load_env=False)
assert getattr(app.config, "TEST_ANSWER", None) is None assert getattr(app.config, "TEST_ANSWER", None) is None
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_load_env_prefix(): def test_load_env_prefix():
environ["MYAPP_TEST_ANSWER"] = "42" environ["MYAPP_TEST_ANSWER"] = "42"
app = Sanic(load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_ANSWER == 42 assert app.config.TEST_ANSWER == 42
del environ["MYAPP_TEST_ANSWER"] del environ["MYAPP_TEST_ANSWER"]
def test_load_env_prefix_float_values(): def test_load_env_prefix_float_values():
environ["MYAPP_TEST_ROI"] = "2.3" environ["MYAPP_TEST_ROI"] = "2.3"
app = Sanic(load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_ROI == 2.3 assert app.config.TEST_ROI == 2.3
del environ["MYAPP_TEST_ROI"] del environ["MYAPP_TEST_ROI"]
def test_load_env_prefix_string_value(): def test_load_env_prefix_string_value():
environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken" environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken"
app = Sanic(load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_TOKEN == "somerandomtesttoken" assert app.config.TEST_TOKEN == "somerandomtesttoken"
del environ["MYAPP_TEST_TOKEN"] del environ["MYAPP_TEST_TOKEN"]

View File

@ -20,7 +20,7 @@ class CustomRequest(Request):
def test_custom_request(): def test_custom_request():
app = Sanic(request_class=CustomRequest) app = Sanic(name=__name__, request_class=CustomRequest)
@app.route("/post", methods=["POST"]) @app.route("/post", methods=["POST"])
async def post_handler(request): async def post_handler(request):

View File

@ -7,12 +7,12 @@ import httpx
from sanic import Sanic, server from sanic import Sanic, server
from sanic.response import text from sanic.response import text
from sanic.testing import HOST, PORT, SanicTestClient from sanic.testing import HOST, SanicTestClient
CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
old_conn = None old_conn = None
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
class ReusableSanicConnectionPool( class ReusableSanicConnectionPool(

View File

@ -1,4 +1,5 @@
import logging import logging
import os
import uuid import uuid
from importlib import reload from importlib import reload
@ -12,6 +13,7 @@ import sanic
from sanic import Sanic from sanic import Sanic
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.response import text from sanic.response import text
from sanic.testing import SanicTestClient
logging_format = """module: %(module)s; \ logging_format = """module: %(module)s; \
@ -127,7 +129,7 @@ def test_log_connection_lost(app, debug, monkeypatch):
def test_logger(caplog): def test_logger(caplog):
rand_string = str(uuid.uuid4()) rand_string = str(uuid.uuid4())
app = Sanic() app = Sanic(name=__name__)
@app.get("/") @app.get("/")
def log_info(request): def log_info(request):
@ -137,15 +139,67 @@ def test_logger(caplog):
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
port = request.server_port
# Note: testing with random port doesn't show the banner because it doesn't
# define host and port. This test supports both modes.
if caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ http://127.0.0.1:{port}",
):
caplog.record_tuples.pop(0)
assert caplog.record_tuples[0] == ( assert caplog.record_tuples[0] == (
"sanic.root", "sanic.root",
logging.INFO, logging.INFO,
"Goin' Fast @ http://127.0.0.1:42101", f"http://127.0.0.1:{port}/",
)
assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == (
"sanic.root",
logging.INFO,
"Server Stopped",
)
def test_logger_static_and_secure(caplog):
# Same as test_logger, except for more coverage:
# - test_client initialised separately for static port
# - using ssl
rand_string = str(uuid.uuid4())
app = Sanic(name=__name__)
@app.get("/")
def log_info(request):
logger.info(rand_string)
return text("hello")
current_dir = os.path.dirname(os.path.realpath(__file__))
ssl_cert = os.path.join(current_dir, "certs/selfsigned.cert")
ssl_key = os.path.join(current_dir, "certs/selfsigned.key")
ssl_dict = {"cert": ssl_cert, "key": ssl_key}
test_client = SanicTestClient(app, port=42101)
with caplog.at_level(logging.INFO):
request, response = test_client.get(
f"https://127.0.0.1:{test_client.port}/",
server_kwargs=dict(ssl=ssl_dict),
)
port = test_client.port
assert caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ https://127.0.0.1:{port}",
) )
assert caplog.record_tuples[1] == ( assert caplog.record_tuples[1] == (
"sanic.root", "sanic.root",
logging.INFO, logging.INFO,
"http://127.0.0.1:42101/", f"https://127.0.0.1:{port}/",
) )
assert caplog.record_tuples[2] == ("sanic.root", logging.INFO, rand_string) assert caplog.record_tuples[2] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == ( assert caplog.record_tuples[-1] == (

View File

@ -49,10 +49,10 @@ def test_logo_false(app, caplog):
loop.run_until_complete(_server.wait_closed()) loop.run_until_complete(_server.wait_closed())
app.stop() app.stop()
banner, port = caplog.record_tuples[ROW][2].rsplit(":", 1)
assert caplog.record_tuples[ROW][1] == logging.INFO assert caplog.record_tuples[ROW][1] == logging.INFO
assert caplog.record_tuples[ROW][ assert banner == "Goin' Fast @ http://127.0.0.1"
2 assert int(port) > 0
] == f"Goin' Fast @ http://127.0.0.1:{PORT}"
def test_logo_true(app, caplog): def test_logo_true(app, caplog):

View File

@ -8,7 +8,7 @@ from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator from sanic.views import stream as stream_decorator
data = "abc" * 10000000 data = "abc" * 1_000_000
def test_request_stream_method_view(app): def test_request_stream_method_view(app):

View File

@ -12,7 +12,7 @@ from sanic import Blueprint, Sanic
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters
from sanic.response import html, json, text from sanic.response import html, json, text
from sanic.testing import ASGI_HOST, HOST, PORT from sanic.testing import ASGI_HOST, HOST, PORT, SanicTestClient
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
@ -1029,7 +1029,7 @@ def test_url_attributes_no_ssl(app, path, query, expected_url):
app.add_route(handler, path) app.add_route(handler, path)
request, response = app.test_client.get(path + f"?{query}") request, response = app.test_client.get(path + f"?{query}")
assert request.url == expected_url.format(HOST, PORT) assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url) parsed = urlparse(request.url)
@ -1086,11 +1086,12 @@ def test_url_attributes_with_ssl_context(app, path, query, expected_url):
app.add_route(handler, path) app.add_route(handler, path)
port = app.test_client.port
request, response = app.test_client.get( request, response = app.test_client.get(
f"https://{HOST}:{PORT}" + path + f"?{query}", f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": context}, server_kwargs={"ssl": context},
) )
assert request.url == expected_url.format(HOST, PORT) assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url) parsed = urlparse(request.url)
@ -1125,7 +1126,7 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
f"https://{HOST}:{PORT}" + path + f"?{query}", f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": ssl_dict}, server_kwargs={"ssl": ssl_dict},
) )
assert request.url == expected_url.format(HOST, PORT) assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url) parsed = urlparse(request.url)
@ -1917,8 +1918,9 @@ def test_request_server_port(app):
def handler(request): def handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={"Host": "my-server"}) test_client = SanicTestClient(app)
assert request.server_port == app.test_client.port request, response = test_client.get("/", headers={"Host": "my-server"})
assert request.server_port == test_client.port
def test_request_server_port_in_host_header(app): def test_request_server_port_in_host_header(app):
@ -1939,7 +1941,10 @@ def test_request_server_port_in_host_header(app):
request, response = app.test_client.get( request, response = app.test_client.get(
"/", headers={"Host": "mal_formed:5555"} "/", headers={"Host": "mal_formed:5555"}
) )
assert request.server_port == app.test_client.port if PORT is None:
assert request.server_port != 5555
else:
assert request.server_port == app.test_client.port
def test_request_server_port_forwarded(app): def test_request_server_port_forwarded(app):
@ -1979,7 +1984,7 @@ def test_server_name_and_url_for(app):
request, response = app.test_client.get("/foo") request, response = app.test_client.get("/foo")
assert ( assert (
request.url_for("handler") request.url_for("handler")
== f"http://my-server:{app.test_client.port}/foo" == f"http://my-server:{request.server_port}/foo"
) )
app.config.SERVER_NAME = "https://my-server/path" app.config.SERVER_NAME = "https://my-server/path"
@ -2040,7 +2045,7 @@ async def test_request_form_invalid_content_type_asgi(app):
def test_endpoint_basic(): def test_endpoint_basic():
app = Sanic() app = Sanic(name=__name__)
@app.route("/") @app.route("/")
def my_unique_handler(request): def my_unique_handler(request):
@ -2053,7 +2058,7 @@ def test_endpoint_basic():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_endpoint_basic_asgi(): async def test_endpoint_basic_asgi():
app = Sanic() app = Sanic(name=__name__)
@app.route("/") @app.route("/")
def my_unique_handler(request): def my_unique_handler(request):
@ -2132,5 +2137,5 @@ def test_url_for_without_server_name(app):
request, response = app.test_client.get("/sample") request, response = app.test_client.get("/sample")
assert ( assert (
response.json["url"] response.json["url"]
== f"http://127.0.0.1:{app.test_client.port}/url-for" == f"http://127.0.0.1:{request.server_port}/url-for"
) )

View File

@ -6,6 +6,7 @@ from sanic import Sanic
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.response import json, text from sanic.response import json, text
from sanic.router import ParameterNameConflicts, RouteDoesNotExist, RouteExists from sanic.router import ParameterNameConflicts, RouteDoesNotExist, RouteExists
from sanic.testing import SanicTestClient
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
@ -167,35 +168,36 @@ def test_route_optional_slash(app):
def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): def test_route_strict_slashes_set_to_false_and_host_is_a_list(app):
# Part of regression test for issue #1120 # Part of regression test for issue #1120
site1 = f"127.0.0.1:{app.test_client.port}" test_client = SanicTestClient(app, port=42101)
site1 = f"127.0.0.1:{test_client.port}"
# before fix, this raises a RouteExists error # before fix, this raises a RouteExists error
@app.get("/get", host=[site1, "site2.com"], strict_slashes=False) @app.get("/get", host=[site1, "site2.com"], strict_slashes=False)
def get_handler(request): def get_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("http://" + site1 + "/get") request, response = test_client.get("http://" + site1 + "/get")
assert response.text == "OK" assert response.text == "OK"
@app.post("/post", host=[site1, "site2.com"], strict_slashes=False) @app.post("/post", host=[site1, "site2.com"], strict_slashes=False)
def post_handler(request): def post_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.post("http://" + site1 + "/post") request, response = test_client.post("http://" + site1 + "/post")
assert response.text == "OK" assert response.text == "OK"
@app.put("/put", host=[site1, "site2.com"], strict_slashes=False) @app.put("/put", host=[site1, "site2.com"], strict_slashes=False)
def put_handler(request): def put_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.put("http://" + site1 + "/put") request, response = test_client.put("http://" + site1 + "/put")
assert response.text == "OK" assert response.text == "OK"
@app.delete("/delete", host=[site1, "site2.com"], strict_slashes=False) @app.delete("/delete", host=[site1, "site2.com"], strict_slashes=False)
def delete_handler(request): def delete_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.delete("http://" + site1 + "/delete") request, response = test_client.delete("http://" + site1 + "/delete")
assert response.text == "OK" assert response.text == "OK"

View File

@ -1,6 +1,9 @@
import asyncio import asyncio
import signal import signal
from contextlib import closing
from socket import socket
import pytest import pytest
from sanic.testing import HOST, PORT from sanic.testing import HOST, PORT
@ -118,25 +121,30 @@ def test_create_server_trigger_events(app):
app.listener("after_server_stop")(after_stop) app.listener("after_server_stop")(after_stop)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
serv_coro = app.create_server(return_asyncio_server=True)
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
server = loop.run_until_complete(serv_task)
server.after_start()
try:
loop.run_forever()
except KeyboardInterrupt as e:
loop.stop()
finally:
# Run the on_stop function if provided
server.before_stop()
# Wait for server to close # Use random port for tests
close_task = server.close() with closing(socket()) as sock:
loop.run_until_complete(close_task) sock.bind(("127.0.0.1", 0))
# Complete all tasks on the loop serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
signal.stopped = True serv_task = asyncio.ensure_future(serv_coro, loop=loop)
for connection in server.connections: server = loop.run_until_complete(serv_task)
connection.close_if_idle() server.after_start()
server.after_stop() try:
assert flag1 and flag2 and flag3 loop.run_forever()
except KeyboardInterrupt as e:
loop.stop()
finally:
# Run the on_stop function if provided
server.before_stop()
# Wait for server to close
close_task = server.close()
loop.run_until_complete(close_task)
# Complete all tasks on the loop
signal.stopped = True
for connection in server.connections:
connection.close_if_idle()
server.after_stop()
assert flag1 and flag2 and flag3

View File

@ -1,8 +1,13 @@
import asyncio import asyncio
import os
import signal
from queue import Queue from queue import Queue
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from sanic.compat import ctrlc_workaround_for_windows
from sanic.response import HTTPResponse from sanic.response import HTTPResponse
from sanic.testing import HOST, PORT from sanic.testing import HOST, PORT
@ -16,13 +21,21 @@ calledq = Queue()
def set_loop(app, loop): def set_loop(app, loop):
loop.add_signal_handler = MagicMock() global mock
mock = MagicMock()
if os.name == "nt":
signal.signal = mock
else:
loop.add_signal_handler = mock
def after(app, loop): def after(app, loop):
calledq.put(loop.add_signal_handler.called) calledq.put(mock.called)
@pytest.mark.skipif(
os.name == "nt", reason="May hang CI on py38/windows"
)
def test_register_system_signals(app): def test_register_system_signals(app):
"""Test if sanic register system signals""" """Test if sanic register system signals"""
@ -38,6 +51,9 @@ def test_register_system_signals(app):
assert calledq.get() is True assert calledq.get() is True
@pytest.mark.skipif(
os.name == "nt", reason="May hang CI on py38/windows"
)
def test_dont_register_system_signals(app): def test_dont_register_system_signals(app):
"""Test if sanic don't register system signals""" """Test if sanic don't register system signals"""
@ -51,3 +67,49 @@ def test_dont_register_system_signals(app):
app.run(HOST, PORT, register_sys_signals=False) app.run(HOST, PORT, register_sys_signals=False)
assert calledq.get() is False assert calledq.get() is False
@pytest.mark.skipif(
os.name == "nt", reason="windows cannot SIGINT processes"
)
def test_windows_workaround():
"""Test Windows workaround (on any other OS)"""
# At least some code coverage, even though this test doesn't work on
# Windows...
class MockApp:
def __init__(self):
self.is_stopping = False
def stop(self):
assert not self.is_stopping
self.is_stopping = True
def add_task(self, func):
loop = asyncio.get_event_loop()
self.stay_active_task = loop.create_task(func(self))
async def atest(stop_first):
app = MockApp()
ctrlc_workaround_for_windows(app)
await asyncio.sleep(0.05)
if stop_first:
app.stop()
await asyncio.sleep(0.2)
assert app.is_stopping == stop_first
# First Ctrl+C: should call app.stop() within 0.1 seconds
os.kill(os.getpid(), signal.SIGINT)
await asyncio.sleep(0.2)
assert app.is_stopping
assert app.stay_active_task.result() == None
# Second Ctrl+C should raise
with pytest.raises(KeyboardInterrupt):
os.kill(os.getpid(), signal.SIGINT)
return "OK"
# Run in our private loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(atest(False))
assert res == "OK"
res = loop.run_until_complete(atest(True))
assert res == "OK"

View File

@ -27,7 +27,8 @@ def test_test_client_port_default(app):
return json(request.transport.get_extra_info("sockname")[1]) return json(request.transport.get_extra_info("sockname")[1])
test_client = SanicTestClient(app) test_client = SanicTestClient(app)
assert test_client.port == PORT assert test_client.port == PORT # Can be None before request
request, response = test_client.get("/get") request, response = test_client.get("/get")
assert response.json == PORT assert test_client.port > 0
assert response.json == test_client.port