simplified aiohttp version diffs, reverted worker import policy
This commit is contained in:
parent
7a9e100b0f
commit
b7d74c82ba
@ -1,30 +1,12 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
from sanic.exceptions import MethodNotSupported
|
from sanic.exceptions import MethodNotSupported
|
||||||
from sanic.response import text
|
from sanic.response import text
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
# direct use
|
|
||||||
import packaging as _packaging
|
|
||||||
version = _packaging.version
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
# setuptools v39.0 and above.
|
|
||||||
try:
|
|
||||||
from setuptools.extern import packaging as _packaging
|
|
||||||
except ImportError:
|
|
||||||
# Before setuptools v39.0
|
|
||||||
from pkg_resources.extern import packaging as _packaging
|
|
||||||
version = _packaging.version
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError("The 'packaging' library is missing.")
|
|
||||||
|
|
||||||
|
|
||||||
HOST = '127.0.0.1'
|
HOST = '127.0.0.1'
|
||||||
PORT = 42101
|
PORT = 42101
|
||||||
is_windows = sys.platform in ['win32', 'cygwin']
|
|
||||||
|
|
||||||
|
|
||||||
class SanicTestClient:
|
class SanicTestClient:
|
||||||
|
326
sanic/worker.py
326
sanic/worker.py
@ -12,207 +12,199 @@ except ImportError:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import uvloop
|
import uvloop
|
||||||
import gunicorn.workers.base as base
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
except ImportError:
|
except ImportError:
|
||||||
base = None
|
|
||||||
pass
|
pass
|
||||||
|
import gunicorn.workers.base as base
|
||||||
|
|
||||||
from sanic.server import trigger_events, serve, HttpProtocol, Signal
|
from sanic.server import trigger_events, serve, HttpProtocol, Signal
|
||||||
from sanic.websocket import WebSocketProtocol
|
from sanic.websocket import WebSocketProtocol
|
||||||
|
|
||||||
|
|
||||||
# gunicorn is not available on windows
|
class GunicornWorker(base.Worker):
|
||||||
if base is not None:
|
|
||||||
class GunicornWorker(base.Worker):
|
|
||||||
|
|
||||||
http_protocol = HttpProtocol
|
http_protocol = HttpProtocol
|
||||||
websocket_protocol = WebSocketProtocol
|
websocket_protocol = WebSocketProtocol
|
||||||
|
|
||||||
def __init__(self, *args, **kw): # pragma: no cover
|
def __init__(self, *args, **kw): # pragma: no cover
|
||||||
super().__init__(*args, **kw)
|
super().__init__(*args, **kw)
|
||||||
cfg = self.cfg
|
cfg = self.cfg
|
||||||
if cfg.is_ssl:
|
if cfg.is_ssl:
|
||||||
self.ssl_context = self._create_ssl_context(cfg)
|
self.ssl_context = self._create_ssl_context(cfg)
|
||||||
else:
|
else:
|
||||||
self.ssl_context = None
|
self.ssl_context = None
|
||||||
self.servers = {}
|
self.servers = {}
|
||||||
self.connections = set()
|
self.connections = set()
|
||||||
self.exit_code = 0
|
self.exit_code = 0
|
||||||
self.signal = Signal()
|
self.signal = Signal()
|
||||||
|
|
||||||
def init_process(self):
|
def init_process(self):
|
||||||
# create new event_loop after fork
|
# create new event_loop after fork
|
||||||
asyncio.get_event_loop().close()
|
asyncio.get_event_loop().close()
|
||||||
|
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(self.loop)
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
super().init_process()
|
super().init_process()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
is_debug = self.log.loglevel == logging.DEBUG
|
is_debug = self.log.loglevel == logging.DEBUG
|
||||||
protocol = (
|
protocol = (
|
||||||
self.websocket_protocol if self.app.callable.websocket_enabled
|
self.websocket_protocol if self.app.callable.websocket_enabled
|
||||||
else self.http_protocol)
|
else self.http_protocol)
|
||||||
self._server_settings = self.app.callable._helper(
|
self._server_settings = self.app.callable._helper(
|
||||||
loop=self.loop,
|
loop=self.loop,
|
||||||
debug=is_debug,
|
debug=is_debug,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
ssl=self.ssl_context,
|
ssl=self.ssl_context,
|
||||||
run_async=True)
|
run_async=True)
|
||||||
self._server_settings['signal'] = self.signal
|
self._server_settings['signal'] = self.signal
|
||||||
self._server_settings.pop('sock')
|
self._server_settings.pop('sock')
|
||||||
trigger_events(self._server_settings.get('before_start', []),
|
trigger_events(self._server_settings.get('before_start', []),
|
||||||
|
self.loop)
|
||||||
|
self._server_settings['before_start'] = ()
|
||||||
|
|
||||||
|
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
|
||||||
|
try:
|
||||||
|
self.loop.run_until_complete(self._runner)
|
||||||
|
self.app.callable.is_running = True
|
||||||
|
trigger_events(self._server_settings.get('after_start', []),
|
||||||
self.loop)
|
self.loop)
|
||||||
self._server_settings['before_start'] = ()
|
self.loop.run_until_complete(self._check_alive())
|
||||||
|
trigger_events(self._server_settings.get('before_stop', []),
|
||||||
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
|
self.loop)
|
||||||
|
self.loop.run_until_complete(self.close())
|
||||||
|
except BaseException:
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
try:
|
try:
|
||||||
self.loop.run_until_complete(self._runner)
|
trigger_events(self._server_settings.get('after_stop', []),
|
||||||
self.app.callable.is_running = True
|
|
||||||
trigger_events(self._server_settings.get('after_start', []),
|
|
||||||
self.loop)
|
self.loop)
|
||||||
self.loop.run_until_complete(self._check_alive())
|
|
||||||
trigger_events(self._server_settings.get('before_stop', []),
|
|
||||||
self.loop)
|
|
||||||
self.loop.run_until_complete(self.close())
|
|
||||||
except BaseException:
|
except BaseException:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
try:
|
self.loop.close()
|
||||||
trigger_events(self._server_settings.get('after_stop', []),
|
|
||||||
self.loop)
|
|
||||||
except BaseException:
|
|
||||||
traceback.print_exc()
|
|
||||||
finally:
|
|
||||||
self.loop.close()
|
|
||||||
|
|
||||||
sys.exit(self.exit_code)
|
sys.exit(self.exit_code)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
if self.servers:
|
if self.servers:
|
||||||
# stop accepting connections
|
# stop accepting connections
|
||||||
self.log.info("Stopping server: %s, connections: %s",
|
self.log.info("Stopping server: %s, connections: %s",
|
||||||
self.pid, len(self.connections))
|
self.pid, len(self.connections))
|
||||||
for server in self.servers:
|
for server in self.servers:
|
||||||
server.close()
|
server.close()
|
||||||
await server.wait_closed()
|
await server.wait_closed()
|
||||||
self.servers.clear()
|
self.servers.clear()
|
||||||
|
|
||||||
# prepare connections for closing
|
# prepare connections for closing
|
||||||
self.signal.stopped = True
|
self.signal.stopped = True
|
||||||
for conn in self.connections:
|
for conn in self.connections:
|
||||||
conn.close_if_idle()
|
conn.close_if_idle()
|
||||||
|
|
||||||
# gracefully shutdown timeout
|
# gracefully shutdown timeout
|
||||||
start_shutdown = 0
|
start_shutdown = 0
|
||||||
graceful_shutdown_timeout = self.cfg.graceful_timeout
|
graceful_shutdown_timeout = self.cfg.graceful_timeout
|
||||||
while self.connections and \
|
while self.connections and \
|
||||||
(start_shutdown < graceful_shutdown_timeout):
|
(start_shutdown < graceful_shutdown_timeout):
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
start_shutdown = start_shutdown + 0.1
|
start_shutdown = start_shutdown + 0.1
|
||||||
|
|
||||||
# Force close non-idle connection after waiting for
|
# Force close non-idle connection after waiting for
|
||||||
# graceful_shutdown_timeout
|
# graceful_shutdown_timeout
|
||||||
coros = []
|
coros = []
|
||||||
for conn in self.connections:
|
for conn in self.connections:
|
||||||
if hasattr(conn, "websocket") and conn.websocket:
|
if hasattr(conn, "websocket") and conn.websocket:
|
||||||
coros.append(
|
coros.append(
|
||||||
conn.websocket.close_connection()
|
conn.websocket.close_connection()
|
||||||
)
|
|
||||||
else:
|
|
||||||
conn.close()
|
|
||||||
_shutdown = asyncio.gather(*coros, loop=self.loop)
|
|
||||||
await _shutdown
|
|
||||||
|
|
||||||
async def _run(self):
|
|
||||||
for sock in self.sockets:
|
|
||||||
state = dict(requests_count=0)
|
|
||||||
self._server_settings["host"] = None
|
|
||||||
self._server_settings["port"] = None
|
|
||||||
server = await serve(
|
|
||||||
sock=sock,
|
|
||||||
connections=self.connections,
|
|
||||||
state=state,
|
|
||||||
**self._server_settings
|
|
||||||
)
|
|
||||||
self.servers[server] = state
|
|
||||||
|
|
||||||
async def _check_alive(self):
|
|
||||||
# If our parent changed then we shut down.
|
|
||||||
pid = os.getpid()
|
|
||||||
try:
|
|
||||||
while self.alive:
|
|
||||||
self.notify()
|
|
||||||
|
|
||||||
req_count = sum(
|
|
||||||
srv['requests_count'] for srv in self.servers.values()
|
|
||||||
)
|
)
|
||||||
if self.max_requests and req_count > self.max_requests:
|
else:
|
||||||
self.alive = False
|
conn.close()
|
||||||
self.log.info(
|
_shutdown = asyncio.gather(*coros, loop=self.loop)
|
||||||
"Max requests exceeded, shutting down: %s",
|
await _shutdown
|
||||||
self
|
|
||||||
)
|
|
||||||
elif pid == os.getpid() and self.ppid != os.getppid():
|
|
||||||
self.alive = False
|
|
||||||
self.log.info(
|
|
||||||
"Parent changed, shutting down: %s",
|
|
||||||
self
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(1.0, loop=self.loop)
|
|
||||||
except (BaseException, GeneratorExit, KeyboardInterrupt):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
async def _run(self):
|
||||||
def _create_ssl_context(cfg):
|
for sock in self.sockets:
|
||||||
""" Creates SSLContext instance for usage in asyncio.create_server.
|
state = dict(requests_count=0)
|
||||||
See ssl.SSLSocket.__init__ for more details.
|
self._server_settings["host"] = None
|
||||||
"""
|
self._server_settings["port"] = None
|
||||||
ctx = ssl.SSLContext(cfg.ssl_version)
|
server = await serve(
|
||||||
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
|
sock=sock,
|
||||||
ctx.verify_mode = cfg.cert_reqs
|
connections=self.connections,
|
||||||
if cfg.ca_certs:
|
state=state,
|
||||||
ctx.load_verify_locations(cfg.ca_certs)
|
**self._server_settings
|
||||||
if cfg.ciphers:
|
)
|
||||||
ctx.set_ciphers(cfg.ciphers)
|
self.servers[server] = state
|
||||||
return ctx
|
|
||||||
|
|
||||||
def init_signals(self):
|
async def _check_alive(self):
|
||||||
# Set up signals through the event loop API.
|
# If our parent changed then we shut down.
|
||||||
|
pid = os.getpid()
|
||||||
|
try:
|
||||||
|
while self.alive:
|
||||||
|
self.notify()
|
||||||
|
|
||||||
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
|
req_count = sum(
|
||||||
signal.SIGQUIT, None)
|
self.servers[srv]["requests_count"] for srv in self.servers
|
||||||
|
)
|
||||||
|
if self.max_requests and req_count > self.max_requests:
|
||||||
|
self.alive = False
|
||||||
|
self.log.info("Max requests exceeded, shutting down: %s",
|
||||||
|
self)
|
||||||
|
elif pid == os.getpid() and self.ppid != os.getppid():
|
||||||
|
self.alive = False
|
||||||
|
self.log.info("Parent changed, shutting down: %s", self)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(1.0, loop=self.loop)
|
||||||
|
except (Exception, BaseException, GeneratorExit, KeyboardInterrupt):
|
||||||
|
pass
|
||||||
|
|
||||||
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
|
@staticmethod
|
||||||
signal.SIGTERM, None)
|
def _create_ssl_context(cfg):
|
||||||
|
""" Creates SSLContext instance for usage in asyncio.create_server.
|
||||||
|
See ssl.SSLSocket.__init__ for more details.
|
||||||
|
"""
|
||||||
|
ctx = ssl.SSLContext(cfg.ssl_version)
|
||||||
|
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
|
||||||
|
ctx.verify_mode = cfg.cert_reqs
|
||||||
|
if cfg.ca_certs:
|
||||||
|
ctx.load_verify_locations(cfg.ca_certs)
|
||||||
|
if cfg.ciphers:
|
||||||
|
ctx.set_ciphers(cfg.ciphers)
|
||||||
|
return ctx
|
||||||
|
|
||||||
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
|
def init_signals(self):
|
||||||
signal.SIGINT, None)
|
# Set up signals through the event loop API.
|
||||||
|
|
||||||
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
|
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
|
||||||
signal.SIGWINCH, None)
|
signal.SIGQUIT, None)
|
||||||
|
|
||||||
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
|
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
|
||||||
signal.SIGUSR1, None)
|
signal.SIGTERM, None)
|
||||||
|
|
||||||
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
|
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
|
||||||
signal.SIGABRT, None)
|
signal.SIGINT, None)
|
||||||
|
|
||||||
# Don't let SIGTERM and SIGUSR1 disturb active requests
|
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
|
||||||
# by interrupting system calls
|
signal.SIGWINCH, None)
|
||||||
signal.siginterrupt(signal.SIGTERM, False)
|
|
||||||
signal.siginterrupt(signal.SIGUSR1, False)
|
|
||||||
|
|
||||||
def handle_quit(self, sig, frame):
|
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
|
||||||
self.alive = False
|
signal.SIGUSR1, None)
|
||||||
self.app.callable.is_running = False
|
|
||||||
self.cfg.worker_int(self)
|
|
||||||
|
|
||||||
def handle_abort(self, sig, frame):
|
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
|
||||||
self.alive = False
|
signal.SIGABRT, None)
|
||||||
self.exit_code = 1
|
|
||||||
self.cfg.worker_abort(self)
|
# Don't let SIGTERM and SIGUSR1 disturb active requests
|
||||||
sys.exit(1)
|
# by interrupting system calls
|
||||||
|
signal.siginterrupt(signal.SIGTERM, False)
|
||||||
|
signal.siginterrupt(signal.SIGUSR1, False)
|
||||||
|
|
||||||
|
def handle_quit(self, sig, frame):
|
||||||
|
self.alive = False
|
||||||
|
self.app.callable.is_running = False
|
||||||
|
self.cfg.worker_int(self)
|
||||||
|
|
||||||
|
def handle_abort(self, sig, frame):
|
||||||
|
self.alive = False
|
||||||
|
self.exit_code = 1
|
||||||
|
self.cfg.worker_abort(self)
|
||||||
|
sys.exit(1)
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from sanic import Sanic
|
from sanic import Sanic
|
||||||
|
|
||||||
|
if sys.platform in ['win32', 'cygwin']:
|
||||||
|
collect_ignore = ["test_worker.py"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(request):
|
def app(request):
|
||||||
|
@ -7,10 +7,7 @@ from sanic.config import Config
|
|||||||
from sanic import server
|
from sanic import server
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import TCPConnector
|
from aiohttp import TCPConnector
|
||||||
from sanic.testing import SanicTestClient, HOST, PORT, version
|
from sanic.testing import SanicTestClient, HOST, PORT
|
||||||
|
|
||||||
|
|
||||||
aiohttp_version = version.parse(aiohttp.__version__)
|
|
||||||
|
|
||||||
|
|
||||||
class ReuseableTCPConnector(TCPConnector):
|
class ReuseableTCPConnector(TCPConnector):
|
||||||
@ -18,39 +15,16 @@ class ReuseableTCPConnector(TCPConnector):
|
|||||||
super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
|
super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
|
||||||
self.old_proto = None
|
self.old_proto = None
|
||||||
|
|
||||||
if aiohttp_version >= version.parse('3.3.0'):
|
async def connect(self, req, *args, **kwargs):
|
||||||
async def connect(self, req, traces, timeout):
|
new_conn = await super(ReuseableTCPConnector, self)\
|
||||||
new_conn = await super(ReuseableTCPConnector, self)\
|
.connect(req, *args, **kwargs)
|
||||||
.connect(req, traces, timeout)
|
if self.old_proto is not None:
|
||||||
if self.old_proto is not None:
|
if self.old_proto != new_conn._protocol:
|
||||||
if self.old_proto != new_conn._protocol:
|
raise RuntimeError(
|
||||||
raise RuntimeError(
|
"We got a new connection, wanted the same one!")
|
||||||
"We got a new connection, wanted the same one!")
|
print(new_conn.__dict__)
|
||||||
print(new_conn.__dict__)
|
self.old_proto = new_conn._protocol
|
||||||
self.old_proto = new_conn._protocol
|
return new_conn
|
||||||
return new_conn
|
|
||||||
elif aiohttp_version >= version.parse('3.0.0'):
|
|
||||||
async def connect(self, req, traces=None):
|
|
||||||
new_conn = await super(ReuseableTCPConnector, self)\
|
|
||||||
.connect(req, traces=traces)
|
|
||||||
if self.old_proto is not None:
|
|
||||||
if self.old_proto != new_conn._protocol:
|
|
||||||
raise RuntimeError(
|
|
||||||
"We got a new connection, wanted the same one!")
|
|
||||||
print(new_conn.__dict__)
|
|
||||||
self.old_proto = new_conn._protocol
|
|
||||||
return new_conn
|
|
||||||
else:
|
|
||||||
async def connect(self, req):
|
|
||||||
new_conn = await super(ReuseableTCPConnector, self)\
|
|
||||||
.connect(req)
|
|
||||||
if self.old_proto is not None:
|
|
||||||
if self.old_proto != new_conn._protocol:
|
|
||||||
raise RuntimeError(
|
|
||||||
"We got a new connection, wanted the same one!")
|
|
||||||
print(new_conn.__dict__)
|
|
||||||
self.old_proto = new_conn._protocol
|
|
||||||
return new_conn
|
|
||||||
|
|
||||||
|
|
||||||
class ReuseableSanicTestClient(SanicTestClient):
|
class ReuseableSanicTestClient(SanicTestClient):
|
||||||
@ -135,14 +109,14 @@ class ReuseableSanicTestClient(SanicTestClient):
|
|||||||
try:
|
try:
|
||||||
request, response = results
|
request, response = results
|
||||||
return request, response
|
return request, response
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Request and response object expected, got ({})".format(
|
"Request and response object expected, got ({})".format(
|
||||||
results))
|
results))
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return results[-1]
|
return results[-1]
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Request object expected, got ({})".format(results))
|
"Request object expected, got ({})".format(results))
|
||||||
|
|
||||||
|
@ -6,7 +6,23 @@ from sanic.response import text
|
|||||||
from sanic.config import Config
|
from sanic.config import Config
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import TCPConnector
|
from aiohttp import TCPConnector
|
||||||
from sanic.testing import SanicTestClient, HOST, version
|
from sanic.testing import SanicTestClient, HOST
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
# direct use
|
||||||
|
import packaging
|
||||||
|
version = packaging.version
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
# setuptools v39.0 and above.
|
||||||
|
try:
|
||||||
|
from setuptools.extern import packaging
|
||||||
|
except ImportError:
|
||||||
|
# Before setuptools v39.0
|
||||||
|
from pkg_resources.extern import packaging
|
||||||
|
version = packaging.version
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError("The 'packaging' library is missing.")
|
||||||
|
|
||||||
|
|
||||||
aiohttp_version = version.parse(aiohttp.__version__)
|
aiohttp_version = version.parse(aiohttp.__version__)
|
||||||
@ -44,8 +60,7 @@ class DelayableTCPConnector(TCPConnector):
|
|||||||
if aiohttp_version >= version.parse("3.3.0"):
|
if aiohttp_version >= version.parse("3.3.0"):
|
||||||
ret = await self.orig_start(connection)
|
ret = await self.orig_start(connection)
|
||||||
else:
|
else:
|
||||||
ret = await self.orig_start(connection,
|
ret = await self.orig_start(connection, read_until_eof)
|
||||||
read_until_eof)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
return ret
|
return ret
|
||||||
@ -59,57 +74,43 @@ class DelayableTCPConnector(TCPConnector):
|
|||||||
async def delayed_send(self, *args, **kwargs):
|
async def delayed_send(self, *args, **kwargs):
|
||||||
req = self.req
|
req = self.req
|
||||||
if self.delay and self.delay > 0:
|
if self.delay and self.delay > 0:
|
||||||
#sync_sleep(self.delay)
|
# sync_sleep(self.delay)
|
||||||
await asyncio.sleep(self.delay)
|
await asyncio.sleep(self.delay)
|
||||||
t = req.loop.time()
|
t = req.loop.time()
|
||||||
print("sending at {}".format(t), flush=True)
|
print("sending at {}".format(t), flush=True)
|
||||||
conn = next(iter(args)) # first arg is connection
|
conn = next(iter(args)) # first arg is connection
|
||||||
|
|
||||||
if aiohttp_version >= version.parse("3.1.0"):
|
try:
|
||||||
try:
|
return await self.orig_send(*args, **kwargs)
|
||||||
delayed_resp = await self.orig_send(*args, **kwargs)
|
except Exception as e:
|
||||||
except Exception as e:
|
if aiohttp_version < version.parse("3.1.0"):
|
||||||
if aiohttp_version >= version.parse("3.3.0"):
|
|
||||||
return aiohttp.ClientResponse(req.method, req.url,
|
|
||||||
writer=None,
|
|
||||||
continue100=None,
|
|
||||||
timer=None,
|
|
||||||
request_info=None,
|
|
||||||
traces=[],
|
|
||||||
loop=req.loop,
|
|
||||||
session=None)
|
|
||||||
else:
|
|
||||||
return aiohttp.ClientResponse(req.method, req.url,
|
|
||||||
writer=None,
|
|
||||||
continue100=None,
|
|
||||||
timer=None,
|
|
||||||
request_info=None,
|
|
||||||
auto_decompress=None,
|
|
||||||
traces=[],
|
|
||||||
loop=req.loop,
|
|
||||||
session=None)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
delayed_resp = self.orig_send(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
return aiohttp.ClientResponse(req.method, req.url)
|
return aiohttp.ClientResponse(req.method, req.url)
|
||||||
return delayed_resp
|
kw = dict(
|
||||||
|
writer=None,
|
||||||
|
continue100=None,
|
||||||
|
timer=None,
|
||||||
|
request_info=None,
|
||||||
|
traces=[],
|
||||||
|
loop=req.loop,
|
||||||
|
session=None
|
||||||
|
)
|
||||||
|
if aiohttp_version < version.parse("3.3.0"):
|
||||||
|
kw['auto_decompress'] = None
|
||||||
|
return aiohttp.ClientResponse(req.method, req.url, **kw)
|
||||||
|
|
||||||
|
def _send(self, *args, **kwargs):
|
||||||
|
gen = self.delayed_send(*args, **kwargs)
|
||||||
|
task = self.req.loop.create_task(gen)
|
||||||
|
self.send_task = task
|
||||||
|
self._acting_as = task
|
||||||
|
return self
|
||||||
|
|
||||||
if aiohttp_version >= version.parse("3.1.0"):
|
if aiohttp_version >= version.parse("3.1.0"):
|
||||||
# aiohttp changed the request.send method to async
|
# aiohttp changed the request.send method to async
|
||||||
async def send(self, *args, **kwargs):
|
async def send(self, *args, **kwargs):
|
||||||
gen = self.delayed_send(*args, **kwargs)
|
return self._send(*args, **kwargs)
|
||||||
task = self.req.loop.create_task(gen)
|
|
||||||
self.send_task = task
|
|
||||||
self._acting_as = task
|
|
||||||
return self
|
|
||||||
else:
|
else:
|
||||||
def send(self, *args, **kwargs):
|
send = _send
|
||||||
gen = self.delayed_send(*args, **kwargs)
|
|
||||||
task = self.req.loop.create_task(gen)
|
|
||||||
self.send_task = task
|
|
||||||
self._acting_as = task
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
_post_connect_delay = kwargs.pop('post_connect_delay', 0)
|
_post_connect_delay = kwargs.pop('post_connect_delay', 0)
|
||||||
@ -118,45 +119,18 @@ class DelayableTCPConnector(TCPConnector):
|
|||||||
self._post_connect_delay = _post_connect_delay
|
self._post_connect_delay = _post_connect_delay
|
||||||
self._pre_request_delay = _pre_request_delay
|
self._pre_request_delay = _pre_request_delay
|
||||||
|
|
||||||
if aiohttp_version >= version.parse("3.3.0"):
|
async def connect(self, req, *args, **kwargs):
|
||||||
async def connect(self, req, traces, timeout):
|
d_req = DelayableTCPConnector.\
|
||||||
d_req = DelayableTCPConnector.\
|
RequestContextManager(req, self._pre_request_delay)
|
||||||
RequestContextManager(req, self._pre_request_delay)
|
conn = await super(DelayableTCPConnector, self).\
|
||||||
conn = await super(DelayableTCPConnector, self).\
|
connect(req, *args, **kwargs)
|
||||||
connect(req, traces, timeout)
|
if self._post_connect_delay and self._post_connect_delay > 0:
|
||||||
if self._post_connect_delay and self._post_connect_delay > 0:
|
await asyncio.sleep(self._post_connect_delay,
|
||||||
await asyncio.sleep(self._post_connect_delay,
|
loop=self._loop)
|
||||||
loop=self._loop)
|
req.send = d_req.send
|
||||||
req.send = d_req.send
|
t = req.loop.time()
|
||||||
t = req.loop.time()
|
print("Connected at {}".format(t), flush=True)
|
||||||
print("Connected at {}".format(t), flush=True)
|
return conn
|
||||||
return conn
|
|
||||||
elif aiohttp_version >= version.parse("3.0.0"):
|
|
||||||
async def connect(self, req, traces=None):
|
|
||||||
d_req = DelayableTCPConnector.\
|
|
||||||
RequestContextManager(req, self._pre_request_delay)
|
|
||||||
conn = await super(DelayableTCPConnector, self).\
|
|
||||||
connect(req, traces=traces)
|
|
||||||
if self._post_connect_delay and self._post_connect_delay > 0:
|
|
||||||
await asyncio.sleep(self._post_connect_delay,
|
|
||||||
loop=self._loop)
|
|
||||||
req.send = d_req.send
|
|
||||||
t = req.loop.time()
|
|
||||||
print("Connected at {}".format(t), flush=True)
|
|
||||||
return conn
|
|
||||||
else:
|
|
||||||
|
|
||||||
async def connect(self, req):
|
|
||||||
d_req = DelayableTCPConnector.\
|
|
||||||
RequestContextManager(req, self._pre_request_delay)
|
|
||||||
conn = await super(DelayableTCPConnector, self).connect(req)
|
|
||||||
if self._post_connect_delay and self._post_connect_delay > 0:
|
|
||||||
await asyncio.sleep(self._post_connect_delay,
|
|
||||||
loop=self._loop)
|
|
||||||
req.send = d_req.send
|
|
||||||
t = req.loop.time()
|
|
||||||
print("Connected at {}".format(t), flush=True)
|
|
||||||
return conn
|
|
||||||
|
|
||||||
|
|
||||||
class DelayableSanicTestClient(SanicTestClient):
|
class DelayableSanicTestClient(SanicTestClient):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
@ -12,7 +13,7 @@ from sanic.response import (
|
|||||||
HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
|
HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
|
||||||
)
|
)
|
||||||
from sanic.server import HttpProtocol
|
from sanic.server import HttpProtocol
|
||||||
from sanic.testing import HOST, PORT, is_windows
|
from sanic.testing import HOST, PORT
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
JSON_DATA = {'ok': True}
|
JSON_DATA = {'ok': True}
|
||||||
@ -74,6 +75,7 @@ def test_response_header(app):
|
|||||||
'CONTENT-TYPE': 'application/json'
|
'CONTENT-TYPE': 'application/json'
|
||||||
})
|
})
|
||||||
|
|
||||||
|
is_windows = sys.platform in ['win32', 'cygwin']
|
||||||
request, response = app.test_client.get('/')
|
request, response = app.test_client.get('/')
|
||||||
assert dict(response.headers) == {
|
assert dict(response.headers) == {
|
||||||
'Connection': 'keep-alive',
|
'Connection': 'keep-alive',
|
||||||
|
@ -4,17 +4,10 @@ import shlex
|
|||||||
import subprocess
|
import subprocess
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from sanic.worker import GunicornWorker
|
||||||
|
from sanic.app import Sanic
|
||||||
import asyncio
|
import asyncio
|
||||||
import pytest
|
import pytest
|
||||||
from sanic.app import Sanic
|
|
||||||
try:
|
|
||||||
from sanic.worker import GunicornWorker
|
|
||||||
except ImportError:
|
|
||||||
pytestmark = pytest.mark.skip(
|
|
||||||
reason="GunicornWorker Not supported on this platform"
|
|
||||||
)
|
|
||||||
# this has to be defined or pytest will err on import
|
|
||||||
GunicornWorker = object
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
@ -107,11 +100,12 @@ def test_run_max_requests_exceeded(worker):
|
|||||||
_runner = asyncio.ensure_future(worker._check_alive(), loop=loop)
|
_runner = asyncio.ensure_future(worker._check_alive(), loop=loop)
|
||||||
loop.run_until_complete(_runner)
|
loop.run_until_complete(_runner)
|
||||||
|
|
||||||
assert worker.alive == False
|
assert not worker.alive
|
||||||
worker.notify.assert_called_with()
|
worker.notify.assert_called_with()
|
||||||
worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s",
|
worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s",
|
||||||
worker)
|
worker)
|
||||||
|
|
||||||
|
|
||||||
def test_worker_close(worker):
|
def test_worker_close(worker):
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None))
|
asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None))
|
||||||
@ -124,8 +118,8 @@ def test_worker_close(worker):
|
|||||||
conn = mock.Mock()
|
conn = mock.Mock()
|
||||||
conn.websocket = mock.Mock()
|
conn.websocket = mock.Mock()
|
||||||
conn.websocket.close_connection = mock.Mock(
|
conn.websocket.close_connection = mock.Mock(
|
||||||
wraps=asyncio.coroutine(lambda *a, **kw: None)
|
wraps=asyncio.coroutine(lambda *a, **kw: None)
|
||||||
)
|
)
|
||||||
worker.connections = set([conn])
|
worker.connections = set([conn])
|
||||||
worker.log = mock.Mock()
|
worker.log = mock.Mock()
|
||||||
worker.loop = loop
|
worker.loop = loop
|
||||||
@ -141,6 +135,6 @@ def test_worker_close(worker):
|
|||||||
_close = asyncio.ensure_future(worker.close(), loop=loop)
|
_close = asyncio.ensure_future(worker.close(), loop=loop)
|
||||||
loop.run_until_complete(_close)
|
loop.run_until_complete(_close)
|
||||||
|
|
||||||
assert worker.signal.stopped == True
|
assert worker.signal.stopped
|
||||||
assert conn.websocket.close_connection.called == True
|
assert conn.websocket.close_connection.called
|
||||||
assert len(worker.servers) == 0
|
assert len(worker.servers) == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user