simplified aiohttp version diffs, reverted worker import policy

This commit is contained in:
Alec Buckenheimer 2018-10-08 22:40:36 -04:00
parent 7a9e100b0f
commit b7d74c82ba
7 changed files with 243 additions and 321 deletions

View File

@ -1,30 +1,12 @@
import sys
import traceback import traceback
from json import JSONDecodeError from json import JSONDecodeError
from sanic.log import logger from sanic.log import logger
from sanic.exceptions import MethodNotSupported from sanic.exceptions import MethodNotSupported
from sanic.response import text from sanic.response import text
try:
try:
# direct use
import packaging as _packaging
version = _packaging.version
except (ImportError, AttributeError):
# setuptools v39.0 and above.
try:
from setuptools.extern import packaging as _packaging
except ImportError:
# Before setuptools v39.0
from pkg_resources.extern import packaging as _packaging
version = _packaging.version
except ImportError:
raise RuntimeError("The 'packaging' library is missing.")
HOST = '127.0.0.1' HOST = '127.0.0.1'
PORT = 42101 PORT = 42101
is_windows = sys.platform in ['win32', 'cygwin']
class SanicTestClient: class SanicTestClient:

View File

@ -12,19 +12,16 @@ except ImportError:
try: try:
import uvloop import uvloop
import gunicorn.workers.base as base
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError: except ImportError:
base = None
pass pass
import gunicorn.workers.base as base
from sanic.server import trigger_events, serve, HttpProtocol, Signal from sanic.server import trigger_events, serve, HttpProtocol, Signal
from sanic.websocket import WebSocketProtocol from sanic.websocket import WebSocketProtocol
# gunicorn is not available on windows class GunicornWorker(base.Worker):
if base is not None:
class GunicornWorker(base.Worker):
http_protocol = HttpProtocol http_protocol = HttpProtocol
websocket_protocol = WebSocketProtocol websocket_protocol = WebSocketProtocol
@ -147,23 +144,18 @@ if base is not None:
self.notify() self.notify()
req_count = sum( req_count = sum(
srv['requests_count'] for srv in self.servers.values() self.servers[srv]["requests_count"] for srv in self.servers
) )
if self.max_requests and req_count > self.max_requests: if self.max_requests and req_count > self.max_requests:
self.alive = False self.alive = False
self.log.info( self.log.info("Max requests exceeded, shutting down: %s",
"Max requests exceeded, shutting down: %s", self)
self
)
elif pid == os.getpid() and self.ppid != os.getppid(): elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False self.alive = False
self.log.info( self.log.info("Parent changed, shutting down: %s", self)
"Parent changed, shutting down: %s",
self
)
else: else:
await asyncio.sleep(1.0, loop=self.loop) await asyncio.sleep(1.0, loop=self.loop)
except (BaseException, GeneratorExit, KeyboardInterrupt): except (Exception, BaseException, GeneratorExit, KeyboardInterrupt):
pass pass
@staticmethod @staticmethod

View File

@ -1,7 +1,11 @@
import sys
import pytest import pytest
from sanic import Sanic from sanic import Sanic
if sys.platform in ['win32', 'cygwin']:
collect_ignore = ["test_worker.py"]
@pytest.fixture @pytest.fixture
def app(request): def app(request):

View File

