simplified aiohttp version diffs, reverted worker import policy

This commit is contained in:
Alec Buckenheimer 2018-10-08 22:40:36 -04:00
parent 7a9e100b0f
commit b7d74c82ba
7 changed files with 243 additions and 321 deletions

View File

@ -1,30 +1,12 @@
import sys
import traceback
from json import JSONDecodeError
from sanic.log import logger
from sanic.exceptions import MethodNotSupported
from sanic.response import text
try:
try:
# direct use
import packaging as _packaging
version = _packaging.version
except (ImportError, AttributeError):
# setuptools v39.0 and above.
try:
from setuptools.extern import packaging as _packaging
except ImportError:
# Before setuptools v39.0
from pkg_resources.extern import packaging as _packaging
version = _packaging.version
except ImportError:
raise RuntimeError("The 'packaging' library is missing.")
HOST = '127.0.0.1'
PORT = 42101
is_windows = sys.platform in ['win32', 'cygwin']
class SanicTestClient:

View File

@ -12,207 +12,199 @@ except ImportError:
try:
import uvloop
import gunicorn.workers.base as base
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
base = None
pass
import gunicorn.workers.base as base
from sanic.server import trigger_events, serve, HttpProtocol, Signal
from sanic.websocket import WebSocketProtocol
# gunicorn is not available on windows
if base is not None:
class GunicornWorker(base.Worker):
class GunicornWorker(base.Worker):
http_protocol = HttpProtocol
websocket_protocol = WebSocketProtocol
http_protocol = HttpProtocol
websocket_protocol = WebSocketProtocol
def __init__(self, *args, **kw): # pragma: no cover
super().__init__(*args, **kw)
cfg = self.cfg
if cfg.is_ssl:
self.ssl_context = self._create_ssl_context(cfg)
else:
self.ssl_context = None
self.servers = {}
self.connections = set()
self.exit_code = 0
self.signal = Signal()
def __init__(self, *args, **kw): # pragma: no cover
super().__init__(*args, **kw)
cfg = self.cfg
if cfg.is_ssl:
self.ssl_context = self._create_ssl_context(cfg)
else:
self.ssl_context = None
self.servers = {}
self.connections = set()
self.exit_code = 0
self.signal = Signal()
def init_process(self):
# create new event_loop after fork
asyncio.get_event_loop().close()
def init_process(self):
# create new event_loop after fork
asyncio.get_event_loop().close()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().init_process()
super().init_process()
def run(self):
is_debug = self.log.loglevel == logging.DEBUG
protocol = (
self.websocket_protocol if self.app.callable.websocket_enabled
else self.http_protocol)
self._server_settings = self.app.callable._helper(
loop=self.loop,
debug=is_debug,
protocol=protocol,
ssl=self.ssl_context,
run_async=True)
self._server_settings['signal'] = self.signal
self._server_settings.pop('sock')
trigger_events(self._server_settings.get('before_start', []),
def run(self):
is_debug = self.log.loglevel == logging.DEBUG
protocol = (
self.websocket_protocol if self.app.callable.websocket_enabled
else self.http_protocol)
self._server_settings = self.app.callable._helper(
loop=self.loop,
debug=is_debug,
protocol=protocol,
ssl=self.ssl_context,
run_async=True)
self._server_settings['signal'] = self.signal
self._server_settings.pop('sock')
trigger_events(self._server_settings.get('before_start', []),
self.loop)
self._server_settings['before_start'] = ()
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
try:
self.loop.run_until_complete(self._runner)
self.app.callable.is_running = True
trigger_events(self._server_settings.get('after_start', []),
self.loop)
self._server_settings['before_start'] = ()
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
self.loop.run_until_complete(self._check_alive())
trigger_events(self._server_settings.get('before_stop', []),
self.loop)
self.loop.run_until_complete(self.close())
except BaseException:
traceback.print_exc()
finally:
try:
self.loop.run_until_complete(self._runner)
self.app.callable.is_running = True
trigger_events(self._server_settings.get('after_start', []),
trigger_events(self._server_settings.get('after_stop', []),
self.loop)
self.loop.run_until_complete(self._check_alive())
trigger_events(self._server_settings.get('before_stop', []),
self.loop)
self.loop.run_until_complete(self.close())
except BaseException:
traceback.print_exc()
finally:
try:
trigger_events(self._server_settings.get('after_stop', []),
self.loop)
except BaseException:
traceback.print_exc()
finally:
self.loop.close()
self.loop.close()
sys.exit(self.exit_code)
sys.exit(self.exit_code)
async def close(self):
if self.servers:
# stop accepting connections
self.log.info("Stopping server: %s, connections: %s",
self.pid, len(self.connections))
for server in self.servers:
server.close()
await server.wait_closed()
self.servers.clear()
async def close(self):
if self.servers:
# stop accepting connections
self.log.info("Stopping server: %s, connections: %s",
self.pid, len(self.connections))
for server in self.servers:
server.close()
await server.wait_closed()
self.servers.clear()
# prepare connections for closing
self.signal.stopped = True
for conn in self.connections:
conn.close_if_idle()
# prepare connections for closing
self.signal.stopped = True
for conn in self.connections:
conn.close_if_idle()
# gracefully shutdown timeout
start_shutdown = 0
graceful_shutdown_timeout = self.cfg.graceful_timeout
while self.connections and \
(start_shutdown < graceful_shutdown_timeout):
await asyncio.sleep(0.1)
start_shutdown = start_shutdown + 0.1
# gracefully shutdown timeout
start_shutdown = 0
graceful_shutdown_timeout = self.cfg.graceful_timeout
while self.connections and \
(start_shutdown < graceful_shutdown_timeout):
await asyncio.sleep(0.1)
start_shutdown = start_shutdown + 0.1
# Force close non-idle connection after waiting for
# graceful_shutdown_timeout
coros = []
for conn in self.connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(
conn.websocket.close_connection()
)
else:
conn.close()
_shutdown = asyncio.gather(*coros, loop=self.loop)
await _shutdown
async def _run(self):
for sock in self.sockets:
state = dict(requests_count=0)
self._server_settings["host"] = None
self._server_settings["port"] = None
server = await serve(
sock=sock,
connections=self.connections,
state=state,
**self._server_settings
)
self.servers[server] = state
async def _check_alive(self):
# If our parent changed then we shut down.
pid = os.getpid()
try:
while self.alive:
self.notify()
req_count = sum(
srv['requests_count'] for srv in self.servers.values()
# Force close non-idle connection after waiting for
# graceful_shutdown_timeout
coros = []
for conn in self.connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(
conn.websocket.close_connection()
)
if self.max_requests and req_count > self.max_requests:
self.alive = False
self.log.info(
"Max requests exceeded, shutting down: %s",
self
)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info(
"Parent changed, shutting down: %s",
self
)
else:
await asyncio.sleep(1.0, loop=self.loop)
except (BaseException, GeneratorExit, KeyboardInterrupt):
pass
else:
conn.close()
_shutdown = asyncio.gather(*coros, loop=self.loop)
await _shutdown
@staticmethod
def _create_ssl_context(cfg):
""" Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details.
"""
ctx = ssl.SSLContext(cfg.ssl_version)
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
ctx.verify_mode = cfg.cert_reqs
if cfg.ca_certs:
ctx.load_verify_locations(cfg.ca_certs)
if cfg.ciphers:
ctx.set_ciphers(cfg.ciphers)
return ctx
async def _run(self):
for sock in self.sockets:
state = dict(requests_count=0)
self._server_settings["host"] = None
self._server_settings["port"] = None
server = await serve(
sock=sock,
connections=self.connections,
state=state,
**self._server_settings
)
self.servers[server] = state
def init_signals(self):
# Set up signals through the event loop API.
async def _check_alive(self):
# If our parent changed then we shut down.
pid = os.getpid()
try:
while self.alive:
self.notify()
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
signal.SIGQUIT, None)
req_count = sum(
self.servers[srv]["requests_count"] for srv in self.servers
)
if self.max_requests and req_count > self.max_requests:
self.alive = False
self.log.info("Max requests exceeded, shutting down: %s",
self)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
await asyncio.sleep(1.0, loop=self.loop)
except (Exception, BaseException, GeneratorExit, KeyboardInterrupt):
pass
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
signal.SIGTERM, None)
@staticmethod
def _create_ssl_context(cfg):
""" Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details.
"""
ctx = ssl.SSLContext(cfg.ssl_version)
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
ctx.verify_mode = cfg.cert_reqs
if cfg.ca_certs:
ctx.load_verify_locations(cfg.ca_certs)
if cfg.ciphers:
ctx.set_ciphers(cfg.ciphers)
return ctx
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
signal.SIGINT, None)
def init_signals(self):
# Set up signals through the event loop API.
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
signal.SIGWINCH, None)
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
signal.SIGQUIT, None)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
signal.SIGUSR1, None)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
signal.SIGTERM, None)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
signal.SIGABRT, None)
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
signal.SIGINT, None)
# Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls
signal.siginterrupt(signal.SIGTERM, False)
signal.siginterrupt(signal.SIGUSR1, False)
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
signal.SIGWINCH, None)
def handle_quit(self, sig, frame):
self.alive = False
self.app.callable.is_running = False
self.cfg.worker_int(self)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
signal.SIGUSR1, None)
def handle_abort(self, sig, frame):
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
signal.SIGABRT, None)
# Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls
signal.siginterrupt(signal.SIGTERM, False)
signal.siginterrupt(signal.SIGUSR1, False)
def handle_quit(self, sig, frame):
self.alive = False
self.app.callable.is_running = False
self.cfg.worker_int(self)
def handle_abort(self, sig, frame):
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)

