From efbacc17cfd8dc8df64c744b6245e80b5fb38fed Mon Sep 17 00:00:00 2001 From: Alec Buckenheimer Date: Sat, 29 Sep 2018 13:54:47 -0400 Subject: [PATCH 1/5] unittests passing on windows again --- sanic/testing.py | 18 ++ sanic/worker.py | 319 ++++++++++++++++--------------- tests/test_config.py | 46 +++-- tests/test_keep_alive_timeout.py | 16 +- tests/test_multiprocessing.py | 8 +- tests/test_request_timeout.py | 18 +- tests/test_response.py | 13 +- tests/test_routes.py | 19 +- tests/test_server_events.py | 12 +- tests/test_worker.py | 19 +- 10 files changed, 264 insertions(+), 224 deletions(-) diff --git a/sanic/testing.py b/sanic/testing.py index 3a1d15c5..a05c9ea7 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,12 +1,30 @@ +import sys import traceback from json import JSONDecodeError from sanic.log import logger from sanic.exceptions import MethodNotSupported from sanic.response import text +try: + try: + # direct use + import packaging as _packaging + version = _packaging.version + except (ImportError, AttributeError): + # setuptools v39.0 and above. + try: + from setuptools.extern import packaging as _packaging + except ImportError: + # Before setuptools v39.0 + from pkg_resources.extern import packaging as _packaging + version = _packaging.version +except ImportError: + raise RuntimeError("The 'packaging' library is missing.") + HOST = '127.0.0.1' PORT = 42101 +is_windows = sys.platform in ['win32', 'cygwin'] class SanicTestClient: diff --git a/sanic/worker.py b/sanic/worker.py index d367a7c3..a5a39c82 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -12,199 +12,202 @@ except ImportError: try: import uvloop + import gunicorn.workers.base as base asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: + base = None pass -import gunicorn.workers.base as base from sanic.server import trigger_events, serve, HttpProtocol, Signal from sanic.websocket import WebSocketProtocol -class GunicornWorker(base.Worker): +# gunicorn is not available on windows +if base is not None: + class GunicornWorker(base.Worker): - http_protocol = HttpProtocol - websocket_protocol = WebSocketProtocol + http_protocol = HttpProtocol + websocket_protocol = WebSocketProtocol - def __init__(self, *args, **kw): # pragma: no cover - super().__init__(*args, **kw) - cfg = self.cfg - if cfg.is_ssl: - self.ssl_context = self._create_ssl_context(cfg) - else: - self.ssl_context = None - self.servers = {} - self.connections = set() - self.exit_code = 0 - self.signal = Signal() + def __init__(self, *args, **kw): # pragma: no cover + super().__init__(*args, **kw) + cfg = self.cfg + if cfg.is_ssl: + self.ssl_context = self._create_ssl_context(cfg) + else: + self.ssl_context = None + self.servers = {} + self.connections = set() + self.exit_code = 0 + self.signal = Signal() - def init_process(self): - # create new event_loop after fork - asyncio.get_event_loop().close() + def init_process(self): + # create new event_loop after fork + asyncio.get_event_loop().close() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) - super().init_process() + super().init_process() - def run(self): - is_debug = self.log.loglevel == logging.DEBUG - protocol = ( - self.websocket_protocol if self.app.callable.websocket_enabled - else self.http_protocol) - self._server_settings = self.app.callable._helper( - loop=self.loop, - debug=is_debug, - protocol=protocol, - ssl=self.ssl_context, - run_async=True) - self._server_settings['signal'] = self.signal - self._server_settings.pop('sock') - trigger_events(self._server_settings.get('before_start', []), - self.loop) - self._server_settings['before_start'] = () - - self._runner = asyncio.ensure_future(self._run(), loop=self.loop) - try: - self.loop.run_until_complete(self._runner) - self.app.callable.is_running = True - trigger_events(self._server_settings.get('after_start', []), + def run(self): + is_debug = self.log.loglevel == logging.DEBUG + protocol = ( + self.websocket_protocol if self.app.callable.websocket_enabled + else self.http_protocol) + self._server_settings = self.app.callable._helper( + loop=self.loop, + debug=is_debug, + protocol=protocol, + ssl=self.ssl_context, + run_async=True) + self._server_settings['signal'] = self.signal + self._server_settings.pop('sock') + trigger_events(self._server_settings.get('before_start', []), self.loop) - self.loop.run_until_complete(self._check_alive()) - trigger_events(self._server_settings.get('before_stop', []), - self.loop) - self.loop.run_until_complete(self.close()) - except BaseException: - traceback.print_exc() - finally: + self._server_settings['before_start'] = () + + self._runner = asyncio.ensure_future(self._run(), loop=self.loop) try: - trigger_events(self._server_settings.get('after_stop', []), + self.loop.run_until_complete(self._runner) + self.app.callable.is_running = True + trigger_events(self._server_settings.get('after_start', []), self.loop) + self.loop.run_until_complete(self._check_alive()) + trigger_events(self._server_settings.get('before_stop', []), + self.loop) + self.loop.run_until_complete(self.close()) except BaseException: traceback.print_exc() finally: - self.loop.close() + try: + trigger_events(self._server_settings.get('after_stop', []), + self.loop) + except BaseException: + traceback.print_exc() + finally: + self.loop.close() - sys.exit(self.exit_code) + sys.exit(self.exit_code) - async def close(self): - if self.servers: - # stop accepting connections - self.log.info("Stopping server: %s, connections: %s", - self.pid, len(self.connections)) - for server in self.servers: - server.close() - await server.wait_closed() - self.servers.clear() + async def close(self): + if self.servers: + # stop accepting connections + self.log.info("Stopping server: %s, connections: %s", + self.pid, len(self.connections)) + for server in self.servers: + server.close() + await server.wait_closed() + self.servers.clear() - # prepare connections for closing - self.signal.stopped = True - for conn in self.connections: - conn.close_if_idle() + # prepare connections for closing + self.signal.stopped = True + for conn in self.connections: + conn.close_if_idle() - # gracefully shutdown timeout - start_shutdown = 0 - graceful_shutdown_timeout = self.cfg.graceful_timeout - while self.connections and \ - (start_shutdown < graceful_shutdown_timeout): - await asyncio.sleep(0.1) - start_shutdown = start_shutdown + 0.1 + # gracefully shutdown timeout + start_shutdown = 0 + graceful_shutdown_timeout = self.cfg.graceful_timeout + while self.connections and \ + (start_shutdown < graceful_shutdown_timeout): + await asyncio.sleep(0.1) + start_shutdown = start_shutdown + 0.1 - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - coros = [] - for conn in self.connections: - if hasattr(conn, "websocket") and conn.websocket: - coros.append( - conn.websocket.close_connection() - ) - else: - conn.close() - _shutdown = asyncio.gather(*coros, loop=self.loop) - await _shutdown + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + coros = [] + for conn in self.connections: + if hasattr(conn, "websocket") and conn.websocket: + coros.append( + conn.websocket.close_connection() + ) + else: + conn.close() + _shutdown = asyncio.gather(*coros, loop=self.loop) + await _shutdown - async def _run(self): - for sock in self.sockets: - state = dict(requests_count=0) - self._server_settings["host"] = None - self._server_settings["port"] = None - server = await serve( - sock=sock, - connections=self.connections, - state=state, - **self._server_settings - ) - self.servers[server] = state - - async def _check_alive(self): - # If our parent changed then we shut down. - pid = os.getpid() - try: - while self.alive: - self.notify() - - req_count = sum( - self.servers[srv]["requests_count"] for srv in self.servers + async def _run(self): + for sock in self.sockets: + state = dict(requests_count=0) + self._server_settings["host"] = None + self._server_settings["port"] = None + server = await serve( + sock=sock, + connections=self.connections, + state=state, + **self._server_settings ) - if self.max_requests and req_count > self.max_requests: - self.alive = False - self.log.info("Max requests exceeded, shutting down: %s", - self) - elif pid == os.getpid() and self.ppid != os.getppid(): - self.alive = False - self.log.info("Parent changed, shutting down: %s", self) - else: - await asyncio.sleep(1.0, loop=self.loop) - except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): - pass + self.servers[server] = state - @staticmethod - def _create_ssl_context(cfg): - """ Creates SSLContext instance for usage in asyncio.create_server. - See ssl.SSLSocket.__init__ for more details. - """ - ctx = ssl.SSLContext(cfg.ssl_version) - ctx.load_cert_chain(cfg.certfile, cfg.keyfile) - ctx.verify_mode = cfg.cert_reqs - if cfg.ca_certs: - ctx.load_verify_locations(cfg.ca_certs) - if cfg.ciphers: - ctx.set_ciphers(cfg.ciphers) - return ctx + async def _check_alive(self): + # If our parent changed then we shut down. + pid = os.getpid() + try: + while self.alive: + self.notify() - def init_signals(self): - # Set up signals through the event loop API. + req_count = sum( + self.servers[srv]["requests_count"] for srv in self.servers + ) + if self.max_requests and req_count > self.max_requests: + self.alive = False + self.log.info("Max requests exceeded, shutting down: %s", + self) + elif pid == os.getpid() and self.ppid != os.getppid(): + self.alive = False + self.log.info("Parent changed, shutting down: %s", self) + else: + await asyncio.sleep(1.0, loop=self.loop) + except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): + pass - self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, - signal.SIGQUIT, None) + @staticmethod + def _create_ssl_context(cfg): + """ Creates SSLContext instance for usage in asyncio.create_server. + See ssl.SSLSocket.__init__ for more details. + """ + ctx = ssl.SSLContext(cfg.ssl_version) + ctx.load_cert_chain(cfg.certfile, cfg.keyfile) + ctx.verify_mode = cfg.cert_reqs + if cfg.ca_certs: + ctx.load_verify_locations(cfg.ca_certs) + if cfg.ciphers: + ctx.set_ciphers(cfg.ciphers) + return ctx - self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, - signal.SIGTERM, None) + def init_signals(self): + # Set up signals through the event loop API. - self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, - signal.SIGINT, None) + self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, + signal.SIGQUIT, None) - self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, - signal.SIGWINCH, None) + self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, + signal.SIGTERM, None) - self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, - signal.SIGUSR1, None) + self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, + signal.SIGINT, None) - self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, - signal.SIGABRT, None) + self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, + signal.SIGWINCH, None) - # Don't let SIGTERM and SIGUSR1 disturb active requests - # by interrupting system calls - signal.siginterrupt(signal.SIGTERM, False) - signal.siginterrupt(signal.SIGUSR1, False) + self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, + signal.SIGUSR1, None) - def handle_quit(self, sig, frame): - self.alive = False - self.app.callable.is_running = False - self.cfg.worker_int(self) + self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, + signal.SIGABRT, None) - def handle_abort(self, sig, frame): - self.alive = False - self.exit_code = 1 - self.cfg.worker_abort(self) - sys.exit(1) + # Don't let SIGTERM and SIGUSR1 disturb active requests + # by interrupting system calls + signal.siginterrupt(signal.SIGTERM, False) + signal.siginterrupt(signal.SIGUSR1, False) + + def handle_quit(self, sig, frame): + self.alive = False + self.app.callable.is_running = False + self.cfg.worker_int(self) + + def handle_abort(self, sig, frame): + self.alive = False + self.exit_code = 1 + self.cfg.worker_abort(self) + sys.exit(1) diff --git a/tests/test_config.py b/tests/test_config.py index 3db35f4d..178bf12a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,20 @@ from os import environ +from pathlib import Path +from contextlib import contextmanager +from tempfile import TemporaryDirectory +from textwrap import dedent import pytest -from tempfile import NamedTemporaryFile from sanic import Sanic +@contextmanager +def temp_path(): + """ a simple cross platform replacement for NamedTemporaryFile """ + with TemporaryDirectory() as td: + yield Path(td, 'file') + + def test_load_from_object(app): class Config: not_for_config = 'should not be used' @@ -15,35 +25,38 @@ def test_load_from_object(app): assert app.config.CONFIG_VALUE == 'should be used' assert 'not_for_config' not in app.config + def test_auto_load_env(): environ["SANIC_TEST_ANSWER"] = "42" app = Sanic() assert app.config.TEST_ANSWER == 42 del environ["SANIC_TEST_ANSWER"] + def test_dont_load_env(): environ["SANIC_TEST_ANSWER"] = "42" app = Sanic(load_env=False) - assert getattr(app.config, 'TEST_ANSWER', None) == None + assert getattr(app.config, 'TEST_ANSWER', None) is None del environ["SANIC_TEST_ANSWER"] + def test_load_env_prefix(): environ["MYAPP_TEST_ANSWER"] = "42" app = Sanic(load_env='MYAPP_') assert app.config.TEST_ANSWER == 42 del environ["MYAPP_TEST_ANSWER"] + def test_load_from_file(app): - config = b""" -VALUE = 'some value' -condition = 1 == 1 -if condition: - CONDITIONAL = 'should be set' - """ - with NamedTemporaryFile() as config_file: - config_file.write(config) - config_file.seek(0) - app.config.from_pyfile(config_file.name) + config = dedent(""" + VALUE = 'some value' + condition = 1 == 1 + if condition: + CONDITIONAL = 'should be set' + """) + with temp_path() as config_path: + config_path.write_text(config) + app.config.from_pyfile(str(config_path)) assert 'VALUE' in app.config assert app.config.VALUE == 'some value' assert 'CONDITIONAL' in app.config @@ -57,11 +70,10 @@ def test_load_from_missing_file(app): def test_load_from_envvar(app): - config = b"VALUE = 'some value'" - with NamedTemporaryFile() as config_file: - config_file.write(config) - config_file.seek(0) - environ['APP_CONFIG'] = config_file.name + config = "VALUE = 'some value'" + with temp_path() as config_path: + config_path.write_text(config) + environ['APP_CONFIG'] = str(config_path) app.config.from_envvar('APP_CONFIG') assert 'VALUE' in app.config assert app.config.VALUE == 'some value' diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index 53a2872e..48a98ac5 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -7,24 +7,12 @@ from sanic.config import Config from sanic import server import aiohttp from aiohttp import TCPConnector -from sanic.testing import SanicTestClient, HOST, PORT +from sanic.testing import SanicTestClient, HOST, PORT, version -try: - try: - import packaging # direct use - except ImportError: - # setuptools v39.0 and above. - try: - from setuptools.extern import packaging - except ImportError: - # Before setuptools v39.0 - from pkg_resources.extern import packaging - version = packaging.version -except ImportError: - raise RuntimeError("The 'packaging' library is missing.") aiohttp_version = version.parse(aiohttp.__version__) + class ReuseableTCPConnector(TCPConnector): def __init__(self, *args, **kwargs): super(ReuseableTCPConnector, self).__init__(*args, **kwargs) diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index c78e19fd..d247001f 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -1,14 +1,20 @@ import multiprocessing import random import signal +import pytest from sanic.testing import HOST, PORT +@pytest.mark.skipif( + not hasattr(signal, 'SIGALRM'), + reason='SIGALRM is not implemented for this platform, we have to come ' + 'up with another timeout strategy to test these' +) def test_multiprocessing(app): """Tests that the number of children we produce is correct""" # Selects a number at random so we can spot check - num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) + num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) process_list = set() def stop_on_alarm(*args): diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index 672d0588..d31fa3d1 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -5,25 +5,13 @@ import asyncio from sanic.response import text from sanic.config import Config import aiohttp -from aiohttp import TCPConnector, ClientResponse -from sanic.testing import SanicTestClient, HOST, PORT +from aiohttp import TCPConnector +from sanic.testing import SanicTestClient, HOST, version -try: - try: - import packaging # direct use - except ImportError: - # setuptools v39.0 and above. - try: - from setuptools.extern import packaging - except ImportError: - # Before setuptools v39.0 - from pkg_resources.extern import packaging - version = packaging.version -except ImportError: - raise RuntimeError("The 'packaging' library is missing.") aiohttp_version = version.parse(aiohttp.__version__) + class DelayableTCPConnector(TCPConnector): class RequestContextManager(object): diff --git a/tests/test_response.py b/tests/test_response.py index 78f9f103..a4d38fad 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -8,9 +8,11 @@ from urllib.parse import unquote import pytest from random import choice -from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json +from sanic.response import ( + HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json +) from sanic.server import HttpProtocol -from sanic.testing import HOST, PORT +from sanic.testing import HOST, PORT, is_windows from unittest.mock import MagicMock JSON_DATA = {'ok': True} @@ -75,8 +77,11 @@ def test_response_header(app): request, response = app.test_client.get('/') assert dict(response.headers) == { 'Connection': 'keep-alive', - 'Keep-Alive': '2', - 'Content-Length': '11', + 'Keep-Alive': str(app.config.KEEP_ALIVE_TIMEOUT), + # response body contains an extra \r at the end if its windows + # TODO: this is the only place this difference shows up in our tests + # we should figure out a way to unify testing on both platforms + 'Content-Length': '12' if is_windows else '11', 'Content-Type': 'application/json', } diff --git a/tests/test_routes.py b/tests/test_routes.py index d70bf975..68f4b0c7 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -165,12 +165,13 @@ def test_route_optional_slash(app): request, response = app.test_client.get('/get/') assert response.text == 'OK' + def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): - #Part of regression test for issue #1120 + # Part of regression test for issue #1120 - site1 = 'localhost:{}'.format(app.test_client.port) + site1 = '127.0.0.1:{}'.format(app.test_client.port) - #before fix, this raises a RouteExists error + # before fix, this raises a RouteExists error @app.get('/get', host=[site1, 'site2.com'], strict_slashes=False) def handler(request): return text('OK') @@ -178,25 +179,25 @@ def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): request, response = app.test_client.get('http://' + site1 + '/get') assert response.text == 'OK' - @app.post('/post', host=[site1, 'site2.com'], strict_slashes=False) + @app.post('/post', host=[site1, 'site2.com'], strict_slashes=False) # noqa def handler(request): return text('OK') - request, response = app.test_client.post('http://' + site1 +'/post') + request, response = app.test_client.post('http://' + site1 + '/post') assert response.text == 'OK' - @app.put('/put', host=[site1, 'site2.com'], strict_slashes=False) + @app.put('/put', host=[site1, 'site2.com'], strict_slashes=False) # noqa def handler(request): return text('OK') - request, response = app.test_client.put('http://' + site1 +'/put') + request, response = app.test_client.put('http://' + site1 + '/put') assert response.text == 'OK' - @app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False) + @app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False) # noqa def handler(request): return text('OK') - request, response = app.test_client.delete('http://' + site1 +'/delete') + request, response = app.test_client.delete('http://' + site1 + '/delete') assert response.text == 'OK' def test_shorthand_routes_post(app): diff --git a/tests/test_server_events.py b/tests/test_server_events.py index 68e097eb..64972271 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -11,6 +11,12 @@ AVAILABLE_LISTENERS = [ 'after_server_stop' ] +skipif_no_alarm = pytest.mark.skipif( + not hasattr(signal, 'SIGALRM'), + reason='SIGALRM is not implemented for this platform, we have to come ' + 'up with another timeout strategy to test these' +) + def create_listener(listener_name, in_list): async def _listener(app, loop): @@ -32,6 +38,7 @@ def start_stop_app(random_name_app, **run_kwargs): pass +@skipif_no_alarm @pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS) def test_single_listener(app, listener_name): """Test that listeners on their own work""" @@ -43,6 +50,7 @@ def test_single_listener(app, listener_name): assert app.name + listener_name == output.pop() +@skipif_no_alarm @pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS) def test_register_listener(app, listener_name): """ @@ -52,12 +60,12 @@ def test_register_listener(app, listener_name): output = [] # Register listener listener = create_listener(listener_name, output) - app.register_listener(listener, - event=listener_name) + app.register_listener(listener, event=listener_name) start_stop_app(app) assert app.name + listener_name == output.pop() +@skipif_no_alarm def test_all_listeners(app): output = [] for listener_name in AVAILABLE_LISTENERS: diff --git a/tests/test_worker.py b/tests/test_worker.py index 0dcd1f38..c960168e 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -4,16 +4,27 @@ import shlex import subprocess import urllib.request from unittest import mock -from sanic.worker import GunicornWorker -from sanic.app import Sanic import asyncio -import logging import pytest +from sanic.app import Sanic +try: + from sanic.worker import GunicornWorker +except ImportError: + pytestmark = pytest.mark.skip( + reason="GunicornWorker Not supported on this platform" + ) + # this has to be defined or pytest will err on import + GunicornWorker = object @pytest.fixture(scope='module') def gunicorn_worker(): - command = 'gunicorn --bind 127.0.0.1:1337 --worker-class sanic.worker.GunicornWorker examples.simple_server:app' + command = ( + 'gunicorn ' + '--bind 127.0.0.1:1337 ' + '--worker-class sanic.worker.GunicornWorker ' + 'examples.simple_server:app' + ) worker = subprocess.Popen(shlex.split(command)) time.sleep(3) yield From 9a08bdae4a6c9385f9292a06b2256893c7b82cb6 Mon Sep 17 00:00:00 2001 From: Alec Buckenheimer Date: Mon, 1 Oct 2018 09:46:18 -0400 Subject: [PATCH 2/5] fix flake8 linelength errors --- sanic/worker.py | 15 ++++++++++----- tests/test_routes.py | 15 ++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sanic/worker.py b/sanic/worker.py index a5a39c82..05319d1f 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -147,18 +147,23 @@ if base is not None: self.notify() req_count = sum( - self.servers[srv]["requests_count"] for srv in self.servers + srv['requests_count'] for srv in self.servers.values() ) if self.max_requests and req_count > self.max_requests: self.alive = False - self.log.info("Max requests exceeded, shutting down: %s", - self) + self.log.info( + "Max requests exceeded, shutting down: %s", + self + ) elif pid == os.getpid() and self.ppid != os.getppid(): self.alive = False - self.log.info("Parent changed, shutting down: %s", self) + self.log.info( + "Parent changed, shutting down: %s", + self + ) else: await asyncio.sleep(1.0, loop=self.loop) - except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): + except (BaseException, GeneratorExit, KeyboardInterrupt): pass @staticmethod diff --git a/tests/test_routes.py b/tests/test_routes.py index 68f4b0c7..d5f1c90a 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -173,33 +173,34 @@ def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): # before fix, this raises a RouteExists error @app.get('/get', host=[site1, 'site2.com'], strict_slashes=False) - def handler(request): + def get_handler(request): return text('OK') request, response = app.test_client.get('http://' + site1 + '/get') assert response.text == 'OK' - @app.post('/post', host=[site1, 'site2.com'], strict_slashes=False) # noqa - def handler(request): + @app.post('/post', host=[site1, 'site2.com'], strict_slashes=False) + def post_handler(request): return text('OK') request, response = app.test_client.post('http://' + site1 + '/post') assert response.text == 'OK' - @app.put('/put', host=[site1, 'site2.com'], strict_slashes=False) # noqa - def handler(request): + @app.put('/put', host=[site1, 'site2.com'], strict_slashes=False) + def put_handler(request): return text('OK') request, response = app.test_client.put('http://' + site1 + '/put') assert response.text == 'OK' - @app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False) # noqa - def handler(request): + @app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False) + def delete_handler(request): return text('OK') request, response = app.test_client.delete('http://' + site1 + '/delete') assert response.text == 'OK' + def test_shorthand_routes_post(app): @app.post('/post') From a16842f7bc3b5d6cff546ed20ee752656f1ade71 Mon Sep 17 00:00:00 2001 From: Lewis Date: Mon, 8 Oct 2018 18:59:15 +0900 Subject: [PATCH 3/5] Fix missing quotes in decorator example --- docs/sanic/decorators.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sanic/decorators.md b/docs/sanic/decorators.md index db2d369b..68495a22 100644 --- a/docs/sanic/decorators.md +++ b/docs/sanic/decorators.md @@ -34,6 +34,6 @@ def authorized(): @app.route("/") @authorized() async def test(request): - return json({status: 'authorized'}) + return json({'status': 'authorized'}) ``` From c3b31a6fb0e9a1220bc1897e782080773802f2f3 Mon Sep 17 00:00:00 2001 From: Arun Babu Neelicattu Date: Mon, 8 Oct 2018 19:29:52 +0200 Subject: [PATCH 4/5] Simplify request ip and port retrieval logic This change also ensures that cases where transport stream is already closed is handled gracefully. --- sanic/request.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/sanic/request.py b/sanic/request.py index c8b470d4..70240207 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -1,6 +1,5 @@ import sys import json -import socket from cgi import parse_header from collections import namedtuple from http.cookies import SimpleCookie @@ -192,18 +191,10 @@ class Request(dict): return self._socket def _get_address(self): - sock = self.transport.get_extra_info('socket') - - if sock.family == socket.AF_INET: - self._socket = (self.transport.get_extra_info('peername') or - (None, None)) - self._ip, self._port = self._socket - elif sock.family == socket.AF_INET6: - self._socket = (self.transport.get_extra_info('peername') or - (None, None, None, None)) - self._ip, self._port, *_ = self._socket - else: - self._ip, self._port = (None, None) + self._socket = self.transport.get_extra_info('peername') or \ + (None, None) + self._ip = self._socket[0] + self._port = self._socket[1] @property def remote_addr(self): From b7d74c82ba2e53b78eac25314e2161b0c782cd50 Mon Sep 17 00:00:00 2001 From: Alec Buckenheimer Date: Mon, 8 Oct 2018 22:40:36 -0400 Subject: [PATCH 5/5] simplified aiohttp version diffs, reverted worker import policy --- sanic/testing.py | 18 -- sanic/worker.py | 326 +++++++++++++++---------------- tests/conftest.py | 4 + tests/test_keep_alive_timeout.py | 52 ++--- tests/test_request_timeout.py | 138 ++++++------- tests/test_response.py | 4 +- tests/test_worker.py | 22 +-- 7 files changed, 243 insertions(+), 321 deletions(-) diff --git a/sanic/testing.py b/sanic/testing.py index a05c9ea7..3a1d15c5 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,30 +1,12 @@ -import sys import traceback from json import JSONDecodeError from sanic.log import logger from sanic.exceptions import MethodNotSupported from sanic.response import text -try: - try: - # direct use - import packaging as _packaging - version = _packaging.version - except (ImportError, AttributeError): - # setuptools v39.0 and above. - try: - from setuptools.extern import packaging as _packaging - except ImportError: - # Before setuptools v39.0 - from pkg_resources.extern import packaging as _packaging - version = _packaging.version -except ImportError: - raise RuntimeError("The 'packaging' library is missing.") - HOST = '127.0.0.1' PORT = 42101 -is_windows = sys.platform in ['win32', 'cygwin'] class SanicTestClient: diff --git a/sanic/worker.py b/sanic/worker.py index 05319d1f..d367a7c3 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -12,207 +12,199 @@ except ImportError: try: import uvloop - import gunicorn.workers.base as base asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: - base = None pass +import gunicorn.workers.base as base from sanic.server import trigger_events, serve, HttpProtocol, Signal from sanic.websocket import WebSocketProtocol -# gunicorn is not available on windows -if base is not None: - class GunicornWorker(base.Worker): +class GunicornWorker(base.Worker): - http_protocol = HttpProtocol - websocket_protocol = WebSocketProtocol + http_protocol = HttpProtocol + websocket_protocol = WebSocketProtocol - def __init__(self, *args, **kw): # pragma: no cover - super().__init__(*args, **kw) - cfg = self.cfg - if cfg.is_ssl: - self.ssl_context = self._create_ssl_context(cfg) - else: - self.ssl_context = None - self.servers = {} - self.connections = set() - self.exit_code = 0 - self.signal = Signal() + def __init__(self, *args, **kw): # pragma: no cover + super().__init__(*args, **kw) + cfg = self.cfg + if cfg.is_ssl: + self.ssl_context = self._create_ssl_context(cfg) + else: + self.ssl_context = None + self.servers = {} + self.connections = set() + self.exit_code = 0 + self.signal = Signal() - def init_process(self): - # create new event_loop after fork - asyncio.get_event_loop().close() + def init_process(self): + # create new event_loop after fork + asyncio.get_event_loop().close() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) - super().init_process() + super().init_process() - def run(self): - is_debug = self.log.loglevel == logging.DEBUG - protocol = ( - self.websocket_protocol if self.app.callable.websocket_enabled - else self.http_protocol) - self._server_settings = self.app.callable._helper( - loop=self.loop, - debug=is_debug, - protocol=protocol, - ssl=self.ssl_context, - run_async=True) - self._server_settings['signal'] = self.signal - self._server_settings.pop('sock') - trigger_events(self._server_settings.get('before_start', []), + def run(self): + is_debug = self.log.loglevel == logging.DEBUG + protocol = ( + self.websocket_protocol if self.app.callable.websocket_enabled + else self.http_protocol) + self._server_settings = self.app.callable._helper( + loop=self.loop, + debug=is_debug, + protocol=protocol, + ssl=self.ssl_context, + run_async=True) + self._server_settings['signal'] = self.signal + self._server_settings.pop('sock') + trigger_events(self._server_settings.get('before_start', []), + self.loop) + self._server_settings['before_start'] = () + + self._runner = asyncio.ensure_future(self._run(), loop=self.loop) + try: + self.loop.run_until_complete(self._runner) + self.app.callable.is_running = True + trigger_events(self._server_settings.get('after_start', []), self.loop) - self._server_settings['before_start'] = () - - self._runner = asyncio.ensure_future(self._run(), loop=self.loop) + self.loop.run_until_complete(self._check_alive()) + trigger_events(self._server_settings.get('before_stop', []), + self.loop) + self.loop.run_until_complete(self.close()) + except BaseException: + traceback.print_exc() + finally: try: - self.loop.run_until_complete(self._runner) - self.app.callable.is_running = True - trigger_events(self._server_settings.get('after_start', []), + trigger_events(self._server_settings.get('after_stop', []), self.loop) - self.loop.run_until_complete(self._check_alive()) - trigger_events(self._server_settings.get('before_stop', []), - self.loop) - self.loop.run_until_complete(self.close()) except BaseException: traceback.print_exc() finally: - try: - trigger_events(self._server_settings.get('after_stop', []), - self.loop) - except BaseException: - traceback.print_exc() - finally: - self.loop.close() + self.loop.close() - sys.exit(self.exit_code) + sys.exit(self.exit_code) - async def close(self): - if self.servers: - # stop accepting connections - self.log.info("Stopping server: %s, connections: %s", - self.pid, len(self.connections)) - for server in self.servers: - server.close() - await server.wait_closed() - self.servers.clear() + async def close(self): + if self.servers: + # stop accepting connections + self.log.info("Stopping server: %s, connections: %s", + self.pid, len(self.connections)) + for server in self.servers: + server.close() + await server.wait_closed() + self.servers.clear() - # prepare connections for closing - self.signal.stopped = True - for conn in self.connections: - conn.close_if_idle() + # prepare connections for closing + self.signal.stopped = True + for conn in self.connections: + conn.close_if_idle() - # gracefully shutdown timeout - start_shutdown = 0 - graceful_shutdown_timeout = self.cfg.graceful_timeout - while self.connections and \ - (start_shutdown < graceful_shutdown_timeout): - await asyncio.sleep(0.1) - start_shutdown = start_shutdown + 0.1 + # gracefully shutdown timeout + start_shutdown = 0 + graceful_shutdown_timeout = self.cfg.graceful_timeout + while self.connections and \ + (start_shutdown < graceful_shutdown_timeout): + await asyncio.sleep(0.1) + start_shutdown = start_shutdown + 0.1 - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - coros = [] - for conn in self.connections: - if hasattr(conn, "websocket") and conn.websocket: - coros.append( - conn.websocket.close_connection() - ) - else: - conn.close() - _shutdown = asyncio.gather(*coros, loop=self.loop) - await _shutdown - - async def _run(self): - for sock in self.sockets: - state = dict(requests_count=0) - self._server_settings["host"] = None - self._server_settings["port"] = None - server = await serve( - sock=sock, - connections=self.connections, - state=state, - **self._server_settings - ) - self.servers[server] = state - - async def _check_alive(self): - # If our parent changed then we shut down. - pid = os.getpid() - try: - while self.alive: - self.notify() - - req_count = sum( - srv['requests_count'] for srv in self.servers.values() + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + coros = [] + for conn in self.connections: + if hasattr(conn, "websocket") and conn.websocket: + coros.append( + conn.websocket.close_connection() ) - if self.max_requests and req_count > self.max_requests: - self.alive = False - self.log.info( - "Max requests exceeded, shutting down: %s", - self - ) - elif pid == os.getpid() and self.ppid != os.getppid(): - self.alive = False - self.log.info( - "Parent changed, shutting down: %s", - self - ) - else: - await asyncio.sleep(1.0, loop=self.loop) - except (BaseException, GeneratorExit, KeyboardInterrupt): - pass + else: + conn.close() + _shutdown = asyncio.gather(*coros, loop=self.loop) + await _shutdown - @staticmethod - def _create_ssl_context(cfg): - """ Creates SSLContext instance for usage in asyncio.create_server. - See ssl.SSLSocket.__init__ for more details. - """ - ctx = ssl.SSLContext(cfg.ssl_version) - ctx.load_cert_chain(cfg.certfile, cfg.keyfile) - ctx.verify_mode = cfg.cert_reqs - if cfg.ca_certs: - ctx.load_verify_locations(cfg.ca_certs) - if cfg.ciphers: - ctx.set_ciphers(cfg.ciphers) - return ctx + async def _run(self): + for sock in self.sockets: + state = dict(requests_count=0) + self._server_settings["host"] = None + self._server_settings["port"] = None + server = await serve( + sock=sock, + connections=self.connections, + state=state, + **self._server_settings + ) + self.servers[server] = state - def init_signals(self): - # Set up signals through the event loop API. + async def _check_alive(self): + # If our parent changed then we shut down. + pid = os.getpid() + try: + while self.alive: + self.notify() - self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, - signal.SIGQUIT, None) + req_count = sum( + self.servers[srv]["requests_count"] for srv in self.servers + ) + if self.max_requests and req_count > self.max_requests: + self.alive = False + self.log.info("Max requests exceeded, shutting down: %s", + self) + elif pid == os.getpid() and self.ppid != os.getppid(): + self.alive = False + self.log.info("Parent changed, shutting down: %s", self) + else: + await asyncio.sleep(1.0, loop=self.loop) + except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): + pass - self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, - signal.SIGTERM, None) + @staticmethod + def _create_ssl_context(cfg): + """ Creates SSLContext instance for usage in asyncio.create_server. + See ssl.SSLSocket.__init__ for more details. + """ + ctx = ssl.SSLContext(cfg.ssl_version) + ctx.load_cert_chain(cfg.certfile, cfg.keyfile) + ctx.verify_mode = cfg.cert_reqs + if cfg.ca_certs: + ctx.load_verify_locations(cfg.ca_certs) + if cfg.ciphers: + ctx.set_ciphers(cfg.ciphers) + return ctx - self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, - signal.SIGINT, None) + def init_signals(self): + # Set up signals through the event loop API. - self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, - signal.SIGWINCH, None) + self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, + signal.SIGQUIT, None) - self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, - signal.SIGUSR1, None) + self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, + signal.SIGTERM, None) - self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, - signal.SIGABRT, None) + self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, + signal.SIGINT, None) - # Don't let SIGTERM and SIGUSR1 disturb active requests - # by interrupting system calls - signal.siginterrupt(signal.SIGTERM, False) - signal.siginterrupt(signal.SIGUSR1, False) + self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, + signal.SIGWINCH, None) - def handle_quit(self, sig, frame): - self.alive = False - self.app.callable.is_running = False - self.cfg.worker_int(self) + self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, + signal.SIGUSR1, None) - def handle_abort(self, sig, frame): - self.alive = False - self.exit_code = 1 - self.cfg.worker_abort(self) - sys.exit(1) + self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, + signal.SIGABRT, None) + + # Don't let SIGTERM and SIGUSR1 disturb active requests + # by interrupting system calls + signal.siginterrupt(signal.SIGTERM, False) + signal.siginterrupt(signal.SIGUSR1, False) + + def handle_quit(self, sig, frame): + self.alive = False + self.app.callable.is_running = False + self.cfg.worker_int(self) + + def handle_abort(self, sig, frame): + self.alive = False + self.exit_code = 1 + self.cfg.worker_abort(self) + sys.exit(1) diff --git a/tests/conftest.py b/tests/conftest.py index 5844e3a1..ac47aceb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,11 @@ +import sys import pytest from sanic import Sanic +if sys.platform in ['win32', 'cygwin']: + collect_ignore = ["test_worker.py"] + @pytest.fixture def app(request): diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index 48a98ac5..ca8cbd0a 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -7,10 +7,7 @@ from sanic.config import Config from sanic import server import aiohttp from aiohttp import TCPConnector -from sanic.testing import SanicTestClient, HOST, PORT, version - - -aiohttp_version = version.parse(aiohttp.__version__) +from sanic.testing import SanicTestClient, HOST, PORT class ReuseableTCPConnector(TCPConnector): @@ -18,39 +15,16 @@ class ReuseableTCPConnector(TCPConnector): super(ReuseableTCPConnector, self).__init__(*args, **kwargs) self.old_proto = None - if aiohttp_version >= version.parse('3.3.0'): - async def connect(self, req, traces, timeout): - new_conn = await super(ReuseableTCPConnector, self)\ - .connect(req, traces, timeout) - if self.old_proto is not None: - if self.old_proto != new_conn._protocol: - raise RuntimeError( - "We got a new connection, wanted the same one!") - print(new_conn.__dict__) - self.old_proto = new_conn._protocol - return new_conn - elif aiohttp_version >= version.parse('3.0.0'): - async def connect(self, req, traces=None): - new_conn = await super(ReuseableTCPConnector, self)\ - .connect(req, traces=traces) - if self.old_proto is not None: - if self.old_proto != new_conn._protocol: - raise RuntimeError( - "We got a new connection, wanted the same one!") - print(new_conn.__dict__) - self.old_proto = new_conn._protocol - return new_conn - else: - async def connect(self, req): - new_conn = await super(ReuseableTCPConnector, self)\ - .connect(req) - if self.old_proto is not None: - if self.old_proto != new_conn._protocol: - raise RuntimeError( - "We got a new connection, wanted the same one!") - print(new_conn.__dict__) - self.old_proto = new_conn._protocol - return new_conn + async def connect(self, req, *args, **kwargs): + new_conn = await super(ReuseableTCPConnector, self)\ + .connect(req, *args, **kwargs) + if self.old_proto is not None: + if self.old_proto != new_conn._protocol: + raise RuntimeError( + "We got a new connection, wanted the same one!") + print(new_conn.__dict__) + self.old_proto = new_conn._protocol + return new_conn class ReuseableSanicTestClient(SanicTestClient): @@ -135,14 +109,14 @@ class ReuseableSanicTestClient(SanicTestClient): try: request, response = results return request, response - except: + except Exception: raise ValueError( "Request and response object expected, got ({})".format( results)) else: try: return results[-1] - except: + except Exception: raise ValueError( "Request object expected, got ({})".format(results)) diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index d31fa3d1..f81ed61c 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -6,7 +6,23 @@ from sanic.response import text from sanic.config import Config import aiohttp from aiohttp import TCPConnector -from sanic.testing import SanicTestClient, HOST, version +from sanic.testing import SanicTestClient, HOST + +try: + try: + # direct use + import packaging + version = packaging.version + except (ImportError, AttributeError): + # setuptools v39.0 and above. + try: + from setuptools.extern import packaging + except ImportError: + # Before setuptools v39.0 + from pkg_resources.extern import packaging + version = packaging.version +except ImportError: + raise RuntimeError("The 'packaging' library is missing.") aiohttp_version = version.parse(aiohttp.__version__) @@ -44,8 +60,7 @@ class DelayableTCPConnector(TCPConnector): if aiohttp_version >= version.parse("3.3.0"): ret = await self.orig_start(connection) else: - ret = await self.orig_start(connection, - read_until_eof) + ret = await self.orig_start(connection, read_until_eof) except Exception as e: raise e return ret @@ -59,57 +74,43 @@ class DelayableTCPConnector(TCPConnector): async def delayed_send(self, *args, **kwargs): req = self.req if self.delay and self.delay > 0: - #sync_sleep(self.delay) + # sync_sleep(self.delay) await asyncio.sleep(self.delay) t = req.loop.time() print("sending at {}".format(t), flush=True) conn = next(iter(args)) # first arg is connection - if aiohttp_version >= version.parse("3.1.0"): - try: - delayed_resp = await self.orig_send(*args, **kwargs) - except Exception as e: - if aiohttp_version >= version.parse("3.3.0"): - return aiohttp.ClientResponse(req.method, req.url, - writer=None, - continue100=None, - timer=None, - request_info=None, - traces=[], - loop=req.loop, - session=None) - else: - return aiohttp.ClientResponse(req.method, req.url, - writer=None, - continue100=None, - timer=None, - request_info=None, - auto_decompress=None, - traces=[], - loop=req.loop, - session=None) - else: - try: - delayed_resp = self.orig_send(*args, **kwargs) - except Exception as e: + try: + return await self.orig_send(*args, **kwargs) + except Exception as e: + if aiohttp_version < version.parse("3.1.0"): return aiohttp.ClientResponse(req.method, req.url) - return delayed_resp + kw = dict( + writer=None, + continue100=None, + timer=None, + request_info=None, + traces=[], + loop=req.loop, + session=None + ) + if aiohttp_version < version.parse("3.3.0"): + kw['auto_decompress'] = None + return aiohttp.ClientResponse(req.method, req.url, **kw) + + def _send(self, *args, **kwargs): + gen = self.delayed_send(*args, **kwargs) + task = self.req.loop.create_task(gen) + self.send_task = task + self._acting_as = task + return self if aiohttp_version >= version.parse("3.1.0"): # aiohttp changed the request.send method to async async def send(self, *args, **kwargs): - gen = self.delayed_send(*args, **kwargs) - task = self.req.loop.create_task(gen) - self.send_task = task - self._acting_as = task - return self + return self._send(*args, **kwargs) else: - def send(self, *args, **kwargs): - gen = self.delayed_send(*args, **kwargs) - task = self.req.loop.create_task(gen) - self.send_task = task - self._acting_as = task - return self + send = _send def __init__(self, *args, **kwargs): _post_connect_delay = kwargs.pop('post_connect_delay', 0) @@ -118,45 +119,18 @@ class DelayableTCPConnector(TCPConnector): self._post_connect_delay = _post_connect_delay self._pre_request_delay = _pre_request_delay - if aiohttp_version >= version.parse("3.3.0"): - async def connect(self, req, traces, timeout): - d_req = DelayableTCPConnector.\ - RequestContextManager(req, self._pre_request_delay) - conn = await super(DelayableTCPConnector, self).\ - connect(req, traces, timeout) - if self._post_connect_delay and self._post_connect_delay > 0: - await asyncio.sleep(self._post_connect_delay, - loop=self._loop) - req.send = d_req.send - t = req.loop.time() - print("Connected at {}".format(t), flush=True) - return conn - elif aiohttp_version >= version.parse("3.0.0"): - async def connect(self, req, traces=None): - d_req = DelayableTCPConnector.\ - RequestContextManager(req, self._pre_request_delay) - conn = await super(DelayableTCPConnector, self).\ - connect(req, traces=traces) - if self._post_connect_delay and self._post_connect_delay > 0: - await asyncio.sleep(self._post_connect_delay, - loop=self._loop) - req.send = d_req.send - t = req.loop.time() - print("Connected at {}".format(t), flush=True) - return conn - else: - - async def connect(self, req): - d_req = DelayableTCPConnector.\ - RequestContextManager(req, self._pre_request_delay) - conn = await super(DelayableTCPConnector, self).connect(req) - if self._post_connect_delay and self._post_connect_delay > 0: - await asyncio.sleep(self._post_connect_delay, - loop=self._loop) - req.send = d_req.send - t = req.loop.time() - print("Connected at {}".format(t), flush=True) - return conn + async def connect(self, req, *args, **kwargs): + d_req = DelayableTCPConnector.\ + RequestContextManager(req, self._pre_request_delay) + conn = await super(DelayableTCPConnector, self).\ + connect(req, *args, **kwargs) + if self._post_connect_delay and self._post_connect_delay > 0: + await asyncio.sleep(self._post_connect_delay, + loop=self._loop) + req.send = d_req.send + t = req.loop.time() + print("Connected at {}".format(t), flush=True) + return conn class DelayableSanicTestClient(SanicTestClient): diff --git a/tests/test_response.py b/tests/test_response.py index a4d38fad..99cb8950 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,3 +1,4 @@ +import sys import asyncio import inspect import os @@ -12,7 +13,7 @@ from sanic.response import ( HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json ) from sanic.server import HttpProtocol -from sanic.testing import HOST, PORT, is_windows +from sanic.testing import HOST, PORT from unittest.mock import MagicMock JSON_DATA = {'ok': True} @@ -74,6 +75,7 @@ def test_response_header(app): 'CONTENT-TYPE': 'application/json' }) + is_windows = sys.platform in ['win32', 'cygwin'] request, response = app.test_client.get('/') assert dict(response.headers) == { 'Connection': 'keep-alive', diff --git a/tests/test_worker.py b/tests/test_worker.py index c960168e..7bfab84c 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -4,17 +4,10 @@ import shlex import subprocess import urllib.request from unittest import mock +from sanic.worker import GunicornWorker +from sanic.app import Sanic import asyncio import pytest -from sanic.app import Sanic -try: - from sanic.worker import GunicornWorker -except ImportError: - pytestmark = pytest.mark.skip( - reason="GunicornWorker Not supported on this platform" - ) - # this has to be defined or pytest will err on import - GunicornWorker = object @pytest.fixture(scope='module') @@ -107,11 +100,12 @@ def test_run_max_requests_exceeded(worker): _runner = asyncio.ensure_future(worker._check_alive(), loop=loop) loop.run_until_complete(_runner) - assert worker.alive == False + assert not worker.alive worker.notify.assert_called_with() worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s", worker) + def test_worker_close(worker): loop = asyncio.new_event_loop() asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None)) @@ -124,8 +118,8 @@ def test_worker_close(worker): conn = mock.Mock() conn.websocket = mock.Mock() conn.websocket.close_connection = mock.Mock( - wraps=asyncio.coroutine(lambda *a, **kw: None) - ) + wraps=asyncio.coroutine(lambda *a, **kw: None) + ) worker.connections = set([conn]) worker.log = mock.Mock() worker.loop = loop @@ -141,6 +135,6 @@ def test_worker_close(worker): _close = asyncio.ensure_future(worker.close(), loop=loop) loop.run_until_complete(_close) - assert worker.signal.stopped == True - assert conn.websocket.close_connection.called == True + assert worker.signal.stopped + assert conn.websocket.close_connection.called assert len(worker.servers) == 0