@ -7,10 +7,7 @@ from sanic.config import Config
from sanic import server from sanic import server
import aiohttp import aiohttp
from aiohttp import TCPConnector from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, PORT, version from sanic.testing import SanicTestClient, HOST, PORT
aiohttp_version = version.parse(aiohttp.__version__)
class ReuseableTCPConnector(TCPConnector): class ReuseableTCPConnector(TCPConnector):
@ -18,32 +15,9 @@ class ReuseableTCPConnector(TCPConnector):
super(ReuseableTCPConnector, self).__init__(*args, **kwargs) super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
self.old_proto = None self.old_proto = None
if aiohttp_version >= version.parse('3.3.0'): async def connect(self, req, *args, **kwargs):
async def connect(self, req, traces, timeout):
new_conn = await super(ReuseableTCPConnector, self)\ new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, traces, timeout) .connect(req, *args, **kwargs)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
elif aiohttp_version >= version.parse('3.0.0'):
async def connect(self, req, traces=None):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, traces=traces)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
else:
async def connect(self, req):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req)
if self.old_proto is not None: if self.old_proto is not None:
if self.old_proto != new_conn._protocol: if self.old_proto != new_conn._protocol:
raise RuntimeError( raise RuntimeError(
@ -135,14 +109,14 @@ class ReuseableSanicTestClient(SanicTestClient):
try: try:
request, response = results request, response = results
return request, response return request, response
except: except Exception:
raise ValueError( raise ValueError(
"Request and response object expected, got ({})".format( "Request and response object expected, got ({})".format(
results)) results))
else: else:
try: try:
return results[-1] return results[-1]
except: except Exception:
raise ValueError( raise ValueError(
"Request object expected, got ({})".format(results)) "Request object expected, got ({})".format(results))

View File

@ -6,7 +6,23 @@ from sanic.response import text
from sanic.config import Config from sanic.config import Config
import aiohttp import aiohttp
from aiohttp import TCPConnector from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, version from sanic.testing import SanicTestClient, HOST
try:
try:
# direct use
import packaging
version = packaging.version
except (ImportError, AttributeError):
# setuptools v39.0 and above.
try:
from setuptools.extern import packaging
except ImportError:
# Before setuptools v39.0
from pkg_resources.extern import packaging
version = packaging.version
except ImportError:
raise RuntimeError("The 'packaging' library is missing.")
aiohttp_version = version.parse(aiohttp.__version__) aiohttp_version = version.parse(aiohttp.__version__)
@ -44,8 +60,7 @@ class DelayableTCPConnector(TCPConnector):
if aiohttp_version >= version.parse("3.3.0"): if aiohttp_version >= version.parse("3.3.0"):
ret = await self.orig_start(connection) ret = await self.orig_start(connection)
else: else:
ret = await self.orig_start(connection, ret = await self.orig_start(connection, read_until_eof)
read_until_eof)
except Exception as e: except Exception as e:
raise e raise e
return ret return ret
@ -59,57 +74,43 @@ class DelayableTCPConnector(TCPConnector):
async def delayed_send(self, *args, **kwargs): async def delayed_send(self, *args, **kwargs):
req = self.req req = self.req
if self.delay and self.delay > 0: if self.delay and self.delay > 0:
#sync_sleep(self.delay) # sync_sleep(self.delay)
await asyncio.sleep(self.delay) await asyncio.sleep(self.delay)
t = req.loop.time() t = req.loop.time()
print("sending at {}".format(t), flush=True) print("sending at {}".format(t), flush=True)
conn = next(iter(args)) # first arg is connection conn = next(iter(args)) # first arg is connection
if aiohttp_version >= version.parse("3.1.0"):
try: try:
delayed_resp = await self.orig_send(*args, **kwargs) return await self.orig_send(*args, **kwargs)
except Exception as e:
if aiohttp_version >= version.parse("3.3.0"):
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None)
else:
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
auto_decompress=None,
traces=[],
loop=req.loop,
session=None)
else:
try:
delayed_resp = self.orig_send(*args, **kwargs)
except Exception as e: except Exception as e:
if aiohttp_version < version.parse("3.1.0"):
return aiohttp.ClientResponse(req.method, req.url) return aiohttp.ClientResponse(req.method, req.url)
return delayed_resp kw = dict(
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None
)
if aiohttp_version < version.parse("3.3.0"):
kw['auto_decompress'] = None
return aiohttp.ClientResponse(req.method, req.url, **kw)
def _send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
if aiohttp_version >= version.parse("3.1.0"): if aiohttp_version >= version.parse("3.1.0"):
# aiohttp changed the request.send method to async # aiohttp changed the request.send method to async
async def send(self, *args, **kwargs): async def send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs) return self._send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
else: else:
def send(self, *args, **kwargs): send = _send
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
_post_connect_delay = kwargs.pop('post_connect_delay', 0) _post_connect_delay = kwargs.pop('post_connect_delay', 0)
@ -118,38 +119,11 @@ class DelayableTCPConnector(TCPConnector):
self._post_connect_delay = _post_connect_delay self._post_connect_delay = _post_connect_delay
self._pre_request_delay = _pre_request_delay self._pre_request_delay = _pre_request_delay
if aiohttp_version >= version.parse("3.3.0"): async def connect(self, req, *args, **kwargs):
async def connect(self, req, traces, timeout):
d_req = DelayableTCPConnector.\ d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay) RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\ conn = await super(DelayableTCPConnector, self).\
connect(req, traces, timeout) connect(req, *args, **kwargs)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
elif aiohttp_version >= version.parse("3.0.0"):
async def connect(self, req, traces=None):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, traces=traces)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
else:
async def connect(self, req):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).connect(req)
if self._post_connect_delay and self._post_connect_delay > 0: if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay, await asyncio.sleep(self._post_connect_delay,
loop=self._loop) loop=self._loop)