View File

@ -1,7 +1,11 @@
import sys
import pytest
from sanic import Sanic
if sys.platform in ['win32', 'cygwin']:
collect_ignore = ["test_worker.py"]
@pytest.fixture
def app(request):

View File

@ -7,10 +7,7 @@ from sanic.config import Config
from sanic import server
import aiohttp
from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, PORT, version
aiohttp_version = version.parse(aiohttp.__version__)
from sanic.testing import SanicTestClient, HOST, PORT
class ReuseableTCPConnector(TCPConnector):
@ -18,39 +15,16 @@ class ReuseableTCPConnector(TCPConnector):
super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
self.old_proto = None
if aiohttp_version >= version.parse('3.3.0'):
async def connect(self, req, traces, timeout):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, traces, timeout)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
elif aiohttp_version >= version.parse('3.0.0'):
async def connect(self, req, traces=None):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, traces=traces)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
else:
async def connect(self, req):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
async def connect(self, req, *args, **kwargs):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, *args, **kwargs)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
class ReuseableSanicTestClient(SanicTestClient):
@ -135,14 +109,14 @@ class ReuseableSanicTestClient(SanicTestClient):
try:
request, response = results
return request, response
except:
except Exception:
raise ValueError(
"Request and response object expected, got ({})".format(
results))
else:
try:
return results[-1]
except:
except Exception:
raise ValueError(
"Request object expected, got ({})".format(results))

