diff --git a/sanic/testing.py b/sanic/testing.py index 3a1d15c5..a05c9ea7 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,12 +1,30 @@ +import sys import traceback from json import JSONDecodeError from sanic.log import logger from sanic.exceptions import MethodNotSupported from sanic.response import text +try: + try: + # direct use + import packaging as _packaging + version = _packaging.version + except (ImportError, AttributeError): + # setuptools v39.0 and above. + try: + from setuptools.extern import packaging as _packaging + except ImportError: + # Before setuptools v39.0 + from pkg_resources.extern import packaging as _packaging + version = _packaging.version +except ImportError: + raise RuntimeError("The 'packaging' library is missing.") + HOST = '127.0.0.1' PORT = 42101 +is_windows = sys.platform in ['win32', 'cygwin'] class SanicTestClient: diff --git a/sanic/worker.py b/sanic/worker.py index d367a7c3..a5a39c82 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -12,199 +12,202 @@ except ImportError: try: import uvloop + import gunicorn.workers.base as base asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: + base = None pass -import gunicorn.workers.base as base from sanic.server import trigger_events, serve, HttpProtocol, Signal from sanic.websocket import WebSocketProtocol -class GunicornWorker(base.Worker): +# gunicorn is not available on windows +if base is not None: + class GunicornWorker(base.Worker): - http_protocol = HttpProtocol - websocket_protocol = WebSocketProtocol + http_protocol = HttpProtocol + websocket_protocol = WebSocketProtocol - def __init__(self, *args, **kw): # pragma: no cover - super().__init__(*args, **kw) - cfg = self.cfg - if cfg.is_ssl: - self.ssl_context = self._create_ssl_context(cfg) - else: - self.ssl_context = None - self.servers = {} - self.connections = set() - self.exit_code = 0 - self.signal = Signal() + def __init__(self, *args, **kw): # pragma: no cover + super().__init__(*args, **kw) + cfg = self.cfg + if cfg.is_ssl: + self.ssl_context = self._create_ssl_context(cfg) + else: + self.ssl_context = None + self.servers = {} + self.connections = set() + self.exit_code = 0 + self.signal = Signal() - def init_process(self): - # create new event_loop after fork - asyncio.get_event_loop().close() + def init_process(self): + # create new event_loop after fork + asyncio.get_event_loop().close() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) - super().init_process() + super().init_process() - def run(self): - is_debug = self.log.loglevel == logging.DEBUG - protocol = ( - self.websocket_protocol if self.app.callable.websocket_enabled - else self.http_protocol) - self._server_settings = self.app.callable._helper( - loop=self.loop, - debug=is_debug, - protocol=protocol, - ssl=self.ssl_context, - run_async=True) - self._server_settings['signal'] = self.signal - self._server_settings.pop('sock') - trigger_events(self._server_settings.get('before_start', []), - self.loop) - self._server_settings['before_start'] = () - - self._runner = asyncio.ensure_future(self._run(), loop=self.loop) - try: - self.loop.run_until_complete(self._runner) - self.app.callable.is_running = True - trigger_events(self._server_settings.get('after_start', []), + def run(self): + is_debug = self.log.loglevel == logging.DEBUG + protocol = ( + self.websocket_protocol if self.app.callable.websocket_enabled + else self.http_protocol) + self._server_settings = self.app.callable._helper( + loop=self.loop, + debug=is_debug, + protocol=protocol, + ssl=self.ssl_context, + run_async=True) + self._server_settings['signal'] = self.signal + self._server_settings.pop('sock') + trigger_events(self._server_settings.get('before_start', []), self.loop) - self.loop.run_until_complete(self._check_alive()) - trigger_events(self._server_settings.get('before_stop', []), - self.loop) - self.loop.run_until_complete(self.close()) - except BaseException: - traceback.print_exc() - finally: + self._server_settings['before_start'] = () + + self._runner = asyncio.ensure_future(self._run(), loop=self.loop) try: - trigger_events(self._server_settings.get('after_stop', []), + self.loop.run_until_complete(self._runner) + self.app.callable.is_running = True + trigger_events(self._server_settings.get('after_start', []), self.loop) + self.loop.run_until_complete(self._check_alive()) + trigger_events(self._server_settings.get('before_stop', []), + self.loop) + self.loop.run_until_complete(self.close()) except BaseException: traceback.print_exc() finally: - self.loop.close() + try: + trigger_events(self._server_settings.get('after_stop', []), + self.loop) + except BaseException: + traceback.print_exc() + finally: + self.loop.close() - sys.exit(self.exit_code) + sys.exit(self.exit_code) - async def close(self): - if self.servers: - # stop accepting connections - self.log.info("Stopping server: %s, connections: %s", - self.pid, len(self.connections)) - for server in self.servers: - server.close() - await server.wait_closed() - self.servers.clear() + async def close(self): + if self.servers: + # stop accepting connections + self.log.info("Stopping server: %s, connections: %s", + self.pid, len(self.connections)) + for server in self.servers: + server.close() + await server.wait_closed() + self.servers.clear() - # prepare connections for closing - self.signal.stopped = True - for conn in self.connections: - conn.close_if_idle() + # prepare connections for closing + self.signal.stopped = True + for conn in self.connections: + conn.close_if_idle() - # gracefully shutdown timeout - start_shutdown = 0 - graceful_shutdown_timeout = self.cfg.graceful_timeout - while self.connections and \ - (start_shutdown < graceful_shutdown_timeout): - await asyncio.sleep(0.1) - start_shutdown = start_shutdown + 0.1 + # gracefully shutdown timeout + start_shutdown = 0 + graceful_shutdown_timeout = self.cfg.graceful_timeout + while self.connections and \ + (start_shutdown < graceful_shutdown_timeout): + await asyncio.sleep(0.1) + start_shutdown = start_shutdown + 0.1 - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - coros = [] - for conn in self.connections: - if hasattr(conn, "websocket") and conn.websocket: - coros.append( - conn.websocket.close_connection() - ) - else: - conn.close() - _shutdown = asyncio.gather(*coros, loop=self.loop) - await _shutdown + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + coros = [] + for conn in self.connections: + if hasattr(conn, "websocket") and conn.websocket: + coros.append( + conn.websocket.close_connection() + ) + else: + conn.close() + _shutdown = asyncio.gather(*coros, loop=self.loop) + await _shutdown - async def _run(self): - for sock in self.sockets: - state = dict(requests_count=0) - self._server_settings["host"] = None - self._server_settings["port"] = None - server = await serve( - sock=sock, - connections=self.connections, - state=state, - **self._server_settings - ) - self.servers[server] = state - - async def _check_alive(self): - # If our parent changed then we shut down. - pid = os.getpid() - try: - while self.alive: - self.notify() - - req_count = sum( - self.servers[srv]["requests_count"] for srv in self.servers + async def _run(self): + for sock in self.sockets: + state = dict(requests_count=0) + self._server_settings["host"] = None + self._server_settings["port"] = None + server = await serve( + sock=sock, + connections=self.connections, + state=state, + **self._server_settings ) - if self.max_requests and req_count > self.max_requests: - self.alive = False - self.log.info("Max requests exceeded, shutting down: %s", - self) - elif pid == os.getpid() and self.ppid != os.getppid(): - self.alive = False - self.log.info("Parent changed, shutting down: %s", self) - else: - await asyncio.sleep(1.0, loop=self.loop) - except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): - pass + self.servers[server] = state - @staticmethod - def _create_ssl_context(cfg): - """ Creates SSLContext instance for usage in asyncio.create_server. - See ssl.SSLSocket.__init__ for more details. - """ - ctx = ssl.SSLContext(cfg.ssl_version) - ctx.load_cert_chain(cfg.certfile, cfg.keyfile) - ctx.verify_mode = cfg.cert_reqs - if cfg.ca_certs: - ctx.load_verify_locations(cfg.ca_certs) - if cfg.ciphers: - ctx.set_ciphers(cfg.ciphers) - return ctx + async def _check_alive(self): + # If our parent changed then we shut down. + pid = os.getpid() + try: + while self.alive: + self.notify() - def init_signals(self): - # Set up signals through the event loop API. + req_count = sum( + self.servers[srv]["requests_count"] for srv in self.servers + ) + if self.max_requests and req_count > self.max_requests: + self.alive = False + self.log.info("Max requests exceeded, shutting down: %s", + self) + elif pid == os.getpid() and self.ppid != os.getppid(): + self.alive = False + self.log.info("Parent changed, shutting down: %s", self) + else: + await asyncio.sleep(1.0, loop=self.loop) + except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): + pass - self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, - signal.SIGQUIT, None) + @staticmethod + def _create_ssl_context(cfg): + """ Creates SSLContext instance for usage in asyncio.create_server. + See ssl.SSLSocket.__init__ for more details. + """ + ctx = ssl.SSLContext(cfg.ssl_version) + ctx.load_cert_chain(cfg.certfile, cfg.keyfile) + ctx.verify_mode = cfg.cert_reqs + if cfg.ca_certs: + ctx.load_verify_locations(cfg.ca_certs) + if cfg.ciphers: + ctx.set_ciphers(cfg.ciphers) + return ctx - self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, - signal.SIGTERM, None) + def init_signals(self): + # Set up signals through the event loop API. - self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, - signal.SIGINT, None) + self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, + signal.SIGQUIT, None) - self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, - signal.SIGWINCH, None) + self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, + signal.SIGTERM, None) - self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, - signal.SIGUSR1, None) + self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, + signal.SIGINT, None) - self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, - signal.SIGABRT, None) + self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, + signal.SIGWINCH, None) - # Don't let SIGTERM and SIGUSR1 disturb active requests - # by interrupting system calls - signal.siginterrupt(signal.SIGTERM, False) - signal.siginterrupt(signal.SIGUSR1, False) + self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, + signal.SIGUSR1, None) - def handle_quit(self, sig, frame): - self.alive = False - self.app.callable.is_running = False - self.cfg.worker_int(self) + self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, + signal.SIGABRT, None) - def handle_abort(self, sig, frame): - self.alive = False - self.exit_code = 1 - self.cfg.worker_abort(self) - sys.exit(1) + # Don't let SIGTERM and SIGUSR1 disturb active requests + # by interrupting system calls + signal.siginterrupt(signal.SIGTERM, False) + signal.siginterrupt(signal.SIGUSR1, False) + + def handle_quit(self, sig, frame): + self.alive = False + self.app.callable.is_running = False + self.cfg.worker_int(self) + + def handle_abort(self, sig, frame): + self.alive = False + self.exit_code = 1 + self.cfg.worker_abort(self) + sys.exit(1) diff --git a/tests/test_config.py b/tests/test_config.py index 3db35f4d..178bf12a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,20 @@ from os import environ +from pathlib import Path +from contextlib import contextmanager +from tempfile import TemporaryDirectory +from textwrap import dedent import pytest -from tempfile import NamedTemporaryFile from sanic import Sanic +@contextmanager +def temp_path(): + """ a simple cross platform replacement for NamedTemporaryFile """ + with TemporaryDirectory() as td: + yield Path(td, 'file') + + def test_load_from_object(app): class Config: not_for_config = 'should not be used' @@ -15,35 +25,38 @@ def test_load_from_object(app): assert app.config.CONFIG_VALUE == 'should be used' assert 'not_for_config' not in app.config + def test_auto_load_env(): environ["SANIC_TEST_ANSWER"] = "42" app = Sanic() assert app.config.TEST_ANSWER == 42 del environ["SANIC_TEST_ANSWER"] + def test_dont_load_env(): environ["SANIC_TEST_ANSWER"] = "42" app = Sanic(load_env=False) - assert getattr(app.config, 'TEST_ANSWER', None) == None + assert getattr(app.config, 'TEST_ANSWER', None) is None del environ["SANIC_TEST_ANSWER"] + def test_load_env_prefix(): environ["MYAPP_TEST_ANSWER"] = "42" app = Sanic(load_env='MYAPP_') assert app.config.TEST_ANSWER == 42 del environ["MYAPP_TEST_ANSWER"] + def test_load_from_file(app): - config = b""" -VALUE = 'some value' -condition = 1 == 1 -if condition: - CONDITIONAL = 'should be set' - """ - with NamedTemporaryFile() as config_file: - config_file.write(config) - config_file.seek(0) - app.config.from_pyfile(config_file.name) + config = dedent(""" + VALUE = 'some value' + condition = 1 == 1 + if condition: + CONDITIONAL = 'should be set' + """) + with temp_path() as config_path: + config_path.write_text(config) + app.config.from_pyfile(str(config_path)) assert 'VALUE' in app.config assert app.config.VALUE == 'some value' assert 'CONDITIONAL' in app.config @@ -57,11 +70,10 @@ def test_load_from_missing_file(app): def test_load_from_envvar(app): - config = b"VALUE = 'some value'" - with NamedTemporaryFile() as config_file: - config_file.write(config) - config_file.seek(0) - environ['APP_CONFIG'] = config_file.name + config = "VALUE = 'some value'" + with temp_path() as config_path: + config_path.write_text(config) + environ['APP_CONFIG'] = str(config_path) app.config.from_envvar('APP_CONFIG') assert 'VALUE' in app.config assert app.config.VALUE == 'some value' diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index 53a2872e..48a98ac5 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -7,24 +7,12 @@ from sanic.config import Config from sanic import server import aiohttp from aiohttp import TCPConnector -from sanic.testing import SanicTestClient, HOST, PORT +from sanic.testing import SanicTestClient, HOST, PORT, version -try: - try: - import packaging # direct use - except ImportError: - # setuptools v39.0 and above. - try: - from setuptools.extern import packaging - except ImportError: - # Before setuptools v39.0 - from pkg_resources.extern import packaging - version = packaging.version -except ImportError: - raise RuntimeError("The 'packaging' library is missing.") aiohttp_version = version.parse(aiohttp.__version__) + class ReuseableTCPConnector(TCPConnector): def __init__(self, *args, **kwargs): super(ReuseableTCPConnector, self).__init__(*args, **kwargs) diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index c78e19fd..d247001f 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -1,14 +1,20 @@ import multiprocessing import random import signal +import pytest from sanic.testing import HOST, PORT +@pytest.mark.skipif( + not hasattr(signal, 'SIGALRM'), + reason='SIGALRM is not implemented for this platform, we have to come ' + 'up with another timeout strategy to test these' +) def test_multiprocessing(app): """Tests that the number of children we produce is correct""" # Selects a number at random so we can spot check - num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) + num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) process_list = set() def stop_on_alarm(*args): diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index 672d0588..d31fa3d1 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -5,25 +5,13 @@ import asyncio from sanic.response import text from sanic.config import Config import aiohttp -from aiohttp import TCPConnector, ClientResponse -from sanic.testing import SanicTestClient, HOST, PORT +from aiohttp import TCPConnector +from sanic.testing import SanicTestClient, HOST, version -try: - try: - import packaging # direct use - except ImportError: - # setuptools v39.0 and above. - try: - from setuptools.extern import packaging - except ImportError: - # Before setuptools v39.0 - from pkg_resources.extern import packaging - version = packaging.version -except ImportError: - raise RuntimeError("The 'packaging' library is missing.") aiohttp_version = version.parse(aiohttp.__version__) + class DelayableTCPConnector(TCPConnector): class RequestContextManager(object): diff --git a/tests/test_response.py b/tests/test_response.py index 78f9f103..a4d38fad 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -8,9 +8,11 @@ from urllib.parse import unquote import pytest from random import choice -from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json +from sanic.response import ( + HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json +) from sanic.server import HttpProtocol -from sanic.testing import HOST, PORT +from sanic.testing import HOST, PORT, is_windows from unittest.mock import MagicMock JSON_DATA = {'ok': True} @@ -75,8 +77,11 @@ def test_response_header(app): request, response = app.test_client.get('/') assert dict(response.headers) == { 'Connection': 'keep-alive', - 'Keep-Alive': '2', - 'Content-Length': '11', + 'Keep-Alive': str(app.config.KEEP_ALIVE_TIMEOUT), + # response body contains an extra \r at the end if its windows + # TODO: this is the only place this difference shows up in our tests + # we should figure out a way to unify testing on both platforms + 'Content-Length': '12' if is_windows else '11', 'Content-Type': 'application/json', } diff --git a/tests/test_routes.py b/tests/test_routes.py index d70bf975..68f4b0c7 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -165,12 +165,13 @@ def test_route_optional_slash(app): request, response = app.test_client.get('/get/') assert response.text == 'OK' + def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): - #Part of regression test for issue #1120 + # Part of regression test for issue #1120 - site1 = 'localhost:{}'.format(app.test_client.port) + site1 = '127.0.0.1:{}'.format(app.test_client.port) - #before fix, this raises a RouteExists error + # before fix, this raises a RouteExists error @app.get('/get', host=[site1, 'site2.com'], strict_slashes=False) def handler(request): return text('OK') @@ -178,25 +179,25 @@ def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): request, response = app.test_client.get('http://' + site1 + '/get') assert response.text == 'OK' - @app.post('/post', host=[site1, 'site2.com'], strict_slashes=False) + @app.post('/post', host=[site1, 'site2.com'], strict_slashes=False) # noqa def handler(request): return text('OK') - request, response = app.test_client.post('http://' + site1 +'/post') + request, response = app.test_client.post('http://' + site1 + '/post') assert response.text == 'OK' - @app.put('/put', host=[site1, 'site2.com'], strict_slashes=False) + @app.put('/put', host=[site1, 'site2.com'], strict_slashes=False) # noqa def handler(request): return text('OK') - request, response = app.test_client.put('http://' + site1 +'/put') + request, response = app.test_client.put('http://' + site1 + '/put') assert response.text == 'OK' - @app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False) + @app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False) # noqa def handler(request): return text('OK') - request, response = app.test_client.delete('http://' + site1 +'/delete') + request, response = app.test_client.delete('http://' + site1 + '/delete') assert response.text == 'OK' def test_shorthand_routes_post(app): diff --git a/tests/test_server_events.py b/tests/test_server_events.py index 68e097eb..64972271 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -11,6 +11,12 @@ AVAILABLE_LISTENERS = [ 'after_server_stop' ] +skipif_no_alarm = pytest.mark.skipif( + not hasattr(signal, 'SIGALRM'), + reason='SIGALRM is not implemented for this platform, we have to come ' + 'up with another timeout strategy to test these' +) + def create_listener(listener_name, in_list): async def _listener(app, loop): @@ -32,6 +38,7 @@ def start_stop_app(random_name_app, **run_kwargs): pass +@skipif_no_alarm @pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS) def test_single_listener(app, listener_name): """Test that listeners on their own work""" @@ -43,6 +50,7 @@ def test_single_listener(app, listener_name): assert app.name + listener_name == output.pop() +@skipif_no_alarm @pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS) def test_register_listener(app, listener_name): """ @@ -52,12 +60,12 @@ def test_register_listener(app, listener_name): output = [] # Register listener listener = create_listener(listener_name, output) - app.register_listener(listener, - event=listener_name) + app.register_listener(listener, event=listener_name) start_stop_app(app) assert app.name + listener_name == output.pop() +@skipif_no_alarm def test_all_listeners(app): output = [] for listener_name in AVAILABLE_LISTENERS: diff --git a/tests/test_worker.py b/tests/test_worker.py index 0dcd1f38..c960168e 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -4,16 +4,27 @@ import shlex import subprocess import urllib.request from unittest import mock -from sanic.worker import GunicornWorker -from sanic.app import Sanic import asyncio -import logging import pytest +from sanic.app import Sanic +try: + from sanic.worker import GunicornWorker +except ImportError: + pytestmark = pytest.mark.skip( + reason="GunicornWorker Not supported on this platform" + ) + # this has to be defined or pytest will err on import + GunicornWorker = object @pytest.fixture(scope='module') def gunicorn_worker(): - command = 'gunicorn --bind 127.0.0.1:1337 --worker-class sanic.worker.GunicornWorker examples.simple_server:app' + command = ( + 'gunicorn ' + '--bind 127.0.0.1:1337 ' + '--worker-class sanic.worker.GunicornWorker ' + 'examples.simple_server:app' + ) worker = subprocess.Popen(shlex.split(command)) time.sleep(3) yield