View File

@ -1,3 +1,4 @@
import sys
import asyncio import asyncio
import inspect import inspect
import os import os
@ -12,7 +13,7 @@ from sanic.response import (
HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
) )
from sanic.server import HttpProtocol from sanic.server import HttpProtocol
from sanic.testing import HOST, PORT, is_windows from sanic.testing import HOST, PORT
from unittest.mock import MagicMock from unittest.mock import MagicMock
JSON_DATA = {'ok': True} JSON_DATA = {'ok': True}
@ -74,6 +75,7 @@ def test_response_header(app):
'CONTENT-TYPE': 'application/json' 'CONTENT-TYPE': 'application/json'
}) })
is_windows = sys.platform in ['win32', 'cygwin']
request, response = app.test_client.get('/') request, response = app.test_client.get('/')
assert dict(response.headers) == { assert dict(response.headers) == {
'Connection': 'keep-alive', 'Connection': 'keep-alive',

View File

@ -4,17 +4,10 @@ import shlex
import subprocess import subprocess
import urllib.request import urllib.request
from unittest import mock from unittest import mock
from sanic.worker import GunicornWorker
from sanic.app import Sanic
import asyncio import asyncio
import pytest import pytest
from sanic.app import Sanic
try:
from sanic.worker import GunicornWorker
except ImportError:
pytestmark = pytest.mark.skip(
reason="GunicornWorker Not supported on this platform"
)
# this has to be defined or pytest will err on import
GunicornWorker = object
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
@ -107,11 +100,12 @@ def test_run_max_requests_exceeded(worker):
_runner = asyncio.ensure_future(worker._check_alive(), loop=loop) _runner = asyncio.ensure_future(worker._check_alive(), loop=loop)
loop.run_until_complete(_runner) loop.run_until_complete(_runner)
assert worker.alive == False assert not worker.alive
worker.notify.assert_called_with() worker.notify.assert_called_with()
worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s", worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s",
worker) worker)
def test_worker_close(worker): def test_worker_close(worker):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None)) asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None))
@ -141,6 +135,6 @@ def test_worker_close(worker):
_close = asyncio.ensure_future(worker.close(), loop=loop) _close = asyncio.ensure_future(worker.close(), loop=loop)
loop.run_until_complete(_close) loop.run_until_complete(_close)
assert worker.signal.stopped == True assert worker.signal.stopped
assert conn.websocket.close_connection.called == True assert conn.websocket.close_connection.called
assert len(worker.servers) == 0 assert len(worker.servers) == 0