View File

@ -6,7 +6,23 @@ from sanic.response import text
from sanic.config import Config
import aiohttp
from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, version
from sanic.testing import SanicTestClient, HOST
try:
try:
# direct use
import packaging
version = packaging.version
except (ImportError, AttributeError):
# setuptools v39.0 and above.
try:
from setuptools.extern import packaging
except ImportError:
# Before setuptools v39.0
from pkg_resources.extern import packaging
version = packaging.version
except ImportError:
raise RuntimeError("The 'packaging' library is missing.")
aiohttp_version = version.parse(aiohttp.__version__)
@ -44,8 +60,7 @@ class DelayableTCPConnector(TCPConnector):
if aiohttp_version >= version.parse("3.3.0"):
ret = await self.orig_start(connection)
else:
ret = await self.orig_start(connection,
read_until_eof)
ret = await self.orig_start(connection, read_until_eof)
except Exception as e:
raise e
return ret
@ -59,57 +74,43 @@ class DelayableTCPConnector(TCPConnector):
async def delayed_send(self, *args, **kwargs):
req = self.req
if self.delay and self.delay > 0:
#sync_sleep(self.delay)
# sync_sleep(self.delay)
await asyncio.sleep(self.delay)
t = req.loop.time()
print("sending at {}".format(t), flush=True)
conn = next(iter(args)) # first arg is connection
if aiohttp_version >= version.parse("3.1.0"):
try:
delayed_resp = await self.orig_send(*args, **kwargs)
except Exception as e:
if aiohttp_version >= version.parse("3.3.0"):
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None)
else:
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
auto_decompress=None,
traces=[],
loop=req.loop,
session=None)
else:
try:
delayed_resp = self.orig_send(*args, **kwargs)
except Exception as e:
try:
return await self.orig_send(*args, **kwargs)
except Exception as e:
if aiohttp_version < version.parse("3.1.0"):
return aiohttp.ClientResponse(req.method, req.url)
return delayed_resp
kw = dict(
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None
)
if aiohttp_version < version.parse("3.3.0"):
kw['auto_decompress'] = None
return aiohttp.ClientResponse(req.method, req.url, **kw)
def _send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
if aiohttp_version >= version.parse("3.1.0"):
# aiohttp changed the request.send method to async
async def send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
return self._send(*args, **kwargs)
else:
def send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
send = _send
def __init__(self, *args, **kwargs):
_post_connect_delay = kwargs.pop('post_connect_delay', 0)
@ -118,45 +119,18 @@ class DelayableTCPConnector(TCPConnector):
self._post_connect_delay = _post_connect_delay
self._pre_request_delay = _pre_request_delay
if aiohttp_version >= version.parse("3.3.0"):
async def connect(self, req, traces, timeout):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, traces, timeout)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
elif aiohttp_version >= version.parse("3.0.0"):
async def connect(self, req, traces=None):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, traces=traces)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
else:
async def connect(self, req):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).connect(req)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
async def connect(self, req, *args, **kwargs):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, *args, **kwargs)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
class DelayableSanicTestClient(SanicTestClient):

