diff --git a/sanic/app.py b/sanic/app.py index 3b93b374..60f845b6 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -81,6 +81,7 @@ class Sanic: self.sock = None self.strict_slashes = strict_slashes self.listeners = defaultdict(list) + self.is_stopping = False self.is_running = False self.is_request_stream = False self.websocket_enabled = False @@ -1177,6 +1178,7 @@ class Sanic: try: self.is_running = True + self.is_stopping = False if workers > 1 and os.name != "posix": logger.warn( f"Multiprocessing is currently not supported on {os.name}," @@ -1209,7 +1211,9 @@ class Sanic: def stop(self): """This kills the Sanic""" - get_event_loop().stop() + if not self.is_stopping: + self.is_stopping = True + get_event_loop().stop() async def create_server( self, diff --git a/sanic/compat.py b/sanic/compat.py index ebf25bb0..28c91b97 100644 --- a/sanic/compat.py +++ b/sanic/compat.py @@ -1,3 +1,6 @@ +import asyncio +import signal + from sys import argv from multidict import CIMultiDict # type: ignore @@ -23,3 +26,27 @@ else: async def open_async(file, mode="r", **kwargs): return aio_open(file, mode, **kwargs) + + +def ctrlc_workaround_for_windows(app): + async def stay_active(app): + """Asyncio wakeups to allow receiving SIGINT in Python""" + while not die: + # If someone else stopped the app, just exit + if app.is_stopping: + return + # Windows Python blocks signal handlers while the event loop is + # waiting for I/O. Frequent wakeups keep interrupts flowing. + await asyncio.sleep(0.1) + # Can't be called from signal handler, so call it from here + app.stop() + + def ctrlc_handler(sig, frame): + nonlocal die + if die: + raise KeyboardInterrupt("Non-graceful Ctrl+C") + die = True + + die = False + signal.signal(signal.SIGINT, ctrlc_handler) + app.add_task(stay_active) diff --git a/sanic/response.py b/sanic/response.py index 9e1a4437..7b7c8521 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -91,6 +91,7 @@ class StreamingHTTPResponse(BaseHTTPResponse): self.headers = Header(headers or {}) self.chunked = chunked self._cookies = None + self.protocol = None async def write(self, data): """Writes a chunk of data to the streaming response. diff --git a/sanic/server.py b/sanic/server.py index 4251d674..31dedd40 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -15,7 +15,7 @@ from time import time from httptools import HttpRequestParser # type: ignore from httptools.parser.errors import HttpParserError # type: ignore -from sanic.compat import Header +from sanic.compat import Header, ctrlc_workaround_for_windows from sanic.exceptions import ( HeaderExpectationFailed, InvalidUsage, @@ -37,6 +37,8 @@ try: except ImportError: pass +OS_IS_WINDOWS = os.name == "nt" + class Signal: stopped = False @@ -929,15 +931,11 @@ def serve( # Register signals for graceful termination if register_sys_signals: - _singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM) - for _signal in _singals: - try: - loop.add_signal_handler(_signal, loop.stop) - except NotImplementedError: - logger.warning( - "Sanic tried to use loop.add_signal_handler " - "but it is not implemented on this platform." - ) + if OS_IS_WINDOWS: + ctrlc_workaround_for_windows(app) + else: + for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: + loop.add_signal_handler(_signal, app.stop) pid = os.getpid() try: logger.info("Starting worker [%s]", pid) diff --git a/sanic/testing.py b/sanic/testing.py index 98ff29e6..41985041 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -12,7 +12,7 @@ from sanic.response import text ASGI_HOST = "mockserver" HOST = "127.0.0.1" -PORT = 42101 +PORT = None class SanicTestClient: @@ -95,7 +95,7 @@ class SanicTestClient: if self.port: server_kwargs = dict( - host=host or self.host, port=self.port, **server_kwargs + host=host or self.host, port=self.port, **server_kwargs, ) host, port = host or self.host, self.port else: @@ -103,6 +103,7 @@ class SanicTestClient: sock.bind((host or self.host, 0)) server_kwargs = dict(sock=sock, **server_kwargs) host, port = sock.getsockname() + self.port = port if uri.startswith( ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") @@ -114,6 +115,9 @@ class SanicTestClient: url = "{scheme}://{host}:{port}{uri}".format( scheme=scheme, host=host, port=port, uri=uri ) + # Tests construct URLs using PORT = None, which means random port not + # known until this function is called, so fix that here + url = url.replace(":None/", f":{port}/") @self.app.listener("after_server_start") async def _collect_response(sanic, loop): @@ -203,7 +207,7 @@ class SanicASGITestClient(httpx.AsyncClient): self.app = app - dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT)) + dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0)) super().__init__(dispatch=dispatch, base_url=base_url) self.last_request = None diff --git a/tests/test_app.py b/tests/test_app.py index f639d9f3..c9cb8329 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -6,6 +6,7 @@ from inspect import isawaitable import pytest +from sanic import Sanic from sanic.exceptions import SanicException from sanic.response import text @@ -48,6 +49,7 @@ def test_asyncio_server_no_start_serving(app): if not uvloop_installed(): loop = asyncio.get_event_loop() asyncio_srv_coro = app.create_server( + port=43123, return_asyncio_server=True, asyncio_server_kwargs=dict(start_serving=False), ) @@ -61,6 +63,7 @@ def test_asyncio_server_start_serving(app): if not uvloop_installed(): loop = asyncio.get_event_loop() asyncio_srv_coro = app.create_server( + port=43124, return_asyncio_server=True, asyncio_server_kwargs=dict(start_serving=False), ) @@ -199,10 +202,17 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog): with caplog.at_level(logging.ERROR): request, response = app.test_client.get("/") + port = request.server_port + assert port > 0 assert response.status == 500 assert "Mock SanicException" in response.text assert ( "sanic.root", logging.ERROR, - "Exception occurred while handling uri: 'http://127.0.0.1:42101/'", + f"Exception occurred while handling uri: 'http://127.0.0.1:{port}/'", ) in caplog.record_tuples + + +def test_app_name_required(): + with pytest.deprecated_call(): + Sanic() diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 4b9ed58b..2bf7edb1 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -221,7 +221,7 @@ async def test_request_class_custom(): class MyCustomRequest(Request): pass - app = Sanic(request_class=MyCustomRequest) + app = Sanic(name=__name__, request_class=MyCustomRequest) @app.get("/custom") def custom_request(request): diff --git a/tests/test_config.py b/tests/test_config.py index 7d2a8395..7c232d75 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -44,42 +44,42 @@ def test_load_from_object_string_exception(app): def test_auto_load_env(): environ["SANIC_TEST_ANSWER"] = "42" - app = Sanic() + app = Sanic(name=__name__) assert app.config.TEST_ANSWER == 42 del environ["SANIC_TEST_ANSWER"] def test_auto_load_bool_env(): environ["SANIC_TEST_ANSWER"] = "True" - app = Sanic() + app = Sanic(name=__name__) assert app.config.TEST_ANSWER == True del environ["SANIC_TEST_ANSWER"] def test_dont_load_env(): environ["SANIC_TEST_ANSWER"] = "42" - app = Sanic(load_env=False) + app = Sanic(name=__name__, load_env=False) assert getattr(app.config, "TEST_ANSWER", None) is None del environ["SANIC_TEST_ANSWER"] def test_load_env_prefix(): environ["MYAPP_TEST_ANSWER"] = "42" - app = Sanic(load_env="MYAPP_") + app = Sanic(name=__name__, load_env="MYAPP_") assert app.config.TEST_ANSWER == 42 del environ["MYAPP_TEST_ANSWER"] def test_load_env_prefix_float_values(): environ["MYAPP_TEST_ROI"] = "2.3" - app = Sanic(load_env="MYAPP_") + app = Sanic(name=__name__, load_env="MYAPP_") assert app.config.TEST_ROI == 2.3 del environ["MYAPP_TEST_ROI"] def test_load_env_prefix_string_value(): environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken" - app = Sanic(load_env="MYAPP_") + app = Sanic(name=__name__, load_env="MYAPP_") assert app.config.TEST_TOKEN == "somerandomtesttoken" del environ["MYAPP_TEST_TOKEN"] diff --git a/tests/test_custom_request.py b/tests/test_custom_request.py index d0ae48e7..54c32ff1 100644 --- a/tests/test_custom_request.py +++ b/tests/test_custom_request.py @@ -20,7 +20,7 @@ class CustomRequest(Request): def test_custom_request(): - app = Sanic(request_class=CustomRequest) + app = Sanic(name=__name__, request_class=CustomRequest) @app.route("/post", methods=["POST"]) async def post_handler(request): diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index 78f95655..8ad6bbc0 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -7,12 +7,12 @@ import httpx from sanic import Sanic, server from sanic.response import text -from sanic.testing import HOST, PORT, SanicTestClient - +from sanic.testing import HOST, SanicTestClient CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} old_conn = None +PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port class ReusableSanicConnectionPool( diff --git a/tests/test_logging.py b/tests/test_logging.py index 5a54b75a..faa83571 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,4 +1,5 @@ import logging +import os import uuid from importlib import reload @@ -12,6 +13,7 @@ import sanic from sanic import Sanic from sanic.log import LOGGING_CONFIG_DEFAULTS, logger from sanic.response import text +from sanic.testing import SanicTestClient logging_format = """module: %(module)s; \ @@ -127,7 +129,7 @@ def test_log_connection_lost(app, debug, monkeypatch): def test_logger(caplog): rand_string = str(uuid.uuid4()) - app = Sanic() + app = Sanic(name=__name__) @app.get("/") def log_info(request): @@ -137,15 +139,67 @@ def test_logger(caplog): with caplog.at_level(logging.INFO): request, response = app.test_client.get("/") + port = request.server_port + + # Note: testing with random port doesn't show the banner because it doesn't + # define host and port. This test supports both modes. + if caplog.record_tuples[0] == ( + "sanic.root", + logging.INFO, + f"Goin' Fast @ http://127.0.0.1:{port}", + ): + caplog.record_tuples.pop(0) + assert caplog.record_tuples[0] == ( "sanic.root", logging.INFO, - "Goin' Fast @ http://127.0.0.1:42101", + f"http://127.0.0.1:{port}/", + ) + assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string) + assert caplog.record_tuples[-1] == ( + "sanic.root", + logging.INFO, + "Server Stopped", + ) + + +def test_logger_static_and_secure(caplog): + # Same as test_logger, except for more coverage: + # - test_client initialised separately for static port + # - using ssl + rand_string = str(uuid.uuid4()) + + app = Sanic(name=__name__) + + @app.get("/") + def log_info(request): + logger.info(rand_string) + return text("hello") + + current_dir = os.path.dirname(os.path.realpath(__file__)) + ssl_cert = os.path.join(current_dir, "certs/selfsigned.cert") + ssl_key = os.path.join(current_dir, "certs/selfsigned.key") + + ssl_dict = {"cert": ssl_cert, "key": ssl_key} + + test_client = SanicTestClient(app, port=42101) + with caplog.at_level(logging.INFO): + request, response = test_client.get( + f"https://127.0.0.1:{test_client.port}/", + server_kwargs=dict(ssl=ssl_dict), + ) + + port = test_client.port + + assert caplog.record_tuples[0] == ( + "sanic.root", + logging.INFO, + f"Goin' Fast @ https://127.0.0.1:{port}", ) assert caplog.record_tuples[1] == ( "sanic.root", logging.INFO, - "http://127.0.0.1:42101/", + f"https://127.0.0.1:{port}/", ) assert caplog.record_tuples[2] == ("sanic.root", logging.INFO, rand_string) assert caplog.record_tuples[-1] == ( diff --git a/tests/test_logo.py b/tests/test_logo.py index 0c34cdb5..e8df2ea5 100644 --- a/tests/test_logo.py +++ b/tests/test_logo.py @@ -49,10 +49,10 @@ def test_logo_false(app, caplog): loop.run_until_complete(_server.wait_closed()) app.stop() + banner, port = caplog.record_tuples[ROW][2].rsplit(":", 1) assert caplog.record_tuples[ROW][1] == logging.INFO - assert caplog.record_tuples[ROW][ - 2 - ] == f"Goin' Fast @ http://127.0.0.1:{PORT}" + assert banner == "Goin' Fast @ http://127.0.0.1" + assert int(port) > 0 def test_logo_true(app, caplog): diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 17b0ac82..3fbf3c3c 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -8,7 +8,7 @@ from sanic.views import CompositionView, HTTPMethodView from sanic.views import stream as stream_decorator -data = "abc" * 10000000 +data = "abc" * 1_000_000 def test_request_stream_method_view(app): diff --git a/tests/test_requests.py b/tests/test_requests.py index c1dbd1a3..a405eebc 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -12,7 +12,7 @@ from sanic import Blueprint, Sanic from sanic.exceptions import ServerError from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters from sanic.response import html, json, text -from sanic.testing import ASGI_HOST, HOST, PORT +from sanic.testing import ASGI_HOST, HOST, PORT, SanicTestClient # ------------------------------------------------------------ # @@ -1029,7 +1029,7 @@ def test_url_attributes_no_ssl(app, path, query, expected_url): app.add_route(handler, path) request, response = app.test_client.get(path + f"?{query}") - assert request.url == expected_url.format(HOST, PORT) + assert request.url == expected_url.format(HOST, request.server_port) parsed = urlparse(request.url) @@ -1086,11 +1086,12 @@ def test_url_attributes_with_ssl_context(app, path, query, expected_url): app.add_route(handler, path) + port = app.test_client.port request, response = app.test_client.get( f"https://{HOST}:{PORT}" + path + f"?{query}", server_kwargs={"ssl": context}, ) - assert request.url == expected_url.format(HOST, PORT) + assert request.url == expected_url.format(HOST, request.server_port) parsed = urlparse(request.url) @@ -1125,7 +1126,7 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url): f"https://{HOST}:{PORT}" + path + f"?{query}", server_kwargs={"ssl": ssl_dict}, ) - assert request.url == expected_url.format(HOST, PORT) + assert request.url == expected_url.format(HOST, request.server_port) parsed = urlparse(request.url) @@ -1917,8 +1918,9 @@ def test_request_server_port(app): def handler(request): return text("OK") - request, response = app.test_client.get("/", headers={"Host": "my-server"}) - assert request.server_port == app.test_client.port + test_client = SanicTestClient(app) + request, response = test_client.get("/", headers={"Host": "my-server"}) + assert request.server_port == test_client.port def test_request_server_port_in_host_header(app): @@ -1939,7 +1941,10 @@ def test_request_server_port_in_host_header(app): request, response = app.test_client.get( "/", headers={"Host": "mal_formed:5555"} ) - assert request.server_port == app.test_client.port + if PORT is None: + assert request.server_port != 5555 + else: + assert request.server_port == app.test_client.port def test_request_server_port_forwarded(app): @@ -1979,7 +1984,7 @@ def test_server_name_and_url_for(app): request, response = app.test_client.get("/foo") assert ( request.url_for("handler") - == f"http://my-server:{app.test_client.port}/foo" + == f"http://my-server:{request.server_port}/foo" ) app.config.SERVER_NAME = "https://my-server/path" @@ -2040,7 +2045,7 @@ async def test_request_form_invalid_content_type_asgi(app): def test_endpoint_basic(): - app = Sanic() + app = Sanic(name=__name__) @app.route("/") def my_unique_handler(request): @@ -2053,7 +2058,7 @@ def test_endpoint_basic(): @pytest.mark.asyncio async def test_endpoint_basic_asgi(): - app = Sanic() + app = Sanic(name=__name__) @app.route("/") def my_unique_handler(request): @@ -2132,5 +2137,5 @@ def test_url_for_without_server_name(app): request, response = app.test_client.get("/sample") assert ( response.json["url"] - == f"http://127.0.0.1:{app.test_client.port}/url-for" + == f"http://127.0.0.1:{request.server_port}/url-for" ) diff --git a/tests/test_routes.py b/tests/test_routes.py index cdb8d78d..7db136fd 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -6,6 +6,7 @@ from sanic import Sanic from sanic.constants import HTTP_METHODS from sanic.response import json, text from sanic.router import ParameterNameConflicts, RouteDoesNotExist, RouteExists +from sanic.testing import SanicTestClient # ------------------------------------------------------------ # @@ -167,35 +168,36 @@ def test_route_optional_slash(app): def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): # Part of regression test for issue #1120 - site1 = f"127.0.0.1:{app.test_client.port}" + test_client = SanicTestClient(app, port=42101) + site1 = f"127.0.0.1:{test_client.port}" # before fix, this raises a RouteExists error @app.get("/get", host=[site1, "site2.com"], strict_slashes=False) def get_handler(request): return text("OK") - request, response = app.test_client.get("http://" + site1 + "/get") + request, response = test_client.get("http://" + site1 + "/get") assert response.text == "OK" @app.post("/post", host=[site1, "site2.com"], strict_slashes=False) def post_handler(request): return text("OK") - request, response = app.test_client.post("http://" + site1 + "/post") + request, response = test_client.post("http://" + site1 + "/post") assert response.text == "OK" @app.put("/put", host=[site1, "site2.com"], strict_slashes=False) def put_handler(request): return text("OK") - request, response = app.test_client.put("http://" + site1 + "/put") + request, response = test_client.put("http://" + site1 + "/put") assert response.text == "OK" @app.delete("/delete", host=[site1, "site2.com"], strict_slashes=False) def delete_handler(request): return text("OK") - request, response = app.test_client.delete("http://" + site1 + "/delete") + request, response = test_client.delete("http://" + site1 + "/delete") assert response.text == "OK" diff --git a/tests/test_server_events.py b/tests/test_server_events.py index edc3d00d..560e9417 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -1,6 +1,9 @@ import asyncio import signal +from contextlib import closing +from socket import socket + import pytest from sanic.testing import HOST, PORT @@ -118,25 +121,30 @@ def test_create_server_trigger_events(app): app.listener("after_server_stop")(after_stop) loop = asyncio.get_event_loop() - serv_coro = app.create_server(return_asyncio_server=True) - serv_task = asyncio.ensure_future(serv_coro, loop=loop) - server = loop.run_until_complete(serv_task) - server.after_start() - try: - loop.run_forever() - except KeyboardInterrupt as e: - loop.stop() - finally: - # Run the on_stop function if provided - server.before_stop() - # Wait for server to close - close_task = server.close() - loop.run_until_complete(close_task) + # Use random port for tests + with closing(socket()) as sock: + sock.bind(("127.0.0.1", 0)) - # Complete all tasks on the loop - signal.stopped = True - for connection in server.connections: - connection.close_if_idle() - server.after_stop() - assert flag1 and flag2 and flag3 + serv_coro = app.create_server(return_asyncio_server=True, sock=sock) + serv_task = asyncio.ensure_future(serv_coro, loop=loop) + server = loop.run_until_complete(serv_task) + server.after_start() + try: + loop.run_forever() + except KeyboardInterrupt as e: + loop.stop() + finally: + # Run the on_stop function if provided + server.before_stop() + + # Wait for server to close + close_task = server.close() + loop.run_until_complete(close_task) + + # Complete all tasks on the loop + signal.stopped = True + for connection in server.connections: + connection.close_if_idle() + server.after_stop() + assert flag1 and flag2 and flag3 diff --git a/tests/test_signal_handlers.py b/tests/test_signal_handlers.py index 262f41cb..6d2b15fe 100644 --- a/tests/test_signal_handlers.py +++ b/tests/test_signal_handlers.py @@ -1,8 +1,13 @@ import asyncio +import os +import signal from queue import Queue from unittest.mock import MagicMock +import pytest + +from sanic.compat import ctrlc_workaround_for_windows from sanic.response import HTTPResponse from sanic.testing import HOST, PORT @@ -16,13 +21,21 @@ calledq = Queue() def set_loop(app, loop): - loop.add_signal_handler = MagicMock() + global mock + mock = MagicMock() + if os.name == "nt": + signal.signal = mock + else: + loop.add_signal_handler = mock def after(app, loop): - calledq.put(loop.add_signal_handler.called) + calledq.put(mock.called) +@pytest.mark.skipif( + os.name == "nt", reason="May hang CI on py38/windows" +) def test_register_system_signals(app): """Test if sanic register system signals""" @@ -38,6 +51,9 @@ def test_register_system_signals(app): assert calledq.get() is True +@pytest.mark.skipif( + os.name == "nt", reason="May hang CI on py38/windows" +) def test_dont_register_system_signals(app): """Test if sanic don't register system signals""" @@ -51,3 +67,49 @@ def test_dont_register_system_signals(app): app.run(HOST, PORT, register_sys_signals=False) assert calledq.get() is False + + +@pytest.mark.skipif( + os.name == "nt", reason="windows cannot SIGINT processes" +) +def test_windows_workaround(): + """Test Windows workaround (on any other OS)""" + # At least some code coverage, even though this test doesn't work on + # Windows... + class MockApp: + def __init__(self): + self.is_stopping = False + + def stop(self): + assert not self.is_stopping + self.is_stopping = True + + def add_task(self, func): + loop = asyncio.get_event_loop() + self.stay_active_task = loop.create_task(func(self)) + + async def atest(stop_first): + app = MockApp() + ctrlc_workaround_for_windows(app) + await asyncio.sleep(0.05) + if stop_first: + app.stop() + await asyncio.sleep(0.2) + assert app.is_stopping == stop_first + # First Ctrl+C: should call app.stop() within 0.1 seconds + os.kill(os.getpid(), signal.SIGINT) + await asyncio.sleep(0.2) + assert app.is_stopping + assert app.stay_active_task.result() == None + # Second Ctrl+C should raise + with pytest.raises(KeyboardInterrupt): + os.kill(os.getpid(), signal.SIGINT) + return "OK" + + # Run in our private loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + res = loop.run_until_complete(atest(False)) + assert res == "OK" + res = loop.run_until_complete(atest(True)) + assert res == "OK" diff --git a/tests/test_test_client_port.py b/tests/test_test_client_port.py index 1b4f300b..2940ba0d 100644 --- a/tests/test_test_client_port.py +++ b/tests/test_test_client_port.py @@ -27,7 +27,8 @@ def test_test_client_port_default(app): return json(request.transport.get_extra_info("sockname")[1]) test_client = SanicTestClient(app) - assert test_client.port == PORT + assert test_client.port == PORT # Can be None before request request, response = test_client.get("/get") - assert response.json == PORT + assert test_client.port > 0 + assert response.json == test_client.port