View File

@ -1,3 +1,4 @@
import sys
import asyncio
import inspect
import os
@ -12,7 +13,7 @@ from sanic.response import (
HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
)
from sanic.server import HttpProtocol
from sanic.testing import HOST, PORT, is_windows
from sanic.testing import HOST, PORT
from unittest.mock import MagicMock
JSON_DATA = {'ok': True}
@ -74,6 +75,7 @@ def test_response_header(app):
'CONTENT-TYPE': 'application/json'
})
is_windows = sys.platform in ['win32', 'cygwin']
request, response = app.test_client.get('/')
assert dict(response.headers) == {
'Connection': 'keep-alive',

View File

@ -4,17 +4,10 @@ import shlex
import subprocess
import urllib.request
from unittest import mock
from sanic.worker import GunicornWorker
from sanic.app import Sanic
import asyncio
import pytest
from sanic.app import Sanic
try:
from sanic.worker import GunicornWorker
except ImportError:
pytestmark = pytest.mark.skip(
reason="GunicornWorker Not supported on this platform"
)
# this has to be defined or pytest will err on import
GunicornWorker = object
@pytest.fixture(scope='module')
@ -107,11 +100,12 @@ def test_run_max_requests_exceeded(worker):
_runner = asyncio.ensure_future(worker._check_alive(), loop=loop)
loop.run_until_complete(_runner)
assert worker.alive == False
assert not worker.alive
worker.notify.assert_called_with()
worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s",
worker)
def test_worker_close(worker):
loop = asyncio.new_event_loop()
asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None))
@ -124,8 +118,8 @@ def test_worker_close(worker):
conn = mock.Mock()
conn.websocket = mock.Mock()
conn.websocket.close_connection = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None)
)
wraps=asyncio.coroutine(lambda *a, **kw: None)
)
worker.connections = set([conn])
worker.log = mock.Mock()
worker.loop = loop
@ -141,6 +135,6 @@ def test_worker_close(worker):
_close = asyncio.ensure_future(worker.close(), loop=loop)
loop.run_until_complete(_close)
assert worker.signal.stopped == True
assert conn.websocket.close_connection.called == True
assert worker.signal.stopped
assert conn.websocket.close_connection.called
assert len(worker.servers) == 0